From 2da979407cb0ee9ff81f6b48cde89098f769cc22 Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Thu, 4 Dec 2025 14:59:16 +0100 Subject: [PATCH] Use default FakeTensorMode when calling module without input Before, _ExportPassBase set self.tracer.fake_tensor_mode to a default value, but didn't use it when tracing. This caused operators that only have fake tensor implementations to crash. Signed-off-by: Erik Lundell Change-Id: I8d7ef0cc841b0e46cd04ea4ed941b761798a76d2 --- backends/arm/test/ops/test_cond.py | 2 -- exir/pass_base.py | 8 ++++---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/backends/arm/test/ops/test_cond.py b/backends/arm/test/ops/test_cond.py index 77405354bd4..2cb6585ae9f 100644 --- a/backends/arm/test/ops/test_cond.py +++ b/backends/arm/test/ops/test_cond.py @@ -237,8 +237,6 @@ def test_cond_tosa_FP(case: Callable[[], tuple[torch.nn.Module, tuple]]): "case", test_cases, xfails={ - "zero_args_one_output": "Since the submodules have no input, the tracer fails finding a fake tensor mode," - " and traces the graph with real tensors, which tosa.RESCALE can't handle.", "one_arg_and_scalar_one_output": "Incorrect quantization on the scalar.", "nested_one_arg_one_output": "Node submodule_0 target submodule_0 references nonexistent attribute submodule_0", }, diff --git a/exir/pass_base.py b/exir/pass_base.py index 497970fae34..eded028f691 100644 --- a/exir/pass_base.py +++ b/exir/pass_base.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -593,14 +594,13 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: ), "Multiple fake tensor mode detected." fake_tensor_mode = i.fake_mode if fake_tensor_mode is None: - self.tracer.fake_tensor_mode = FakeTensorMode(allow_non_fake_inputs=True) - fake_tensor_mode = nullcontext() # type: ignore[assignment] + fake_tensor_mode = FakeTensorMode(allow_non_fake_inputs=True) dispatcher_mode = nullcontext() # type: ignore[assignment] else: fake_tensor_mode.allow_non_fake_inputs = True - self.tracer.fake_tensor_mode = fake_tensor_mode dispatcher_mode = enable_python_dispatcher() # type: ignore[assignment] - self.fake_tensor_mode = self.tracer.fake_tensor_mode + self.tracer.fake_tensor_mode = fake_tensor_mode + self.fake_tensor_mode = fake_tensor_mode with fake_tensor_mode, dispatcher_mode: # type: ignore[assignment, union-attr] result = self.call_submodule(graph_module, tuple(inputs))