Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions backends/arm/common/arm_compile_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class DebugMode(Enum):
path_for_intermediates: str | None = None
tosa_debug_mode: DebugMode | None = None
preserve_io_quantization: bool = False
tosa_dev_mode: bool | None = None

_TOSA_SPEC_KEY = "tosa_spec"
_COMPILE_FLAGS_KEY = "compile_flags"
Expand All @@ -46,6 +47,7 @@ class DebugMode(Enum):
_OUTPUT_REORDER_KEY = "ouput_reorder_workaround"
_TRANSFORM_PIPELINE_CONFIG_KEY = "transform_pipeline_config"
_PRESERVE_IO_QUANT_KEY = "preserve_io_quantization"
_TOSA_DEV_MODE = "tosa_sw_dev_mode"

def _set_compile_specs(
self,
Expand All @@ -56,6 +58,7 @@ def _set_compile_specs(
output_order_workaround: bool = False,
pipeline_config: ArmPassPipelineConfig | None = None,
preserve_io_quantization: bool = False,
tosa_dev_mode: bool | None = None,
):
"""Set all values of dataclass directly."""
self.tosa_spec = tosa_spec
Expand All @@ -66,6 +69,7 @@ def _set_compile_specs(
self.output_order_workaround = output_order_workaround
self.preserve_io_quantization = preserve_io_quantization
self._warn_if_redundant_preserve_io_quantization()
self.tosa_dev_mode = tosa_dev_mode
if output_order_workaround:
warnings.warn(
"ArmCompileSpec(output_order_workaround=True) is deprecated and will be "
Expand All @@ -84,6 +88,7 @@ def _from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901
output_order_workaround: bool = False
pipeline_config: ArmPassPipelineConfig | None = None
preserve_io_quantization: bool = False
tosa_dev_mode: bool | None = None
unknown_specs: dict[str, str] = {}
for spec in compile_specs:
key = spec.key
Expand Down Expand Up @@ -136,6 +141,12 @@ def _from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901
pipeline_config = ArmPassPipelineConfig.from_dict(json.loads(val))
elif key == ArmCompileSpec._PRESERVE_IO_QUANT_KEY:
preserve_io_quantization = str(val).lower() in ("1", "true", "yes")
elif key == ArmCompileSpec._TOSA_DEV_MODE:
if tosa_dev_mode is not None:
raise ValueError(
"More than one tosa_sw_dev_mode entry in compile spec."
)
tosa_dev_mode = str(val).lower() in ("1", "true", "yes")
else:
unknown_specs[key] = val

Expand All @@ -160,6 +171,7 @@ def _from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901
output_order_workaround=output_order_workaround,
pipeline_config=pipeline_config,
preserve_io_quantization=preserve_io_quantization,
tosa_dev_mode=tosa_dev_mode,
)
cls._from_list_hook(compile_spec, unknown_specs)
compile_spec._validate()
Expand Down Expand Up @@ -242,6 +254,15 @@ def _to_list(self):
str(bool(self.preserve_io_quantization)).encode(),
)
)

if self.tosa_dev_mode is not None:
compile_spec.append(
CompileSpec(
ArmCompileSpec._TOSA_DEV_MODE,
str(bool(self.tosa_dev_mode)).encode(),
)
)

return compile_spec

def _set_preserve_io_quantization(self, enabled: bool) -> "ArmCompileSpec":
Expand Down Expand Up @@ -326,6 +347,16 @@ def dump_debug_info(self, debug_mode: DebugMode | None):
self.tosa_debug_mode = debug_mode
return self

def _set_tosa_dev_mode(self, tosa_dev_mode: bool):
"""Sets whether to enable TOSA software development mode.

Args:
tosa_dev_mode: Boolean indicating whether to enable TOSA software development mode.

"""
self.tosa_dev_mode = tosa_dev_mode
return self
Comment on lines +350 to +358

@deprecated(
"set_output_order_workaround() is deprecated and will be removed in v1.5; please remove this call."
)
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/requirements-arm-tosa.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ flatbuffers == 24.3.25
tosa-adapter-model-explorer == 0.1.0
ai-edge-model-explorer >= 0.1.16
pytest-timeout == 2.4.0
tosa-tools == 2026.2.1
tosa-tools == 2026.5.0
6 changes: 6 additions & 0 deletions backends/arm/test/misc/test_compile_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,12 @@ def test_preserve_io_quantization_roundtrip_vgf_FP_INT():
assert roundtripped.preserve_io_quantization is True


def test_preserve_tosa_dev_mode_roundtrip_vgf_FP_INT():
compile_spec = VgfCompileSpec()
roundtripped = VgfCompileSpec._from_list(compile_spec._to_list())
assert roundtripped.tosa_dev_mode is True


def test_preserve_io_quantization_warns_for_u55_INT():
with warns(
UserWarning,
Expand Down
4 changes: 4 additions & 0 deletions backends/arm/tosa/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,9 @@ def _preprocess( # noqa: C901
targetDraft=True if version.minor > 0 else False,
)

if compile_spec.tosa_dev_mode:
tosa_graph.setExperimentalDevVersion()

if not (
tosa_spec.version.major == ts.TOSA_VERSION_MAJOR
and tosa_spec.version.minor <= ts.TOSA_VERSION_MINOR
Expand Down Expand Up @@ -484,4 +487,5 @@ def filter_tosa_compile_specs(
)
.dump_debug_info(compile_spec.tosa_debug_mode)
.set_output_order_workaround(compile_spec.output_order_workaround)
._set_tosa_dev_mode(compile_spec.tosa_dev_mode)
)
Loading
Loading