From 50fdcad4ac0cffcfede5c50cba9b5b851f67cdb4 Mon Sep 17 00:00:00 2001 From: Saoirse Stewart Date: Tue, 5 May 2026 13:26:40 +0100 Subject: [PATCH] Arm backend: Add TOSA dialect FFT ops Signed-off-by: Saoirse Stewart --- .../arm/test/misc/test_tosa_dialect_fft.py | 121 +++++++++++++++ backends/arm/tosa/dialect/__init__.py | 1 + backends/arm/tosa/dialect/ops/_common.py | 10 ++ backends/arm/tosa/dialect/ops/fft.py | 144 ++++++++++++++++++ 4 files changed, 276 insertions(+) create mode 100644 backends/arm/test/misc/test_tosa_dialect_fft.py create mode 100644 backends/arm/tosa/dialect/ops/fft.py diff --git a/backends/arm/test/misc/test_tosa_dialect_fft.py b/backends/arm/test/misc/test_tosa_dialect_fft.py new file mode 100644 index 00000000000..3922a1a88ea --- /dev/null +++ b/backends/arm/test/misc/test_tosa_dialect_fft.py @@ -0,0 +1,121 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import executorch.backends.arm.tosa.dialect # noqa: F401 +import pytest +import sympy # type: ignore[import-untyped] +import torch +from executorch.backends.arm.tosa.dialect.lib import TosaValueError +from executorch.backends.arm.tosa.specification import ( + TosaLoweringContext, + TosaSpecification, +) +from executorch.exir.dialects._ops import ops as exir_ops +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.symbolic_shapes import ShapeEnv + + +def _make_symint( + shape_env: ShapeEnv, symbol: str, hint: int, min: int = 1, max: int = 64 +) -> torch.SymInt: + symint = shape_env.create_symintnode(sympy.Symbol(symbol), hint=hint) + assert isinstance(symint, torch.SymInt) + shape_env.constrain_symbol_range( + symint.node.expr, compiler_min=min, compiler_max=max + ) + return symint + + +def _expr(sym: torch.SymInt) -> sympy.Expr: + return sympy.sympify(str(sym.node._expr)) + + +def test_fft2d_tosa_fp_fft() -> None: + input_real = torch.randn((2, 8, 16), dtype=torch.float32) + input_imag = torch.randn((2, 8, 16), dtype=torch.float32) + + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.1+FP+fft") + ), FakeTensorMode() as mode: + output_real, output_imag = exir_ops.backend.tosa.FFT2D.default( + mode.from_tensor(input_real), + mode.from_tensor(input_imag), + ) + + assert output_real.dtype == torch.float32 + assert output_imag.dtype == torch.float32 + assert tuple(output_real.shape) == (2, 8, 16) + assert tuple(output_imag.shape) == (2, 8, 16) + + +def test_fft2d_accepts_matching_symbolic_shape() -> None: + shape_env = ShapeEnv() + width = _make_symint(shape_env, "w", hint=16) + + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.1+FP+fft"), + shape_env, + ), FakeTensorMode(shape_env=shape_env) as mode: + input_real = torch.empty((2, 8, width), dtype=torch.float32) + input_imag = torch.empty((2, 8, width), dtype=torch.float32) + output_real, output_imag = exir_ops.backend.tosa.FFT2D.default( + mode.from_tensor(input_real), + mode.from_tensor(input_imag), + ) + + assert isinstance(output_real.shape[2], torch.SymInt) + assert isinstance(output_imag.shape[2], torch.SymInt) + assert sympy.simplify(_expr(output_real.shape[2]) - sympy.Symbol("w")) == 0 + assert sympy.simplify(_expr(output_imag.shape[2]) - sympy.Symbol("w")) == 0 + + +def test_rfft2d_tosa_fp_fft() -> None: + input_real = torch.randn((2, 8, 16), dtype=torch.float32) + + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.1+FP+fft") + ), FakeTensorMode() as mode: + output_real, output_imag = exir_ops.backend.tosa.RFFT2D.default( + mode.from_tensor(input_real), + ) + + assert output_real.dtype == torch.float32 + assert output_imag.dtype == torch.float32 + assert tuple(output_real.shape) == (2, 8, 9) + assert tuple(output_imag.shape) == (2, 8, 9) + + +def test_fft_requires_extension() -> None: + input_real = torch.randn((2, 8, 16), dtype=torch.float32) + input_imag = torch.randn((2, 8, 16), dtype=torch.float32) + + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.1+FP") + ), FakeTensorMode() as mode: + with pytest.raises(TosaValueError, match="doesn't support FFT2D"): + exir_ops.backend.tosa.FFT2D.default( + mode.from_tensor(input_real), + mode.from_tensor(input_imag), + ) + + +def test_rfft2d_preserves_symbolic_width() -> None: + shape_env = ShapeEnv() + width = _make_symint(shape_env, "w", hint=16) + + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.1+FP+fft"), + shape_env, + ), FakeTensorMode(shape_env=shape_env) as mode: + input_real = torch.empty((2, 8, width), dtype=torch.float32) + output_real, output_imag = exir_ops.backend.tosa.RFFT2D.default( + mode.from_tensor(input_real) + ) + + expected = sympy.floor(sympy.Symbol("w") / 2) + sympy.Integer(1) + assert isinstance(output_real.shape[2], torch.SymInt) + assert isinstance(output_imag.shape[2], torch.SymInt) + assert sympy.simplify(_expr(output_real.shape[2]) - expected) == 0 + assert sympy.simplify(_expr(output_imag.shape[2]) - expected) == 0 diff --git a/backends/arm/tosa/dialect/__init__.py b/backends/arm/tosa/dialect/__init__.py index 4678da4d118..9f16720d893 100644 --- a/backends/arm/tosa/dialect/__init__.py +++ b/backends/arm/tosa/dialect/__init__.py @@ -11,6 +11,7 @@ conv3d, custom, depthwise_conv2d, + fft, gather, identity, matmul, diff --git a/backends/arm/tosa/dialect/ops/_common.py b/backends/arm/tosa/dialect/ops/_common.py index f70b6995eeb..c05e1a9d173 100644 --- a/backends/arm/tosa/dialect/ops/_common.py +++ b/backends/arm/tosa/dialect/ops/_common.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import torch from executorch.backends.arm.tosa.dialect.lib import TosaValueError _VALID_NAN_MODES = {"PROPAGATE", "IGNORE"} @@ -14,3 +15,12 @@ def validate_nan_mode(nan_mode: str, op: str) -> None: f"Unsupported nan_mode {nan_mode}. Expected one of {_VALID_NAN_MODES}", op=op, ) + + +def validate_power_of_two(size: int | torch.SymInt, name: str, op: str) -> None: + if not isinstance(size, int): + return + if size < 1 or (size & (size - 1)) != 0: + raise TosaValueError( + f"{name} must be a positive power of two, got {size}", op=op + ) diff --git a/backends/arm/tosa/dialect/ops/fft.py b/backends/arm/tosa/dialect/ops/fft.py new file mode 100644 index 00000000000..60294e7ef4e --- /dev/null +++ b/backends/arm/tosa/dialect/ops/fft.py @@ -0,0 +1,144 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import sympy # type: ignore[import-untyped] +import torch +from executorch.backends.arm.tosa.dialect.lib import TosaValueError +from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op +from executorch.backends.arm.tosa.specification import ( + get_context_shape_env, + get_context_spec, + TosaSpecification, +) +from torch.utils._sympy.functions import FloorDiv + + +def _validate_fft_spec(op: str) -> None: + tosa_spec = get_context_spec() + if not (tosa_spec.support_float() and tosa_spec.support_extension("fft")): + raise TosaValueError( + f"TOSA spec {tosa_spec} doesn't support {op}", + op=op, + ) + + +def _is_power_of_two(value: int) -> bool: + return value > 0 and (value & (value - 1)) == 0 + + +def _validate_power_of_two(value: int | torch.SymInt, name: str, op: str) -> None: + if isinstance(value, torch.SymInt): + expr = sympy.simplify(_to_sympy_expr(value)) + value_range = get_context_shape_env().bound_sympy(expr) + if value_range.is_int and value_range.is_singleton(): + singleton = sympy.simplify(value_range.lower) + if singleton.is_integer and not _is_power_of_two(int(singleton)): + raise TosaValueError( + f"{op} requires {name} to be a power of two but got {singleton}", + op=op, + ) + return + + if not _is_power_of_two(int(value)): + raise TosaValueError( + f"{op} requires {name} to be a power of two but got {value}", + op=op, + ) + + +def _validate_fft_input(input_real: torch.Tensor, op: str) -> None: + if input_real.dtype != torch.float32: + raise TosaValueError(f"{op} requires float32 inputs", op=op) + if input_real.dim() != 3: + raise TosaValueError(f"{op} requires a rank-3 input", op=op) + + _, height, width = input_real.shape + _validate_power_of_two(height, "height", op) + _validate_power_of_two(width, "width", op) + + +def _to_sympy_expr(value: int | torch.SymInt) -> sympy.Expr: + if isinstance(value, torch.SymInt): + return value.node._expr + return sympy.Integer(int(value)) + + +def _rfft_output_width(width: int | torch.SymInt) -> int | torch.SymInt: + if isinstance(width, torch.SymInt): + expr = FloorDiv(_to_sympy_expr(width), sympy.Integer(2)) + sympy.Integer(1) + return get_context_shape_env().create_symintnode(expr, hint=None) + return width // 2 + 1 + + +def _same_fft_dimension(lhs: int | torch.SymInt, rhs: int | torch.SymInt) -> bool: + if not isinstance(lhs, torch.SymInt) and not isinstance(rhs, torch.SymInt): + return lhs == rhs + + diff = sympy.simplify(_to_sympy_expr(lhs) - _to_sympy_expr(rhs)) + if diff == 0: + return True + + value_range = get_context_shape_env().bound_sympy(diff) + return ( + value_range.is_int + and value_range.is_singleton() + and sympy.simplify(value_range.lower) == 0 + ) + + +def _same_fft_shape( + lhs: torch.Size | tuple[int | torch.SymInt, ...], + rhs: torch.Size | tuple[int | torch.SymInt, ...], +) -> bool: + return len(lhs) == len(rhs) and all( + _same_fft_dimension(lhs_dim, rhs_dim) for lhs_dim, rhs_dim in zip(lhs, rhs) + ) + + +@register_fake_tosa_op( + "FFT2D(Tensor input_real, Tensor input_imag, *, bool inverse=False, bool local_bound=False) -> (Tensor output_real, Tensor output_imag)", + TosaSpecification.all_versions_and_profiles(), +) +def FFT2D( + input_real: torch.Tensor, + input_imag: torch.Tensor, + *, + inverse: bool = False, + local_bound: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + _validate_fft_spec("FFT2D") + _validate_fft_input(input_real, "FFT2D") + _validate_fft_input(input_imag, "FFT2D") + + if not _same_fft_shape(input_real.shape, input_imag.shape): + raise TosaValueError( + f"FFT2D expects matching input shapes but got {tuple(input_real.shape)} and {tuple(input_imag.shape)}", + op="FFT2D", + ) + + return ( + torch.empty_like(input_real, dtype=input_real.dtype), + torch.empty_like(input_imag, dtype=input_imag.dtype), + ) + + +@register_fake_tosa_op( + "RFFT2D(Tensor input_real, *, bool local_bound=False) -> (Tensor output_real, Tensor output_imag)", + TosaSpecification.all_versions_and_profiles(), +) +def RFFT2D( + input_real: torch.Tensor, + *, + local_bound: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + _validate_fft_spec("RFFT2D") + _validate_fft_input(input_real, "RFFT2D") + + batch, height, width = input_real.shape + output_shape = (batch, height, _rfft_output_width(width)) + return ( + torch.empty(output_shape, dtype=input_real.dtype), + torch.empty(output_shape, dtype=input_real.dtype), + )