Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions python/tvm/relax/frontend/tflite/tflite_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ def __init__(self, model, subgraph, exp_tab, ctx, conversion_state=None):
"DENSIFY": self.convert_densify,
"DEPTH_TO_SPACE": self.convert_depth_to_space,
"DEPTHWISE_CONV_2D": functools.partial(self.convert_conv, conv_type="depthwise"),
"DELEGATE": functools.partial(self.convert_operator_marker, op_name="DELEGATE"),
"DEQUANTIZE": self.convert_dequantize,
"DETECTION_POSTPROCESS": self.convert_detection_postprocess,
"DILATE": self.convert_dilate,
Expand Down Expand Up @@ -290,6 +291,9 @@ def __init__(self, model, subgraph, exp_tab, ctx, conversion_state=None):
"PACK": self.convert_pack,
"PAD": self.convert_pad,
"PADV2": self.convert_pad,
"PLACEHOLDER_FOR_GREATER_OP_CODES": functools.partial(
self.convert_operator_marker, op_name="PLACEHOLDER_FOR_GREATER_OP_CODES"
),
"POW": functools.partial(self._convert_elemwise, relax_op=_op.power),
"PRELU": self.convert_prelu,
"RANGE": self.convert_range,
Expand Down Expand Up @@ -479,6 +483,12 @@ def check_unsupported_ops(self):
if len(raise_msg) > 0:
raise tvm.error.OpNotImplemented(raise_msg)

def convert_operator_marker(self, op, op_name):
"""Reject TFLite marker builtins with an explicit diagnostic."""
raise tvm.error.OpNotImplemented(
f"TFLite operator marker {op_name} is not a Relax tensor operator"
)

def unbind(self, data, axis=1):
"""
This is a modified version compared to the one in common.py.
Expand Down
26 changes: 26 additions & 0 deletions tests/python/relax/test_frontend_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -4066,6 +4066,32 @@ def _get_builtin_operator(builtin_name):
return getattr(_tfl_builtin_operator, builtin_name)


def _build_tflite_operator_marker_model(builtin_name):
"""Build a minimal model containing a TFLite marker builtin."""
builder = flatbuffers.Builder(1024)
builtin_op = _get_builtin_operator(builtin_name)
op_code = _build_operator_code(builder, builtin_op)
tensors = [
_build_tensor(builder, 0, [1], tensor_type=_tfl_tensor_type.FLOAT32),
_build_tensor(builder, 0, [1], tensor_type=_tfl_tensor_type.FLOAT32),
]
op = _build_operator(builder, 0, [0], [1])
subgraph = _build_subgraph(builder, tensors=tensors, operators=[op], inputs=[0], outputs=[1])
return _finish_tflite_model(
builder,
subgraph=subgraph,
operator_codes=[op_code],
buffers=[_build_buffer(builder)],
)


@pytest.mark.parametrize("builtin_name", ["DELEGATE", "PLACEHOLDER_FOR_GREATER_OP_CODES"])
def test_operator_marker_unsupported(builtin_name):
"""TFLite marker builtins report explicit unsupported diagnostics."""
with pytest.raises(tvm.error.OpNotImplemented, match=f"TFLite operator marker {builtin_name}"):
_load_model_from_buffer(_build_tflite_operator_marker_model(builtin_name))
Comment on lines +4089 to +4092

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To prevent test failures in environments with older versions of the tflite package (where DELEGATE or PLACEHOLDER_FOR_GREATER_OP_CODES might not be defined in _tfl_builtin_operator), we should check for their existence and skip the test if they are missing.

Suggested change
def test_operator_marker_unsupported(builtin_name):
"""TFLite marker builtins report explicit unsupported diagnostics."""
with pytest.raises(tvm.error.OpNotImplemented, match=f"TFLite operator marker {builtin_name}"):
_load_model_from_buffer(_build_tflite_operator_marker_model(builtin_name))
def test_operator_marker_unsupported(builtin_name):
"""TFLite marker builtins report explicit unsupported diagnostics."""
if not hasattr(_tfl_builtin_operator, builtin_name):
pytest.skip(f"{builtin_name} is not supported by the installed tflite package version.")
with pytest.raises(tvm.error.OpNotImplemented, match=f"TFLite operator marker {builtin_name}"):
_load_model_from_buffer(_build_tflite_operator_marker_model(builtin_name))



def _run_module(mod, *inputs):
tgt = tvm.target.Target("c")
ex = tvm.compile(mod, tgt)
Expand Down
Loading