Add argmax/argmin as symbolic reduce ops with Z3 reasoning#316
Add argmax/argmin as symbolic reduce ops with Z3 reasoning#316
Conversation
Register argmax and argmin in REDUCE_OPS, _SUPPORTED_OPS, and _Z3_BUILDERS. The Z3 implementation uses an If-chain that tracks the best value and index, avoiding the concretize() fallback which would fail for nodes that only implement _to_z3_impl (e.g. where). GPC: support-tl-max-1-argmax-argmin-z3
Add _NUMPY_REDUCE_OPS mapping and concretize() method so that reduce expressions can fall back to concrete NumPy evaluation when the Z3 path is not needed (e.g. for simple constant inputs). argmax/argmin results are cast to int32 to match the declared dtype. GPC: support-tl-max-2-concretize
f9ee11c to
2c8c0c7
Compare
Sanitizer Performance Benchmark
Iterations: 1 warmup + 40 measured |
|
@codex review. |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 9e7ad4954a
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| for i in range(1, len(arr)): | ||
| is_better = arr[i] > best_val | ||
| best_idx = If(is_better, IntVal(i), best_idx) | ||
| best_val = If(is_better, arr[i], best_val) | ||
| return best_idx, constraints |
There was a problem hiding this comment.
Respect axis in argmax/argmin reduction logic
This loop always scans the flattened arr and returns a single global index, but ReduceSymbolicExpr accepts axis/keepdims and computes shaped outputs from them. For any multi-dimensional input with axis set (the normal argmax/argmin case), the result is incorrect: it returns a flat index over all elements instead of per-axis indices, so symbolic checks can validate the wrong element and wrong index range.
Useful? React with 👍 / 👎.
Code reviewFound 1 issue:
The docstring for triton-viz/tests/unit/test_sanitizer.py Lines 99 to 106 in 2c8c0c7 🤖 Generated with Claude Code - If this code review was useful, please react with 👍. Otherwise, react with 👎. |
The docstring incorrectly referenced a non-existent "old concretize() fallback". Before this branch, argmax/argmin were not registered ops at all, so no such fallback ever existed on main.
fixed. |
|
|
||
|
|
||
| @pytest.mark.parametrize("op,np_op", [("argmax", np.argmax), ("argmin", np.argmin)]) | ||
| def test_reduce_argmax_argmin_z3_through_where(op: str, np_op): |
There was a problem hiding this comment.
Do you have any end to end examples?
There was a problem hiding this comment.
Yes, just added one — see test_tl_max_return_indices in tests/end_to_end/test_sanitizer.py. It runs tl.max(x, axis=0, return_indices=True) through the sanitizer end-to-end.
| return IntVal(0), constraints | ||
| best_val = arr[0] | ||
| best_idx: Z3Expr = IntVal(0) | ||
| for i in range(1, len(arr)): |
There was a problem hiding this comment.
This is confusing to me. If arr is a SymbolicExpr, what's gonna happen?
There was a problem hiding this comment.
arr here is the result of self.input._to_z3(), not self.input itself — so by the time we reach the loop, arr is already a Python list/tuple of Z3 expressions (e.g. [BitVecVal(1), BitVecVal(5), ...]), not a SymbolicExpr. The isinstance(arr, (list, tuple)) check on line 1192 handles the edge case where _to_z3() returns a scalar Z3 expression (single-element reduction), in which case the argmax index is trivially 0.
There was a problem hiding this comment.
In anyways, it doesn't match argmin's semantic, the index variable returned is the index of the specific element in the tensor, not the index pointing to an array of tensors
There was a problem hiding this comment.
So we can just leave z3 as not implemented here, since most of address calculation won't involve argmin/max.
Fixed.
Resolve conflict in ReduceSymbolicExpr.__init__: keep argmax/argmin dtype logic from this branch and add self.shape from main.
Reproduces the crash where tl.max(x, axis=0, return_indices=True) fails under triton-sanitizer with TypeError in tensor.__getitem__.
The symbolic engine invariant requires self.dtype to always be a scalar type with shape stored separately. The argmax/argmin commit broke this by wrapping dtype in tl.block_type, causing downstream cast failures.
| return IntVal(0), constraints | ||
| best_val = arr[0] | ||
| best_idx: Z3Expr = IntVal(0) | ||
| for i in range(1, len(arr)): |
There was a problem hiding this comment.
In anyways, it doesn't match argmin's semantic, the index variable returned is the index of the specific element in the tensor, not the index pointing to an array of tensors
…o support-tl-max-1-argmax-argmin-z3 # Conflicts: # triton_viz/core/patch.py
…back The Z3 symbolic engine tracks addresses, not data values, so argmax/argmin cannot be meaningfully computed on the Z3 path. Remove the broken _reduce_argmax/_reduce_argmin Z3 builders and let these ops fall back to concretize() which uses numpy and produces correct results.
argmax/argmin results are stored as values, never used in address calculations, so neither the Z3 path nor concretize is needed.
Already covered by test_tl_max_return_indices; the masked variant was redundant (N == BLOCK made the mask trivially all-true).
Summary
argmaxandargmininREDUCE_OPS,_SUPPORTED_OPS, and_Z3_BUILDERSTest plan
pytest tests/unit/test_sanitizer.py::test_reduce_argmax_argmin_z3_through_where -xvs[FEAT] Add argmax/argmin as symbolic reduce ops with Z3 reasoning
Register argmax and argmin in REDUCE_OPS, _SUPPORTED_OPS, and
_Z3_BUILDERS.
PR chain