[codex] align PTODSL public surface and sync validation#388
[codex] align PTODSL public surface and sync validation#388Zhendong404 wants to merge 26 commits into
Conversation
| specialized_text = compiled.mlir_text() | ||
| expect_parse_roundtrip_and_verify(specialized_text, "flash attention specialized MLIR") | ||
| expect("func.func @flash_attention_kernel" in specialized_text, "direct compile should emit the flash_attention_kernel entry") | ||
| expect("!pto.tile_buf<mat, 64x128xf32" in specialized_text, "BLOCK_Q=64 specialization should change the physical Q tile shape") | ||
| expect("func.call @materialize_tile_bounds" in specialized_text, "direct compile should still route SIMT helpers through func.call") | ||
|
|
||
| cached = demo.flash_attention_kernel.cached_specializations() | ||
| expect(len(cached) >= 2, "wrapper compile plus explicit compile should populate at least two cached specializations") | ||
| print("ptodsl_flash_attention_demo_compile: PASS") |
There was a problem hiding this comment.
Here only tests python dsl -> MLIR? Should also test the ptoas step that lowers to binary.
|
|
||
|
|
||
| @pto.cube | ||
| def qk_matmul( |
There was a problem hiding this comment.
Is this @pto.cube decorator necessary? Can this function just be inlined?
| @pto.cube | ||
| def pv_matmul( |
There was a problem hiding this comment.
Same for here. It feels cumbersome that every small util function needs to be a separately-decorated function. The actual compute is only 7 lines (if inlined), but this function with argument is 20 lines...
| | Double-buffer handoff (compute → DMA) | `rls_buf(V, id)` + `get_buf(MTE2, id)` | | ||
| | Double-buffer handoff (DMA → compute) | `rls_buf(MTE2, id)` + `get_buf(V, id)` | | ||
| | Core A notifies core B | `set_cross_core(B, id)` + `wait_flag_dev(A, id)` | | ||
| | Core A notifies core B | `set_cross_flag(B, id)` + `wait_cross_flag(A, id)` | |
There was a problem hiding this comment.
This is a leftover from the previous design, these functions accept pipes now (only Pipe.FIX).
There was a problem hiding this comment.
Did CCE change this interface recently?
There was a problem hiding this comment.
Are those attributes like KernelRole.UKERNEL actually needed by the IR and passes? If not, we should keep the minimum needed context managers like with pto.vf():, and only keep one decorator @pto.jit, and remove the redundant decorators, to reduce the grammar noise.
There was a problem hiding this comment.
PTO IR actually need simd/simt/cube decorators to create different region/function/section. For ukernel, I'm considering remove it.
f2824be to
f8a71f9
Compare
191e7a1 to
538529a
Compare
| raise ValueError("seq must be positive") | ||
|
|
||
| @pto.jit( | ||
| name=name, |
There was a problem hiding this comment.
Minor thing: we can omit name can default to kernel.__name__ of this function object.
* pip install ptoas * use pip install in CI * wheels pipelines use pip install * add missing license header * fix pip setup
| if __package__ in {None, ""}: | ||
| here = Path(__file__).resolve() | ||
| for candidate in here.parents: | ||
| if (candidate / "ptodsl" / "__init__.py").exists(): | ||
| sys.path.insert(0, str(candidate)) | ||
| break | ||
| else: | ||
| raise RuntimeError( | ||
| "Unable to locate the PTODSL Python package root from flash_attention_softmax_launch.py" | ||
| ) | ||
|
|
||
| from ptodsl import pto |
There was a problem hiding this comment.
We can assume user already typed pip install the ptodsl package, so no need extra sys.path.insert here.
| def kernel( | ||
| scores: pto.tensor_spec(rank=2, dtype=pto.f32), | ||
| out: pto.tensor_spec(rank=2, dtype=pto.f32), | ||
| ): | ||
| lane_num = pto.elements_per_vreg(pto.f32) | ||
| physical_rows = ((rows + lane_num - 1) // lane_num) * lane_num | ||
| scores_tile_bytes = seq * physical_rows * pto.bytewidth(pto.f32) | ||
| runtime_seq = scores.shape[0] | ||
| runtime_rows = scores.shape[1] | ||
| total_elems = runtime_rows * runtime_seq | ||
|
|
||
| scores_view = pto.make_tensor_view( | ||
| scores, | ||
| shape=[1, 1, 1, runtime_seq, runtime_rows], | ||
| strides=[total_elems, total_elems, total_elems, runtime_rows, 1], | ||
| ) | ||
| out_view = pto.make_tensor_view( | ||
| out, | ||
| shape=[1, 1, 1, runtime_seq, runtime_rows], | ||
| strides=[total_elems, total_elems, total_elems, runtime_rows, 1], | ||
| ) |
There was a problem hiding this comment.
In type declaration, scores: pto.ptr(dtype=pto.f32) is more suitable than pto.tensor_spec(rank=2, dtype=pto.f32). Because scores is converted to 5D tensor by pto.make_tensor_view anyways, so the previous rank=2 information looks useless?
| _DEVICE = "npu:0" | ||
|
|
||
|
|
||
| def _make_softmax_kernel(name: str, *, rows: int, seq: int): |
There was a problem hiding this comment.
Here uses closure to re-compile kernel for every [rows, seq] shape. Should test dynamic-shape kernel by having rows: pto.i32 as kernel's dynamic arg (not as closure/constant)
@MirkoDeVita98 check if dynamic shape works? ptodsl/examples/jit/tadd_launch.py is an easier starting point.
There was a problem hiding this comment.
I updated tadd_launch.py in #418 to include a dynamic-shape TADD kernel with rows: pto.i32 as a runtime kernel argument instead of capturing it as a closure/constant. The dynamic kernel reuses the same compiled kernel for different row counts (16x64 and 32x64) and passes rows at launch time. Verified with msprof and all TADD cases pass.
| @pto.jit( | ||
| name="TADD_f32_16x64", | ||
| kernel_kind="vector", | ||
| target="a5", | ||
| ) | ||
| def TADD_f32_16x64( | ||
| A: pto.tensor_spec(rank=2, dtype=pto.f32), | ||
| B: pto.tensor_spec(rank=2, dtype=pto.f32), | ||
| C: pto.tensor_spec(rank=2, dtype=pto.f32), | ||
| ): | ||
| _tadd_tile(A, B, C, 16, 64) |
There was a problem hiding this comment.
Same issues here as in flash_attention_softmax_launch.py:
rank=2is useless & redundant information- only closure-based static shape, dynamic dim is not tested (cc @MirkoDeVita98
namecan be omittedsys.path.insertnot needed, assuming pip installed ptodsl
|
Will be merged into feature-vpto-backend directly |
Summary
Validation