Skip to content

[Relax][ONNX] Preserve NaN in Relu to align with ONNX Runtime#19750

Merged
guan404ming merged 2 commits into
apache:mainfrom
cchung100m:issue-19572-relu
Jun 19, 2026
Merged

[Relax][ONNX] Preserve NaN in Relu to align with ONNX Runtime#19750
guan404ming merged 2 commits into
apache:mainfrom
cchung100m:issue-19572-relu

Conversation

@cchung100m

Copy link
Copy Markdown
Contributor

Hi Committers,

This PR fixes issues #19572. Any suggestions would be appreciated if you are available.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the ONNX frontend's Relu operator implementation (v13) to preserve NaN values for floating-point inputs, aligning it with ONNX specifications. A corresponding unit test test_relu_nan_preserve is added to verify this behavior against ONNX Runtime. The review feedback highlights a compatibility issue in the test code where isinstance(tvm_out, list | tuple) is used; since TVM supports Python versions prior to 3.10, this should be changed to isinstance(tvm_out, (list, tuple)) to avoid runtime errors.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread tests/python/relax/test_frontend_onnx.py
@cchung100m cchung100m marked this pull request as ready for review June 15, 2026 16:23

@tlopex tlopex left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM could you resolve the conflict so that we can merge it in?

@cchung100m

Copy link
Copy Markdown
Contributor Author

Hi @tlopex

Thanks for your friendly reminder, I updated the part you mentioned.

@guan404ming guan404ming merged commit fbb9102 into apache:main Jun 19, 2026
10 checks passed
@guan404ming

Copy link
Copy Markdown
Member

Thanks @cchung100m!

@cchung100m cchung100m deleted the issue-19572-relu branch June 19, 2026 17:10
@cchung100m

Copy link
Copy Markdown
Contributor Author

Thanks to @guan404ming

@tqchen

tqchen commented Jun 19, 2026

Copy link
Copy Markdown
Member

are we sure we want to have this behavior? The report is likely mostly come from fuzzer, and while Nan preserving is nice, having an explicit nan where step here will likely slow down the computation, i would rather have us to have less strict behavior while being efficient while leaving Nan handling as undefined behavior.

Instead, i think it may boils down to how low-level handles NaN, e.g. relu should translate to Max, and it is up to each backend to decide whether max(x, 0) should be nan preserving, for cases that can be made efficient we might

@tqchen

tqchen commented Jun 19, 2026

Copy link
Copy Markdown
Member

Looking more broadly, it appears that we have a series of fixes of nan behavior all via explicit nan checking where steps, i think they are not the directions we want to go toward, e.g. paying extra computing cost for a corner case that may only appear in fuzzer setting. Instead, i think it is better to think about ways to define the right most efficient way in implementing relax.relu

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants