Skip to content

Add argmax/argmin as symbolic reduce ops with Z3 reasoning#316

Merged
mark14wu merged 14 commits intomainfrom
support-tl-max-1-argmax-argmin-z3
Mar 14, 2026
Merged

Add argmax/argmin as symbolic reduce ops with Z3 reasoning#316
mark14wu merged 14 commits intomainfrom
support-tl-max-1-argmax-argmin-z3

Conversation

@mark14wu
Copy link
Copy Markdown
Collaborator

@mark14wu mark14wu commented Mar 6, 2026

Summary

  • Register argmax and argmin in REDUCE_OPS, _SUPPORTED_OPS, and _Z3_BUILDERS

Test plan

  • pytest tests/unit/test_sanitizer.py::test_reduce_argmax_argmin_z3_through_where -xvs

[FEAT] Add argmax/argmin as symbolic reduce ops with Z3 reasoning

Register argmax and argmin in REDUCE_OPS, _SUPPORTED_OPS, and
_Z3_BUILDERS.

PR chain

  1. 👉 Add argmax/argmin as symbolic reduce ops with Z3 reasoning #316 👈 YOU ARE HERE
  2. Add concretize() fallback for ReduceSymbolicExpr #317
  3. Support tl.max/tl.min with return_indices=True #318

mark14wu added 2 commits March 5, 2026 21:38
Register argmax and argmin in REDUCE_OPS, _SUPPORTED_OPS, and
_Z3_BUILDERS. The Z3 implementation uses an If-chain that tracks
the best value and index, avoiding the concretize() fallback which
would fail for nodes that only implement _to_z3_impl (e.g. where).

GPC: support-tl-max-1-argmax-argmin-z3
Add _NUMPY_REDUCE_OPS mapping and concretize() method so that reduce
expressions can fall back to concrete NumPy evaluation when the Z3
path is not needed (e.g. for simple constant inputs). argmax/argmin
results are cast to int32 to match the declared dtype.

GPC: support-tl-max-2-concretize
@github-actions
Copy link
Copy Markdown

github-actions Bot commented Mar 6, 2026

Sanitizer Performance Benchmark

Benchmark main (min) PR (min) Change
simple_load_store 0.005s 0.005s +0.1%
gemm 0.021s 0.021s +0.3%
gemm_oob 0.021s 0.022s +0.6%
indirect_load 0.066s 0.066s -0.6%
nested_loop 0.022s 0.022s +0.3%
block_pointer_loop_advance 0.007s 0.007s -0.8%
liger_jsd 0.138s 0.138s -0.3%
flaggems_layernorm 2.506s 2.524s +0.7%
Total 2.786s 2.803s +0.6%

Iterations: 1 warmup + 40 measured

@mark14wu
Copy link
Copy Markdown
Collaborator Author

mark14wu commented Mar 6, 2026

@codex review.

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 9e7ad4954a

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread triton_viz/clients/symbolic_engine.py Outdated
Comment on lines +1192 to +1196
for i in range(1, len(arr)):
is_better = arr[i] > best_val
best_idx = If(is_better, IntVal(i), best_idx)
best_val = If(is_better, arr[i], best_val)
return best_idx, constraints
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Respect axis in argmax/argmin reduction logic

This loop always scans the flattened arr and returns a single global index, but ReduceSymbolicExpr accepts axis/keepdims and computes shaped outputs from them. For any multi-dimensional input with axis set (the normal argmax/argmin case), the result is incorrect: it returns a flat index over all elements instead of per-axis indices, so symbolic checks can validate the wrong element and wrong index range.

Useful? React with 👍 / 👎.

@mark14wu
Copy link
Copy Markdown
Collaborator Author

mark14wu commented Mar 6, 2026

Code review

Found 1 issue:

  1. Misleading test docstring references a non-existent "old concretize() fallback" (factually incorrect documentation)

The docstring for test_reduce_argmax_argmin_z3_through_where claims "the old concretize() fallback would raise NotImplementedError." Before this PR, argmax/argmin were not registered ops at all — calling SymbolicExpr.create("argmax", ...) would raise NotImplementedError("Unsupported reduce op: argmax") from ReduceSymbolicExpr.__init__, not from any concretize() path. The docstring appears to describe intermediate development commits that were squashed and never existed on main.

def test_reduce_argmax_argmin_z3_through_where(op: str, np_op):
"""argmax/argmin should use Z3 symbolic path, not concretize().
When the input flows through a node that only has _to_z3_impl (like
``where``), the old concretize() fallback would raise
NotImplementedError. The Z3 If-chain implementation avoids this by
staying on the symbolic path end-to-end.
"""

