Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions backends/arm/test/misc/test_tosa_dialect_fft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright 2026 Arm Limited and/or its affiliates.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file should be placed in misc/tosa_dialect

#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import executorch.backends.arm.tosa.dialect # noqa: F401
import pytest
import sympy # type: ignore[import-untyped]
import torch
from executorch.backends.arm.tosa.dialect.lib import TosaValueError
from executorch.backends.arm.tosa.specification import (
TosaLoweringContext,
TosaSpecification,
)
from executorch.exir.dialects._ops import ops as exir_ops
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.symbolic_shapes import ShapeEnv


def _make_symint(
shape_env: ShapeEnv, symbol: str, hint: int, min: int = 1, max: int = 64
) -> torch.SymInt:
symint = shape_env.create_symintnode(sympy.Symbol(symbol), hint=hint)
assert isinstance(symint, torch.SymInt)
shape_env.constrain_symbol_range(
symint.node.expr, compiler_min=min, compiler_max=max
)
return symint


def _expr(sym: torch.SymInt) -> sympy.Expr:
return sympy.sympify(str(sym.node._expr))


def test_fft2d_tosa_fp_fft() -> None:
input_real = torch.randn((2, 8, 16), dtype=torch.float32)
input_imag = torch.randn((2, 8, 16), dtype=torch.float32)

with TosaLoweringContext(
TosaSpecification.create_from_string("TOSA-1.1+FP+fft")
), FakeTensorMode() as mode:
output_real, output_imag = exir_ops.backend.tosa.FFT2D.default(
mode.from_tensor(input_real),
mode.from_tensor(input_imag),
)

assert output_real.dtype == torch.float32
assert output_imag.dtype == torch.float32
assert tuple(output_real.shape) == (2, 8, 16)
assert tuple(output_imag.shape) == (2, 8, 16)


def test_fft2d_accepts_matching_symbolic_shape() -> None:
shape_env = ShapeEnv()
width = _make_symint(shape_env, "w", hint=16)

with TosaLoweringContext(
TosaSpecification.create_from_string("TOSA-1.1+FP+fft"),
shape_env,
), FakeTensorMode(shape_env=shape_env) as mode:
input_real = torch.empty((2, 8, width), dtype=torch.float32)
input_imag = torch.empty((2, 8, width), dtype=torch.float32)
output_real, output_imag = exir_ops.backend.tosa.FFT2D.default(
mode.from_tensor(input_real),
mode.from_tensor(input_imag),
)

assert isinstance(output_real.shape[2], torch.SymInt)
assert isinstance(output_imag.shape[2], torch.SymInt)
assert sympy.simplify(_expr(output_real.shape[2]) - sympy.Symbol("w")) == 0
assert sympy.simplify(_expr(output_imag.shape[2]) - sympy.Symbol("w")) == 0


def test_rfft2d_tosa_fp_fft() -> None:
input_real = torch.randn((2, 8, 16), dtype=torch.float32)

with TosaLoweringContext(
TosaSpecification.create_from_string("TOSA-1.1+FP+fft")
), FakeTensorMode() as mode:
output_real, output_imag = exir_ops.backend.tosa.RFFT2D.default(
mode.from_tensor(input_real),
)

assert output_real.dtype == torch.float32
assert output_imag.dtype == torch.float32
assert tuple(output_real.shape) == (2, 8, 9)
assert tuple(output_imag.shape) == (2, 8, 9)


def test_fft_requires_extension() -> None:
input_real = torch.randn((2, 8, 16), dtype=torch.float32)
input_imag = torch.randn((2, 8, 16), dtype=torch.float32)

with TosaLoweringContext(
TosaSpecification.create_from_string("TOSA-1.1+FP")
), FakeTensorMode() as mode:
with pytest.raises(TosaValueError, match="doesn't support FFT2D"):
exir_ops.backend.tosa.FFT2D.default(
mode.from_tensor(input_real),
mode.from_tensor(input_imag),
)


def test_rfft2d_preserves_symbolic_width() -> None:
shape_env = ShapeEnv()
width = _make_symint(shape_env, "w", hint=16)

with TosaLoweringContext(
TosaSpecification.create_from_string("TOSA-1.1+FP+fft"),
shape_env,
), FakeTensorMode(shape_env=shape_env) as mode:
input_real = torch.empty((2, 8, width), dtype=torch.float32)
output_real, output_imag = exir_ops.backend.tosa.RFFT2D.default(
mode.from_tensor(input_real)
)

expected = sympy.floor(sympy.Symbol("w") / 2) + sympy.Integer(1)
assert isinstance(output_real.shape[2], torch.SymInt)
assert isinstance(output_imag.shape[2], torch.SymInt)
assert sympy.simplify(_expr(output_real.shape[2]) - expected) == 0
assert sympy.simplify(_expr(output_imag.shape[2]) - expected) == 0
1 change: 1 addition & 0 deletions backends/arm/tosa/dialect/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
conv3d,
custom,
depthwise_conv2d,
fft,
gather,
identity,
matmul,
Expand Down
10 changes: 10 additions & 0 deletions backends/arm/tosa/dialect/ops/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.backends.arm.tosa.dialect.lib import TosaValueError

_VALID_NAN_MODES = {"PROPAGATE", "IGNORE"}
Expand All @@ -14,3 +15,12 @@ def validate_nan_mode(nan_mode: str, op: str) -> None:
f"Unsupported nan_mode {nan_mode}. Expected one of {_VALID_NAN_MODES}",
op=op,
)


def validate_power_of_two(size: int | torch.SymInt, name: str, op: str) -> None:
if not isinstance(size, int):
return
if size < 1 or (size & (size - 1)) != 0:
raise TosaValueError(
f"{name} must be a positive power of two, got {size}", op=op
)
144 changes: 144 additions & 0 deletions backends/arm/tosa/dialect/ops/fft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import sympy # type: ignore[import-untyped]
import torch
from executorch.backends.arm.tosa.dialect.lib import TosaValueError
from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op
from executorch.backends.arm.tosa.specification import (
get_context_shape_env,
get_context_spec,
TosaSpecification,
)
from torch.utils._sympy.functions import FloorDiv


def _validate_fft_spec(op: str) -> None:
tosa_spec = get_context_spec()
if not (tosa_spec.support_float() and tosa_spec.support_extension("fft")):
raise TosaValueError(
f"TOSA spec {tosa_spec} doesn't support {op}",
op=op,
)


def _is_power_of_two(value: int) -> bool:
return value > 0 and (value & (value - 1)) == 0


def _validate_power_of_two(value: int | torch.SymInt, name: str, op: str) -> None:
if isinstance(value, torch.SymInt):
expr = sympy.simplify(_to_sympy_expr(value))
value_range = get_context_shape_env().bound_sympy(expr)
if value_range.is_int and value_range.is_singleton():
singleton = sympy.simplify(value_range.lower)
if singleton.is_integer and not _is_power_of_two(int(singleton)):
raise TosaValueError(
f"{op} requires {name} to be a power of two but got {singleton}",
op=op,
)
return

if not _is_power_of_two(int(value)):
raise TosaValueError(
f"{op} requires {name} to be a power of two but got {value}",
op=op,
)


def _validate_fft_input(input_real: torch.Tensor, op: str) -> None:
if input_real.dtype != torch.float32:
raise TosaValueError(f"{op} requires float32 inputs", op=op)
if input_real.dim() != 3:
raise TosaValueError(f"{op} requires a rank-3 input", op=op)

_, height, width = input_real.shape
_validate_power_of_two(height, "height", op)
_validate_power_of_two(width, "width", op)


def _to_sympy_expr(value: int | torch.SymInt) -> sympy.Expr:
if isinstance(value, torch.SymInt):
return value.node._expr
return sympy.Integer(int(value))


def _rfft_output_width(width: int | torch.SymInt) -> int | torch.SymInt:
if isinstance(width, torch.SymInt):
expr = FloorDiv(_to_sympy_expr(width), sympy.Integer(2)) + sympy.Integer(1)
return get_context_shape_env().create_symintnode(expr, hint=None)
return width // 2 + 1


def _same_fft_dimension(lhs: int | torch.SymInt, rhs: int | torch.SymInt) -> bool:
if not isinstance(lhs, torch.SymInt) and not isinstance(rhs, torch.SymInt):
return lhs == rhs

diff = sympy.simplify(_to_sympy_expr(lhs) - _to_sympy_expr(rhs))
if diff == 0:
return True

value_range = get_context_shape_env().bound_sympy(diff)
return (
value_range.is_int
and value_range.is_singleton()
and sympy.simplify(value_range.lower) == 0
)


def _same_fft_shape(
lhs: torch.Size | tuple[int | torch.SymInt, ...],
rhs: torch.Size | tuple[int | torch.SymInt, ...],
) -> bool:
return len(lhs) == len(rhs) and all(
_same_fft_dimension(lhs_dim, rhs_dim) for lhs_dim, rhs_dim in zip(lhs, rhs)
)


@register_fake_tosa_op(
"FFT2D(Tensor input_real, Tensor input_imag, *, bool inverse=False, bool local_bound=False) -> (Tensor output_real, Tensor output_imag)",
TosaSpecification.all_versions_and_profiles(),
)
def FFT2D(
input_real: torch.Tensor,
input_imag: torch.Tensor,
*,
inverse: bool = False,
local_bound: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
_validate_fft_spec("FFT2D")
_validate_fft_input(input_real, "FFT2D")
_validate_fft_input(input_imag, "FFT2D")

if not _same_fft_shape(input_real.shape, input_imag.shape):
raise TosaValueError(
f"FFT2D expects matching input shapes but got {tuple(input_real.shape)} and {tuple(input_imag.shape)}",
op="FFT2D",
)

return (
torch.empty_like(input_real, dtype=input_real.dtype),
torch.empty_like(input_imag, dtype=input_imag.dtype),
)


@register_fake_tosa_op(
"RFFT2D(Tensor input_real, *, bool local_bound=False) -> (Tensor output_real, Tensor output_imag)",
TosaSpecification.all_versions_and_profiles(),
)
def RFFT2D(
input_real: torch.Tensor,
*,
local_bound: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
_validate_fft_spec("RFFT2D")
_validate_fft_input(input_real, "RFFT2D")

batch, height, width = input_real.shape
output_shape = (batch, height, _rfft_output_width(width))
return (
torch.empty(output_shape, dtype=input_real.dtype),
torch.empty(output_shape, dtype=input_real.dtype),
)
Loading