🤖 Generated with Claude Code

- If this code review was useful, please react with 👍. Otherwise, react with 👎.

The docstring incorrectly referenced a non-existent "old concretize()
fallback". Before this branch, argmax/argmin were not registered ops
at all, so no such fallback ever existed on main.
@mark14wu
Copy link
Copy Markdown
Collaborator Author

mark14wu commented Mar 6, 2026

Code review

Found 1 issue:

  1. Misleading test docstring references a non-existent "old concretize() fallback" (factually incorrect documentation)

The docstring for test_reduce_argmax_argmin_z3_through_where claims "the old concretize() fallback would raise NotImplementedError." Before this PR, argmax/argmin were not registered ops at all — calling SymbolicExpr.create("argmax", ...) would raise NotImplementedError("Unsupported reduce op: argmax") from ReduceSymbolicExpr.__init__, not from any concretize() path. The docstring appears to describe intermediate development commits that were squashed and never existed on main.

def test_reduce_argmax_argmin_z3_through_where(op: str, np_op):
"""argmax/argmin should use Z3 symbolic path, not concretize().
When the input flows through a node that only has _to_z3_impl (like
``where``), the old concretize() fallback would raise
NotImplementedError. The Z3 If-chain implementation avoids this by
staying on the symbolic path end-to-end.
"""

🤖 Generated with Claude Code

  • If this code review was useful, please react with 👍. Otherwise, react with 👎.

fixed.

Comment thread tests/unit/test_sanitizer.py Outdated


@pytest.mark.parametrize("op,np_op", [("argmax", np.argmax), ("argmin", np.argmin)])
def test_reduce_argmax_argmin_z3_through_where(op: str, np_op):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have any end to end examples?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, just added one — see test_tl_max_return_indices in tests/end_to_end/test_sanitizer.py. It runs tl.max(x, axis=0, return_indices=True) through the sanitizer end-to-end.

Comment thread triton_viz/clients/symbolic_engine.py Outdated
return IntVal(0), constraints
best_val = arr[0]
best_idx: Z3Expr = IntVal(0)
for i in range(1, len(arr)):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is confusing to me. If arr is a SymbolicExpr, what's gonna happen?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

arr here is the result of self.input._to_z3(), not self.input itself — so by the time we reach the loop, arr is already a Python list/tuple of Z3 expressions (e.g. [BitVecVal(1), BitVecVal(5), ...]), not a SymbolicExpr. The isinstance(arr, (list, tuple)) check on line 1192 handles the edge case where _to_z3() returns a scalar Z3 expression (single-element reduction), in which case the argmax index is trivially 0.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In anyways, it doesn't match argmin's semantic, the index variable returned is the index of the specific element in the tensor, not the index pointing to an array of tensors

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we can just leave z3 as not implemented here, since most of address calculation won't involve argmin/max.
Fixed.

mark14wu added 6 commits March 9, 2026 15:42
Resolve conflict in ReduceSymbolicExpr.__init__: keep argmax/argmin
dtype logic from this branch and add self.shape from main.
Reproduces the crash where tl.max(x, axis=0, return_indices=True)
fails under triton-sanitizer with TypeError in tensor.__getitem__.
The symbolic engine invariant requires self.dtype to always be a scalar
type with shape stored separately. The argmax/argmin commit broke this
by wrapping dtype in tl.block_type, causing downstream cast failures.
Comment thread triton_viz/clients/symbolic_engine.py Outdated
return IntVal(0), constraints
best_val = arr[0]
best_idx: Z3Expr = IntVal(0)
for i in range(1, len(arr)):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In anyways, it doesn't match argmin's semantic, the index variable returned is the index of the specific element in the tensor, not the index pointing to an array of tensors

…o support-tl-max-1-argmax-argmin-z3

# Conflicts:
#	triton_viz/core/patch.py
…back

The Z3 symbolic engine tracks addresses, not data values, so argmax/argmin
cannot be meaningfully computed on the Z3 path. Remove the broken
_reduce_argmax/_reduce_argmin Z3 builders and let these ops fall back to
concretize() which uses numpy and produces correct results.
argmax/argmin results are stored as values, never used in address
calculations, so neither the Z3 path nor concretize is needed.
Already covered by test_tl_max_return_indices; the masked variant
was redundant (N == BLOCK made the mask trivially all-true).
@mark14wu mark14wu merged commit c8aa313 into main Mar 14, 2026
4 checks passed
@mark14wu mark14wu deleted the support-tl-max-1-argmax-argmin-z3 branch March 14, 2026 02:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants