From 0cd74858b095bae859f14219bba15ce8086057cd Mon Sep 17 00:00:00 2001 From: Will Trojak Date: Tue, 2 Dec 2025 22:13:19 +0000 Subject: [PATCH 01/11] [wip] added ptx generator for bstream --- gimmik/__init__.py | 4 +- gimmik/kernels/ptx/base.mako | 4 + gimmik/kernels/ptx/bstream.mako | 139 ++++++++++++++++++++++++++++++++ gimmik/ptx.py | 63 +++++++++++++++ 4 files changed, 209 insertions(+), 1 deletion(-) create mode 100644 gimmik/kernels/ptx/base.mako create mode 100644 gimmik/kernels/ptx/bstream.mako create mode 100644 gimmik/ptx.py diff --git a/gimmik/__init__.py b/gimmik/__init__.py index b32ebdc..cd21134 100644 --- a/gimmik/__init__.py +++ b/gimmik/__init__.py @@ -8,6 +8,7 @@ from gimmik.hip import HIPMatMul from gimmik.metal import MetalMatMul from gimmik.opencl import OpenCLMatMul +from gimmik.ptx import PTXMatMul def generate_mm(mat, dtype, platform, alpha=1.0, beta=0.0, funcn='gimmik_mm', @@ -22,7 +23,8 @@ def generate_mm(mat, dtype, platform, alpha=1.0, beta=0.0, funcn='gimmik_mm', 'cuda': CUDAMatMul, 'ispc': ISPCMatMul, 'hip': HIPMatMul, - 'opencl': OpenCLMatMul + 'opencl': OpenCLMatMul, + 'ptx': PTXMatMul } mm = platmap[platform](alpha*mat, beta, None, n, ldb, ldc) diff --git a/gimmik/kernels/ptx/base.mako b/gimmik/kernels/ptx/base.mako new file mode 100644 index 0000000..0521b84 --- /dev/null +++ b/gimmik/kernels/ptx/base.mako @@ -0,0 +1,4 @@ +.version 8.7 +.target sm_${cc} +.address_size 64 +${next.body()} \ No newline at end of file diff --git a/gimmik/kernels/ptx/bstream.mako b/gimmik/kernels/ptx/bstream.mako new file mode 100644 index 0000000..3ba8e93 --- /dev/null +++ b/gimmik/kernels/ptx/bstream.mako @@ -0,0 +1,139 @@ +<%inherit file='base'/> + +<% +pftype = "f32" if dtype == "float" else "f64" +putype = "u32" if dtype == "float" else "u64" +pbtype = "b32" if dtype == "float" else "b64" +rtype = "f" if dtype == "float" else "fd" +dwidth = "4" if dtype == "float" else "8" +%> + +% if n is None: +.visible .entry ${kname}(.param .u32 _n, + .param .u64 _b, + .param .u32 _ldb, + .param .u64 _c, + .param .u32 _ldc) +{ + .reg .u32 n, ldb, ldc; + ld.param.u32 n, [_n]; + ld.param.u32 ldb, [_ldb]; + ld.param.u32 ldc, [_ldc]; +% else: +.visible .entry ${kname}(.param .u64 _b, + .param .u64 _c) +{ + .reg .u32 n; + mov.u32 n, ${n}; +%endif + .reg .u32 id; + .reg .u64 b, c; + .reg .${pftype} csub<${m}>; + .reg .${pftype} ctmp<${m}>; + .reg .pred p1; + ld.param.u64 b, [_b]; + ld.param.u64 c, [_c]; + + { + .reg .u32 grid<3>; + mov.u32 grid0, %ntid.x; + mov.u32 grid1, %ctaid.x; + mov.u32 grid2, %tid.x; + mad.lo.u32 id, grid0, grid1, grid2; + } + setp.ge.u32 p1, id, n; + @p1 bra $L_EXIT; + cvta.to.global.u64 b, b; + cvta.to.global.u64 c, c; + + { + .reg .${pftype} bv; + .reg .u32 boff<${len(bix)}>, coff; + .reg .u64 bptr<${len(bix)}>, cptr; +%for kx in bix: +% if n is None: + mul.lo.u32 boff${kx}, ldb, ${kx}; + ${address((f"bptr{kx}", "u64"), ("b", "u64"), dwidth, (f"boff{kx}", "u32"), ("id", "u32"))} +% else: + ${address((f"bptr{kx}", "u64"), ("b", "u64"), dwidth, (f"{ldb*kx}", "u32"), ("id", "u32"))} +% endif + ld.weak.global.cg.${pftype} bv, [bptr${kx}]; + +% for j, jx in enumerate(A[:, kx]): +% if jx != 0 and kx == afix[j]: + mul.${pftype} csub${j}, bv, ${jx}; +% elif jx != 0: + fma.rn.${pftype} csub${j}, bv, ${jx}, csub${j}; +% endif + +% if kx == alix[j] and beta == 0: +% if n is None: + mul.lo.u32 coff, ldc, ${j}; + ${address(("cptr", "u64"), ("c", "u64"), dwidth, ("coff", "u32"), ("id", "u32"))} +% else: + ${address(("cptr", "u64"), ("c", "u64"), dwidth, (f"{ldc*j}", "u32"), ("id", "u32"))} +% endif: + st.weak.global.cg.${pftype} [cptr], csub${j}; + +% elif kx == alix[j] and beta == 1: +% if n is None: + mul.lo.u32 coff, ldc, ${j}; + ${address(("cptr", "u64"), ("c", "u64"), dwidth, ("coff", "u32"), ("id", "u32"))} +% else: + ${address(("cptr", "u64"), ("c", "u64"), dwidth, (f"{ldc*j}", "u32"), ("id", "u32"))} +% endif: + ld.weak.global.${pftype} ctmp${j}, [cptr]; + add.${pftype} ctmp${j}, ctmp${j}, csub${j}; + st.weak.global.cg.${pftype} [cptr], ctmp${j}; + +% elif kx == alix[j]: +% if n is None: + mul.lo.u32 coff, ldc, ${j}; + ${address(("cptr", "u64"), ("c", "u64"), dwidth, ("coff", "u32"), ("id", "u32"))} +% else: + ${address(("cptr", "u64"), ("c", "u64"), dwidth, (f"{ldc*j}", "u32"), ("id", "u32"))} +% endif: + ld.weak.global.${pftype} ctmp${j}, [cptr]; + fma.rn.${pftype} ctmp${j}, ctmp${j}, ${beta}, csub${j}; + st.weak.global.cg.${pftype} [cptr], csub${j}; +% endif +% endfor +%endfor + } + + { + .reg .u32 coff; + .reg .u64 cptr; + .reg .${pftype} fz; + .reg .${putype} uz; + .reg .${pftype} cin, cout; + mov.${putype} uz, 0; + mov.${pbtype} fz, uz; + +%for j, jx in enumerate(afix): +% if jx == -1 and beta == 0: +% if n is None: + mul.lo.u32 coff, ldc, ${j}; + ${address(("cptr", "u64"), ("c", "u64"), dwidth, ("coff", "u32"), ("id", "u32"))} +% else: + ${address(("cptr", "u64"), ("c", "u64"), dwidth, (f"{ldc*j}", "u32"), ("id", "u32"))} +% endif: + st.weak.global.cg.${pftype} [cptr], fz; + +% elif jx == -1 and beta != 1: +% if n is None: + mul.lo.u32 coff, ldc, ${j}; + ${address(("cptr", "u64"), ("c", "u64"), dwidth, ("coff", "u32"), ("id", "u32"))} +% else: + ${address(("cptr", "u64"), ("c", "u64"), dwidth, (f"{ldc*j}", "u32"), ("id", "u32"))} +% endif: + ld.weak.global.cg.${pftype} cin, [cptr]; + mul.${pftype} cout, cin, ${beta}; + st.weak.globla.cg.${pftype} [cptr], cout; +% endif +%endfor + } + +$L_EXIT: + ret; +} \ No newline at end of file diff --git a/gimmik/ptx.py b/gimmik/ptx.py new file mode 100644 index 0000000..bccdb61 --- /dev/null +++ b/gimmik/ptx.py @@ -0,0 +1,63 @@ +# -*- coding: utf-8 -*- + +from gimmik.base import MatMul + +class PTXSource: + def __init__(self): + self._src = "" + + def __iadd__(self, other): + self._src = f"{self}\n\t{other}" + return self + + def __str__(self): + return self._src + + def __repr__(self): + return self._src + + +class PTXMatMul(MatMul): + platform = 'ptx' + basemeta = {'block': (128, 1, 1), 'width': 1, 'shared': 0, + 'dynamic_shared': 0} + + def _address(self, out, base, size, *offs): + src = PTXSource() + out_type = out[1] + if out_type != base[1]: + raise RuntimeError("out and base must have the same type") + + if offs: + off_type = offs[0][1] + if not all(off[1] == off_type for off in offs): + raise RuntimeError("offsets must all have the same tpye") + + if len(offs) == 1: + off = offs[0] + mad_type = "lo" if out_type == off_type else "wide" + src += f"mad.{mad_type}.{off_type} {out[0]}, {size}, {off[0]}, {base[0]};" + else: + src += f".reg .{off_type} _addrs_acum;" + src += f"add.{off_type} _addrs_acum, {offs[0][0]}, {offs[1][0]};" + for off in offs[2:]: + src += f"add.{off_type} _addrs_acum, _addrs_acum, {off[0]};" + mad_type = "lo" if out_type == off_type else "wide" + src += f"mad.{mad_type}.{off_type} {out[0]}, {size}, _addrs_acum, {base[0]};" + else: + src += f"mov.{out_type} {out[0]}, {base[0]};" + return f"{{{src}\n\t}}" + + + def _kernel_generators(self, dtype, dsize, *, compute_capability=None): + base_args = {'address': lambda o, b, s, *off: self._address(o, b, s, + *off), 'cc': compute_capability} + + # B streaming, C accumulation kernel + args = base_args | {} + yield ('bstream', args, {}) + + def _process_meta(self, meta): + if self.n is not None: + div = meta['block'][0]*meta['width'] + meta['grid'] = (-(-self.n // div), 1, 1) From 626c2f5b4b8d41be691fd45044c81600345a3ece Mon Sep 17 00:00:00 2001 From: WillTrojak Date: Fri, 24 Apr 2026 12:40:34 -0700 Subject: [PATCH 02/11] Addtional sparse and dense work --- gimmik/base.py | 3 +- gimmik/kernels/ptx/base.mako | 2 +- gimmik/kernels/ptx/bstream-msplit.mako | 281 ++++++++++++++++++++++ gimmik/kernels/ptx/bstream.mako | 223 +++++++++-------- gimmik/kernels/ptx/cstream-ksplit.mako | 179 ++++++++++++++ gimmik/kernels/ptx/cstream-w2.mako | 98 ++++++++ gimmik/kernels/ptx/cstream.mako | 157 ++++++++++++ gimmik/kernels/ptx/dense-mma-gAd.mako | 210 ++++++++++++++++ gimmik/kernels/ptx/dense-mma-smem-gA.mako | 264 ++++++++++++++++++++ gimmik/ptx.py | 94 +++++++- 10 files changed, 1410 insertions(+), 101 deletions(-) create mode 100644 gimmik/kernels/ptx/bstream-msplit.mako create mode 100644 gimmik/kernels/ptx/cstream-ksplit.mako create mode 100644 gimmik/kernels/ptx/cstream-w2.mako create mode 100644 gimmik/kernels/ptx/cstream.mako create mode 100644 gimmik/kernels/ptx/dense-mma-gAd.mako create mode 100644 gimmik/kernels/ptx/dense-mma-smem-gA.mako diff --git a/gimmik/base.py b/gimmik/base.py index f547afc..0ecc29a 100644 --- a/gimmik/base.py +++ b/gimmik/base.py @@ -144,7 +144,8 @@ def _render_kernel(self, dtype, tplname, tplargs): src = tpl.render(**tplargs) # At single precision suffix all floating point constants by 'f' - if dtype == 'float': + # (PTX doesn't use an 'f' suffix for FP literals) + if dtype == 'float' and self.platform != 'ptx': src = re.sub(r'(?=\d*[.eE])(?=\.?\d)\d*\.?\d*(?:[eE][+-]?\d+)?', r'\g<0>f', src) diff --git a/gimmik/kernels/ptx/base.mako b/gimmik/kernels/ptx/base.mako index 0521b84..e380f1b 100644 --- a/gimmik/kernels/ptx/base.mako +++ b/gimmik/kernels/ptx/base.mako @@ -1,4 +1,4 @@ .version 8.7 -.target sm_${cc} +.target sm_${cc[0]}${cc[1]}${"a" if cc[0] >= 9 else ""} .address_size 64 ${next.body()} \ No newline at end of file diff --git a/gimmik/kernels/ptx/bstream-msplit.mako b/gimmik/kernels/ptx/bstream-msplit.mako new file mode 100644 index 0000000..77e7ce7 --- /dev/null +++ b/gimmik/kernels/ptx/bstream-msplit.mako @@ -0,0 +1,281 @@ +<%inherit file='base'/> + +<% +pftype = "f32" if dtype == "float" else "f64" +dwidth_i = 4 if dtype == "float" else 8 +fzero = "0f00000000" if dtype == "float" else "0d0000000000000000" +has_zero_rows = any(jx == -1 for jx in afix) +mx = partition(A, into=msplit, by='rows') +bix_list = list(bix) +bchunks = chunk(bix_list, bsz) +nchunks = len(bchunks) +m_per_group = max(len(mcx) for mcx in mx) +bsub_bytes = 2 * bsz * blockx * dwidth_i +def bsub_off(buf, idx): + return (buf * bsz + idx) * blockx * dwidth_i +use_cpasync = cc is not None and (cc[0], cc[1]) >= (8, 0) and dwidth_i in (4, 8) +%> + +% if n is None: +.visible .entry ${kname}(.param .u32 _n, + .param .u64 _b, + .param .u32 _ldb, + .param .u64 _c, + .param .u32 _ldc) +{ + .reg .u32 ldb, ldc; + ld.param.u32 ldb, [_ldb]; + ld.param.u32 ldc, [_ldc]; +% else: +.visible .entry ${kname}(.param .u64 _b, + .param .u64 _c) +{ +% endif + .reg .u32 n, id, tid_x, tid_y; + .reg .u64 b, c, b_base, c_base, bsub_thread; +% if use_cpasync: + .reg .u32 bsub_sm_thread; +% endif + .reg .${pftype} bv, csub<${m_per_group}>; + .reg .pred p1, p_skip; + .shared .align 8 .b8 _bsub[${bsub_bytes}]; + +% if n is None: + ld.param.u32 n, [_n]; +% else: + mov.u32 n, ${n}; +% endif + ld.param.u64 b, [_b]; + ld.param.u64 c, [_c]; + + { + .reg .u32 _ctaid_x; + mov.u32 _ctaid_x, %ctaid.x; + mov.u32 tid_x, %tid.x; + mov.u32 tid_y, %tid.y; + mad.lo.u32 id, _ctaid_x, ${blockx}, tid_x; + } + + setp.ge.u32 p1, id, n; + @p1 bra $L_EXIT; + + cvta.to.global.u64 b, b; + cvta.to.global.u64 c, c; + + { + .reg .u64 _id64; + cvt.u64.u32 _id64, id; + mad.lo.u64 b_base, _id64, ${dwidth_i}, b; + mad.lo.u64 c_base, _id64, ${dwidth_i}, c; + } + + { + .reg .u64 _tx_off; + mul.wide.u32 _tx_off, tid_x, ${dwidth_i}; + mov.u64 bsub_thread, _bsub; + add.u64 bsub_thread, bsub_thread, _tx_off; + } +% if use_cpasync: + { + .reg .u64 _sm64; + cvta.to.shared.u64 _sm64, bsub_thread; + cvt.u32.u64 bsub_sm_thread, _sm64; + } +% endif + +% for cid, mcx in enumerate(mx): +## cid = ${cid}, rows ${mcx} + setp.ne.u32 p_skip, tid_y, ${cid}; + @p_skip bra $L_END_CID_${cid}; + +% if use_cpasync: +## Async fill of chunk 0 +% for idx, kx in enumerate(bchunks[0]): +% if idx % msplit == cid: +% if n is None: + { + .reg .u32 _boff; + .reg .u64 _bptr; + mul.lo.u32 _boff, ldb, ${kx}; + mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; + cp.async.ca.shared::cta.global [bsub_sm_thread + ${bsub_off(0, idx)}], [_bptr], ${dwidth_i}; + } +% else: + cp.async.ca.shared::cta.global [bsub_sm_thread + ${bsub_off(0, idx)}], [b_base + ${ldb*kx*dwidth_i}], ${dwidth_i}; +% endif +% endif +% endfor + cp.async.commit_group; + cp.async.wait_all; + bar.sync 0; +% else: +## Sync fill of chunk 0 +% for idx, kx in enumerate(bchunks[0]): +% if idx % msplit == cid: +% if n is None: + { + .reg .u32 _boff; + .reg .u64 _bptr; + .reg .${pftype} _bv; + mul.lo.u32 _boff, ldb, ${kx}; + mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; + ld.global.cg.${pftype} _bv, [_bptr]; + st.shared.${pftype} [bsub_thread + ${bsub_off(0, idx)}], _bv; + } +% else: + { + .reg .${pftype} _bv; + ld.global.cg.${pftype} _bv, [b_base + ${ldb*kx*dwidth_i}]; + st.shared.${pftype} [bsub_thread + ${bsub_off(0, idx)}], _bv; + } +% endif +% endif +% endfor + bar.sync 0; +% endif + +## Main loop over B-chunks (double-buffered) +% for bb in range(nchunks): +<% + buf_cur = bb % 2 + buf_next = (bb + 1) % 2 + is_last = (bb == nchunks - 1) +%> +% if not is_last: +% for idx, kx in enumerate(bchunks[bb + 1]): +% if idx % msplit == cid: +% if use_cpasync: +% if n is None: + { + .reg .u32 _boff; + .reg .u64 _bptr; + mul.lo.u32 _boff, ldb, ${kx}; + mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; + cp.async.ca.shared::cta.global [bsub_sm_thread + ${bsub_off(buf_next, idx)}], [_bptr], ${dwidth_i}; + } +% else: + cp.async.ca.shared::cta.global [bsub_sm_thread + ${bsub_off(buf_next, idx)}], [b_base + ${ldb*kx*dwidth_i}], ${dwidth_i}; +% endif +% else: +% if n is None: + { + .reg .u32 _boff; + .reg .u64 _bptr; + .reg .${pftype} _bv; + mul.lo.u32 _boff, ldb, ${kx}; + mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; + ld.global.cg.${pftype} _bv, [_bptr]; + st.shared.${pftype} [bsub_thread + ${bsub_off(buf_next, idx)}], _bv; + } +% else: + { + .reg .${pftype} _bv; + ld.global.cg.${pftype} _bv, [b_base + ${ldb*kx*dwidth_i}]; + st.shared.${pftype} [bsub_thread + ${bsub_off(buf_next, idx)}], _bv; + } +% endif +% endif +% endif +% endfor +% if use_cpasync: + cp.async.commit_group; +% endif +% endif + +% for idx, kx in enumerate(bchunks[bb]): + ld.shared.${pftype} bv, [bsub_thread + ${bsub_off(buf_cur, idx)}]; +% for j, row_j in enumerate(mcx): +<% jx = A[row_j, kx] %> +% if jx != 0 and kx == afix[row_j]: + mul.${pftype} csub${j}, bv, ${jx}; +% elif jx != 0: + fma.rn.${pftype} csub${j}, bv, ${jx}, csub${j}; +% endif +% if kx == alix[row_j]: +% if beta == 0: +% if n is None: + { + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${row_j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + st.weak.global.cg.${pftype} [_cptr], csub${j}; + } +% else: + st.weak.global.cg.${pftype} [c_base + ${ldc*row_j*dwidth_i}], csub${j}; +% endif +% else: + { + .reg .${pftype} _ctmp; +% if n is None: + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${row_j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + ld.global.${pftype} _ctmp, [_cptr]; + fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, csub${j}; + st.global.${pftype} [_cptr], _ctmp; +% else: + ld.global.${pftype} _ctmp, [c_base + ${ldc*row_j*dwidth_i}]; + fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, csub${j}; + st.global.${pftype} [c_base + ${ldc*row_j*dwidth_i}], _ctmp; +% endif + } +% endif +% endif +% endfor +% endfor +% if use_cpasync: +% if not is_last: + cp.async.wait_all; +% endif +% endif + bar.sync 0; +% endfor + +## Handle zero rows in this cid's group +% if has_zero_rows: +% for row_j in mcx: +% if afix[row_j] == -1: +% if beta == 0: + { + .reg .${pftype} _tmp; + mov.${pftype} _tmp, ${fzero}; +% if n is None: + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${row_j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + st.weak.global.cg.${pftype} [_cptr], _tmp; +% else: + st.weak.global.cg.${pftype} [c_base + ${ldc*row_j*dwidth_i}], _tmp; +% endif + } +% elif beta != 1: + { + .reg .${pftype} _tmp; +% if n is None: + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${row_j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + ld.global.${pftype} _tmp, [_cptr]; + mul.${pftype} _tmp, _tmp, ${float(beta)}; + st.global.${pftype} [_cptr], _tmp; +% else: + ld.global.${pftype} _tmp, [c_base + ${ldc*row_j*dwidth_i}]; + mul.${pftype} _tmp, _tmp, ${float(beta)}; + st.global.${pftype} [c_base + ${ldc*row_j*dwidth_i}], _tmp; +% endif + } +% endif +% endif +% endfor +% endif + +$L_END_CID_${cid}: +% endfor + +$L_EXIT: + ret; +} diff --git a/gimmik/kernels/ptx/bstream.mako b/gimmik/kernels/ptx/bstream.mako index 3ba8e93..465eac3 100644 --- a/gimmik/kernels/ptx/bstream.mako +++ b/gimmik/kernels/ptx/bstream.mako @@ -2,10 +2,13 @@ <% pftype = "f32" if dtype == "float" else "f64" -putype = "u32" if dtype == "float" else "u64" -pbtype = "b32" if dtype == "float" else "b64" -rtype = "f" if dtype == "float" else "fd" -dwidth = "4" if dtype == "float" else "8" +dwidth_i = 4 if dtype == "float" else 8 +fzero = "0f00000000" if dtype == "float" else "0d0000000000000000" +has_zero_rows = any(jx == -1 for jx in afix) +bix_list = list(bix) +bix_idx = {kx: i for i, kx in enumerate(bix_list)} +preload_c = beta != 0 +need_scale = beta != 0 and beta != 1 %> % if n is None: @@ -15,125 +18,157 @@ dwidth = "4" if dtype == "float" else "8" .param .u64 _c, .param .u32 _ldc) { - .reg .u32 n, ldb, ldc; - ld.param.u32 n, [_n]; + .reg .u32 ldb, ldc; ld.param.u32 ldb, [_ldb]; ld.param.u32 ldc, [_ldc]; % else: .visible .entry ${kname}(.param .u64 _b, .param .u64 _c) { - .reg .u32 n; - mov.u32 n, ${n}; -%endif - .reg .u32 id; - .reg .u64 b, c; - .reg .${pftype} csub<${m}>; - .reg .${pftype} ctmp<${m}>; +% endif + .reg .u32 n, id; + .reg .u64 b, c, b_base, c_base; + .reg .${pftype} csub<${m}>, bv<${len(bix_list)}>; .reg .pred p1; + +% if n is None: + ld.param.u32 n, [_n]; +% else: + mov.u32 n, ${n}; +% endif ld.param.u64 b, [_b]; ld.param.u64 c, [_c]; { - .reg .u32 grid<3>; - mov.u32 grid0, %ntid.x; - mov.u32 grid1, %ctaid.x; - mov.u32 grid2, %tid.x; - mad.lo.u32 id, grid0, grid1, grid2; + .reg .u32 _grd<3>; + mov.u32 _grd0, %ntid.x; + mov.u32 _grd1, %ctaid.x; + mov.u32 _grd2, %tid.x; + mad.lo.u32 id, _grd0, _grd1, _grd2; } - setp.ge.u32 p1, id, n; + + setp.ge.u32 p1, id, n; @p1 bra $L_EXIT; + cvta.to.global.u64 b, b; cvta.to.global.u64 c, c; { - .reg .${pftype} bv; - .reg .u32 boff<${len(bix)}>, coff; - .reg .u64 bptr<${len(bix)}>, cptr; -%for kx in bix: -% if n is None: - mul.lo.u32 boff${kx}, ldb, ${kx}; - ${address((f"bptr{kx}", "u64"), ("b", "u64"), dwidth, (f"boff{kx}", "u32"), ("id", "u32"))} -% else: - ${address((f"bptr{kx}", "u64"), ("b", "u64"), dwidth, (f"{ldb*kx}", "u32"), ("id", "u32"))} -% endif - ld.weak.global.cg.${pftype} bv, [bptr${kx}]; + .reg .u64 _id64; + cvt.u64.u32 _id64, id; + mad.lo.u64 b_base, _id64, ${dwidth_i}, b; + mad.lo.u64 c_base, _id64, ${dwidth_i}, c; + } -% for j, jx in enumerate(A[:, kx]): -% if jx != 0 and kx == afix[j]: - mul.${pftype} csub${j}, bv, ${jx}; -% elif jx != 0: - fma.rn.${pftype} csub${j}, bv, ${jx}, csub${j}; -% endif +## Batch-load active B columns +%for i, kx in enumerate(bix_list): +% if n is None: + { + .reg .u32 _boff; + .reg .u64 _bptr; + mul.lo.u32 _boff, ldb, ${kx}; + mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; + ld.weak.global.cg.${pftype} bv${i}, [_bptr]; + } +% else: + ld.weak.global.cg.${pftype} bv${i}, [b_base + ${ldb*kx*dwidth_i}]; +% endif +%endfor -% if kx == alix[j] and beta == 0: -% if n is None: - mul.lo.u32 coff, ldc, ${j}; - ${address(("cptr", "u64"), ("c", "u64"), dwidth, ("coff", "u32"), ("id", "u32"))} -% else: - ${address(("cptr", "u64"), ("c", "u64"), dwidth, (f"{ldc*j}", "u32"), ("id", "u32"))} -% endif: - st.weak.global.cg.${pftype} [cptr], csub${j}; +% if preload_c: +## Pre-load C so per-row completion is a plain store +% for j in range(m): +% if afix[j] != -1: +% if n is None: + { + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + ld.weak.global.cg.${pftype} csub${j}, [_cptr]; + } +% else: + ld.weak.global.cg.${pftype} csub${j}, [c_base + ${ldc*j*dwidth_i}]; +% endif +% endif +% endfor +% if need_scale: +% for j in range(m): +% if afix[j] != -1: + mul.${pftype} csub${j}, csub${j}, ${float(beta)}; +% endif +% endfor +% endif +% endif -% elif kx == alix[j] and beta == 1: -% if n is None: - mul.lo.u32 coff, ldc, ${j}; - ${address(("cptr", "u64"), ("c", "u64"), dwidth, ("coff", "u32"), ("id", "u32"))} +## Main compute +%for kx in bix_list: +% for j, jx in enumerate(A[:, kx]): +% if jx != 0: +% if preload_c: + fma.rn.${pftype} csub${j}, bv${bix_idx[kx]}, ${jx}, csub${j}; +% elif kx == afix[j]: + mul.${pftype} csub${j}, bv${bix_idx[kx]}, ${jx}; % else: - ${address(("cptr", "u64"), ("c", "u64"), dwidth, (f"{ldc*j}", "u32"), ("id", "u32"))} -% endif: - ld.weak.global.${pftype} ctmp${j}, [cptr]; - add.${pftype} ctmp${j}, ctmp${j}, csub${j}; - st.weak.global.cg.${pftype} [cptr], ctmp${j}; - -% elif kx == alix[j]: + fma.rn.${pftype} csub${j}, bv${bix_idx[kx]}, ${jx}, csub${j}; +% endif +% endif +% if kx == alix[j]: % if n is None: - mul.lo.u32 coff, ldc, ${j}; - ${address(("cptr", "u64"), ("c", "u64"), dwidth, ("coff", "u32"), ("id", "u32"))} + { + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + st.weak.global.cg.${pftype} [_cptr], csub${j}; + } % else: - ${address(("cptr", "u64"), ("c", "u64"), dwidth, (f"{ldc*j}", "u32"), ("id", "u32"))} -% endif: - ld.weak.global.${pftype} ctmp${j}, [cptr]; - fma.rn.${pftype} ctmp${j}, ctmp${j}, ${beta}, csub${j}; - st.weak.global.cg.${pftype} [cptr], csub${j}; + st.weak.global.cg.${pftype} [c_base + ${ldc*j*dwidth_i}], csub${j}; +% endif + % endif % endfor %endfor - } - { - .reg .u32 coff; - .reg .u64 cptr; - .reg .${pftype} fz; - .reg .${putype} uz; - .reg .${pftype} cin, cout; - mov.${putype} uz, 0; - mov.${pbtype} fz, uz; - -%for j, jx in enumerate(afix): -% if jx == -1 and beta == 0: -% if n is None: - mul.lo.u32 coff, ldc, ${j}; - ${address(("cptr", "u64"), ("c", "u64"), dwidth, ("coff", "u32"), ("id", "u32"))} -% else: - ${address(("cptr", "u64"), ("c", "u64"), dwidth, (f"{ldc*j}", "u32"), ("id", "u32"))} -% endif: - st.weak.global.cg.${pftype} [cptr], fz; +% if has_zero_rows: + { + .reg .${pftype} _tmp; + mov.${pftype} _tmp, ${fzero}; +% for j, jx in enumerate(afix): +% if jx == -1 and beta == 0: +% if n is None: + { + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + st.weak.global.cg.${pftype} [_cptr], _tmp; + } +% else: + st.weak.global.cg.${pftype} [c_base + ${ldc*j*dwidth_i}], _tmp; +% endif -% elif jx == -1 and beta != 1: -% if n is None: - mul.lo.u32 coff, ldc, ${j}; - ${address(("cptr", "u64"), ("c", "u64"), dwidth, ("coff", "u32"), ("id", "u32"))} -% else: - ${address(("cptr", "u64"), ("c", "u64"), dwidth, (f"{ldc*j}", "u32"), ("id", "u32"))} -% endif: - ld.weak.global.cg.${pftype} cin, [cptr]; - mul.${pftype} cout, cin, ${beta}; - st.weak.globla.cg.${pftype} [cptr], cout; +% elif jx == -1: +% if n is None: + { + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + ld.global.${pftype} _tmp, [_cptr]; + mul.${pftype} _tmp, _tmp, ${float(beta)}; + st.global.${pftype} [_cptr], _tmp; + } +% else: + ld.global.${pftype} _tmp, [c_base + ${ldc*j*dwidth_i}]; + mul.${pftype} _tmp, _tmp, ${float(beta)}; + st.global.${pftype} [c_base + ${ldc*j*dwidth_i}], _tmp; +% endif % endif -%endfor - } +% endfor + } +% endif $L_EXIT: ret; -} \ No newline at end of file +} diff --git a/gimmik/kernels/ptx/cstream-ksplit.mako b/gimmik/kernels/ptx/cstream-ksplit.mako new file mode 100644 index 0000000..06e8a77 --- /dev/null +++ b/gimmik/kernels/ptx/cstream-ksplit.mako @@ -0,0 +1,179 @@ +<%inherit file='base'/> + +<% +pftype = "f32" if dtype == "float" else "f64" +dwidth_i = 4 if dtype == "float" else 8 +fzero = "0f00000000" if dtype == "float" else "0d0000000000000000" +kparts = partition(A, ksplit, by='cols') +cchunks = chunk(list(range(m)), csz) +cv_per_thread = -(-csz // ksplit) +bv_per_thread = max(len(kbx) for kbx in kparts) +csub_bytes = (ksplit - 1) * csz * blockx * dwidth_i +%> + +% if n is None: +.visible .entry ${kname}(.param .u32 _n, + .param .u64 _b, + .param .u32 _ldb, + .param .u64 _c, + .param .u32 _ldc) +{ + .reg .u32 ldb, ldc; + ld.param.u32 ldb, [_ldb]; + ld.param.u32 ldc, [_ldc]; +% else: +.visible .entry ${kname}(.param .u64 _b, + .param .u64 _c) +{ +% endif + .reg .u32 n, id, tid_x, tid_y; + .reg .u64 b, c, b_base, c_base, csub_thread; + .reg .${pftype} bv<${bv_per_thread}>, cv<${cv_per_thread}>, dotp; + .reg .pred p1, p_skip; + .shared .align 8 .b8 _csub[${csub_bytes}]; + +% if n is None: + ld.param.u32 n, [_n]; +% else: + mov.u32 n, ${n}; +% endif + ld.param.u64 b, [_b]; + ld.param.u64 c, [_c]; + + { + .reg .u32 _ctaid_x; + mov.u32 _ctaid_x, %ctaid.x; + mov.u32 tid_x, %tid.x; + mov.u32 tid_y, %tid.y; + mad.lo.u32 id, _ctaid_x, ${blockx}, tid_x; + } + + setp.ge.u32 p1, id, n; + @p1 bra $L_EXIT; + + cvta.to.global.u64 b, b; + cvta.to.global.u64 c, c; + + { + .reg .u64 _id64; + cvt.u64.u32 _id64, id; + mad.lo.u64 b_base, _id64, ${dwidth_i}, b; + mad.lo.u64 c_base, _id64, ${dwidth_i}, c; + } + + { + .reg .u64 _tx_off; + mul.wide.u32 _tx_off, tid_x, ${dwidth_i}; + mov.u64 csub_thread, _csub; + add.u64 csub_thread, csub_thread, _tx_off; + } + +% for bid, kbx in enumerate(kparts): +## bid = ${bid}: ${len(kbx)} B columns, ksplit=${ksplit} + setp.ne.u32 p_skip, tid_y, ${bid}; + @p_skip bra $L_END_BID_${bid}; + +<% + loaded = set() + kbx_idx = {kx: i for i, kx in enumerate(kbx)} +%> + +% for cchunk_i, cchunk in enumerate(cchunks): +## Chunk ${cchunk_i}: partial dot-product +% for row_idx, j in enumerate(cchunk): +<% + nz = [(kbx_idx[kx], kx, A[j, kx]) for kx in kbx if A[j, kx] != 0] + owner_bid = row_idx % ksplit +%> +% for (kxi, kx, jx) in nz: +% if kx not in loaded: +% if n is None: + { + .reg .u32 _boff; + .reg .u64 _bptr; + mul.lo.u32 _boff, ldb, ${kx}; + mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; + ld.global.nc.${pftype} bv${kxi}, [_bptr]; + } +% else: + ld.global.nc.${pftype} bv${kxi}, [b_base + ${ldb*kx*dwidth_i}]; +% endif +<% loaded.add(kx) %> +% endif +% endfor +% if nz: +% for i, (kxi, kx, jx) in enumerate(nz): +% if i == 0: + mul.${pftype} dotp, bv${kxi}, ${jx}; +% else: + fma.rn.${pftype} dotp, bv${kxi}, ${jx}, dotp; +% endif +% endfor +% else: + mov.${pftype} dotp, ${fzero}; +% endif +% if owner_bid == bid: + mov.${pftype} cv${row_idx // ksplit}, dotp; +% else: +<% csub_idx = bid - (1 if bid > owner_bid else 0) %> + st.shared.${pftype} [csub_thread + ${(csub_idx * csz + row_idx) * blockx * dwidth_i}], dotp; +% endif +% endfor + bar.sync 0; + +## Combine phase (owned rows only) +% for row_idx, j in enumerate(cchunk): +% if row_idx % ksplit == bid: + mov.${pftype} dotp, cv${row_idx // ksplit}; +% for other_bid in range(ksplit): +% if other_bid != bid: +<% csub_idx = other_bid - (1 if other_bid > (row_idx % ksplit) else 0) %> + { + .reg .${pftype} _tmp; + ld.shared.${pftype} _tmp, [csub_thread + ${(csub_idx * csz + row_idx) * blockx * dwidth_i}]; + add.${pftype} dotp, dotp, _tmp; + } +% endif +% endfor +% if beta == 0: +% if n is None: + { + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + st.weak.global.cg.${pftype} [_cptr], dotp; + } +% else: + st.weak.global.cg.${pftype} [c_base + ${ldc*j*dwidth_i}], dotp; +% endif +% else: + { + .reg .${pftype} _ctmp; +% if n is None: + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + ld.global.${pftype} _ctmp, [_cptr]; + fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, dotp; + st.global.${pftype} [_cptr], _ctmp; +% else: + ld.global.${pftype} _ctmp, [c_base + ${ldc*j*dwidth_i}]; + fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, dotp; + st.global.${pftype} [c_base + ${ldc*j*dwidth_i}], _ctmp; +% endif + } +% endif + +% endif +% endfor + bar.sync 0; +% endfor + +$L_END_BID_${bid}: +% endfor + +$L_EXIT: + ret; +} diff --git a/gimmik/kernels/ptx/cstream-w2.mako b/gimmik/kernels/ptx/cstream-w2.mako new file mode 100644 index 0000000..150cf57 --- /dev/null +++ b/gimmik/kernels/ptx/cstream-w2.mako @@ -0,0 +1,98 @@ +<%inherit file='base'/> + +<% +pftype = "f64" +dwidth_i = 8 +fzero = "0d0000000000000000" +bix_list = list(bix) +bix_pos = {kx: i for i, kx in enumerate(bix_list)} +K_used = len(bix_list) +row_nz = [[(kx, A[j, kx]) for kx in range(k) if A[j, kx] != 0] for j in range(m)] +assert dtype == 'double', 'cstream-w2 is double-precision only' +assert n is not None, 'cstream-w2 requires compile-time n' +%> + +.visible .entry ${kname}(.param .u64 _b, + .param .u64 _c) +{ + .reg .u32 n, id; + .reg .u64 b, c, b_base, c_base; + .reg .f64 bv_a<${K_used}>, bv_b<${K_used}>, dotp_a, dotp_b; + .reg .pred p1; + + mov.u32 n, ${-(-n // 2)}; + ld.param.u64 b, [_b]; + ld.param.u64 c, [_c]; + + { + .reg .u32 _ctaid_x, _tid_x; + mov.u32 _ctaid_x, %ctaid.x; + mov.u32 _tid_x, %tid.x; + mad.lo.u32 id, _ctaid_x, ${blockx}, _tid_x; + } + + setp.ge.u32 p1, id, n; + @p1 bra $L_EXIT; + + cvta.to.global.u64 b, b; + cvta.to.global.u64 c, c; + + { + .reg .u64 _id64; + cvt.u64.u32 _id64, id; + mad.lo.u64 b_base, _id64, 16, b; + mad.lo.u64 c_base, _id64, 16, c; + } + +## Batch-load B column pairs +%for i, kx in enumerate(bix_list): + ld.global.nc.v2.f64 {bv_a${i}, bv_b${i}}, [b_base + ${ldb*kx*dwidth_i}]; +%endfor + +## Main compute: two parallel dot-product streams per thread +%for j in range(m): +% if row_nz[j]: +% for i_nz, (kx, jx) in enumerate(row_nz[j]): +% if i_nz == 0: + mul.f64 dotp_a, bv_a${bix_pos[kx]}, ${jx}; + mul.f64 dotp_b, bv_b${bix_pos[kx]}, ${jx}; +% else: + fma.rn.f64 dotp_a, bv_a${bix_pos[kx]}, ${jx}, dotp_a; + fma.rn.f64 dotp_b, bv_b${bix_pos[kx]}, ${jx}, dotp_b; +% endif +% endfor +% if beta == 0: + st.weak.global.cg.v2.f64 [c_base + ${ldc*j*dwidth_i}], {dotp_a, dotp_b}; +% else: + { + .reg .f64 _ca, _cb; + ld.global.v2.f64 {_ca, _cb}, [c_base + ${ldc*j*dwidth_i}]; + fma.rn.f64 _ca, _ca, ${float(beta)}, dotp_a; + fma.rn.f64 _cb, _cb, ${float(beta)}, dotp_b; + st.global.v2.f64 [c_base + ${ldc*j*dwidth_i}], {_ca, _cb}; + } +% endif + +% else: +## Zero row of A +% if beta == 0: + { + .reg .f64 _z; + mov.f64 _z, ${fzero}; + st.weak.global.cg.v2.f64 [c_base + ${ldc*j*dwidth_i}], {_z, _z}; + } +% elif beta != 1: + { + .reg .f64 _ca, _cb; + ld.global.v2.f64 {_ca, _cb}, [c_base + ${ldc*j*dwidth_i}]; + mul.f64 _ca, _ca, ${float(beta)}; + mul.f64 _cb, _cb, ${float(beta)}; + st.global.v2.f64 [c_base + ${ldc*j*dwidth_i}], {_ca, _cb}; + } +% endif +% endif +%endfor + +$L_EXIT: + ret; +} diff --git a/gimmik/kernels/ptx/cstream.mako b/gimmik/kernels/ptx/cstream.mako new file mode 100644 index 0000000..f26abeb --- /dev/null +++ b/gimmik/kernels/ptx/cstream.mako @@ -0,0 +1,157 @@ +<%inherit file='base'/> + +<% +pftype = "f32" if dtype == "float" else "f64" +dwidth_i = 4 if dtype == "float" else 8 +fzero = "0f00000000" if dtype == "float" else "0d0000000000000000" +bix_list = list(bix) +bix_pos = {kx: i for i, kx in enumerate(bix_list)} +K_used = len(bix_list) +row_nz = [[(kx, A[j, kx]) for kx in range(k) if A[j, kx] != 0] for j in range(m)] +%> + +% if n is None: +.visible .entry ${kname}(.param .u32 _n, + .param .u64 _b, + .param .u32 _ldb, + .param .u64 _c, + .param .u32 _ldc) +{ + .reg .u32 ldb, ldc; + ld.param.u32 ldb, [_ldb]; + ld.param.u32 ldc, [_ldc]; +% else: +.visible .entry ${kname}(.param .u64 _b, + .param .u64 _c) +{ +% endif + .reg .u32 n, id; + .reg .u64 b, c, b_base, c_base; + .reg .${pftype} bv<${K_used}>, dotp; + .reg .pred p1; + +% if n is None: + ld.param.u32 n, [_n]; +% else: + mov.u32 n, ${n}; +% endif + ld.param.u64 b, [_b]; + ld.param.u64 c, [_c]; + + { + .reg .u32 _grd<3>; + mov.u32 _grd0, %ntid.x; + mov.u32 _grd1, %ctaid.x; + mov.u32 _grd2, %tid.x; + mad.lo.u32 id, _grd0, _grd1, _grd2; + } + + setp.ge.u32 p1, id, n; + @p1 bra $L_EXIT; + + cvta.to.global.u64 b, b; + cvta.to.global.u64 c, c; + + { + .reg .u64 _id64; + cvt.u64.u32 _id64, id; + mad.lo.u64 b_base, _id64, ${dwidth_i}, b; + mad.lo.u64 c_base, _id64, ${dwidth_i}, c; + } + +## Batch-load active B columns +%for i, kx in enumerate(bix_list): +% if n is None: + { + .reg .u32 _boff; + .reg .u64 _bptr; + mul.lo.u32 _boff, ldb, ${kx}; + mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; + ld.global.nc.${pftype} bv${i}, [_bptr]; + } +% else: + ld.global.nc.${pftype} bv${i}, [b_base + ${ldb*kx*dwidth_i}]; +% endif +%endfor + +## Compute and store each output row +%for j in range(m): +% if row_nz[j]: +% for i_nz, (kx, jx) in enumerate(row_nz[j]): +% if i_nz == 0: + mul.${pftype} dotp, bv${bix_pos[kx]}, ${jx}; +% else: + fma.rn.${pftype} dotp, bv${bix_pos[kx]}, ${jx}, dotp; +% endif +% endfor +% if beta == 0: +% if n is None: + { + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + st.weak.global.cg.${pftype} [_cptr], dotp; + } +% else: + st.weak.global.cg.${pftype} [c_base + ${ldc*j*dwidth_i}], dotp; +% endif +% else: + { + .reg .${pftype} _ctmp; +% if n is None: + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + ld.global.${pftype} _ctmp, [_cptr]; + fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, dotp; + st.global.${pftype} [_cptr], _ctmp; +% else: + ld.global.${pftype} _ctmp, [c_base + ${ldc*j*dwidth_i}]; + fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, dotp; + st.global.${pftype} [c_base + ${ldc*j*dwidth_i}], _ctmp; +% endif + } +% endif + +% else: +## Zero row of A +% if beta == 0: + { + .reg .${pftype} _tmp; + mov.${pftype} _tmp, ${fzero}; +% if n is None: + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + st.weak.global.cg.${pftype} [_cptr], _tmp; +% else: + st.weak.global.cg.${pftype} [c_base + ${ldc*j*dwidth_i}], _tmp; +% endif + } +% elif beta != 1: + { + .reg .${pftype} _tmp; +% if n is None: + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + ld.global.${pftype} _tmp, [_cptr]; + mul.${pftype} _tmp, _tmp, ${float(beta)}; + st.global.${pftype} [_cptr], _tmp; +% else: + ld.global.${pftype} _tmp, [c_base + ${ldc*j*dwidth_i}]; + mul.${pftype} _tmp, _tmp, ${float(beta)}; + st.global.${pftype} [c_base + ${ldc*j*dwidth_i}], _tmp; +% endif + } +% endif +% endif +%endfor + +$L_EXIT: + ret; +} diff --git a/gimmik/kernels/ptx/dense-mma-gAd.mako b/gimmik/kernels/ptx/dense-mma-gAd.mako new file mode 100644 index 0000000..dcb8463 --- /dev/null +++ b/gimmik/kernels/ptx/dense-mma-gAd.mako @@ -0,0 +1,210 @@ +<%inherit file='base'/> + +<%! +import struct +import math +%> + +<% +assert dtype == "double" +assert n is not None and ldb is not None and ldc is not None + +M, K_ = A.shape +assert K_ == k +M_PAD = -(-M // 8) * 8 +M_TILES = M_PAD // 8 +K_REM = k % 4 +K_PAD = k if K_REM == 0 else k + (4 - K_REM) +K_ITERS = K_PAD // 4 + +# A in fragment-layout (32 contiguous elements per fragment) +a_u64 = [] +for m_tile in range(M_TILES): + for k_iter in range(K_ITERS): + for lane in range(32): + r_div4 = lane // 4 + r_mod4 = lane % 4 + i = m_tile * 8 + r_div4 + j = k_iter * 4 + r_mod4 + v = float(A[i, j]) if (i < M and j < k) else 0.0 + u = struct.unpack(' + +.global .align 16 .b64 ${kname}_Ag[${A_ELEMS}] = { + ${', '.join(a_u64)} +}; + +.visible .entry ${kname}(.param .u64 _b, + .param .u64 _c) +{ + .reg .u32 tid, warp, lane, r_mod4, r_div4; + .reg .u64 b_ptr, c_ptr; + .reg .u32 warp_n_base; + .reg .u64 ag_thr_base, b_thr_base, c_thr_base; + .reg .pred pwarp_exit; + .reg .f64 a_frag; +% for nt in range(NN): + .reg .u32 b_col_${nt}, c_col0_${nt}, c_col1_${nt}; + .reg .pred pvalid_bcol_${nt}, pvalid_c0col_${nt}, pvalid_c1col_${nt}; + .reg .f64 b_frag_${nt}; + .reg .f64 c0_${nt}_<${M_TILES}>, c1_${nt}_<${M_TILES}>; +% endfor + + ld.param.u64 b_ptr, [_b]; + ld.param.u64 c_ptr, [_c]; + cvta.to.global.u64 b_ptr, b_ptr; + cvta.to.global.u64 c_ptr, c_ptr; + + mov.u32 tid, %tid.x; + shr.u32 warp, tid, 5; + and.b32 lane, tid, 31; + shr.u32 r_div4, lane, 2; + and.b32 r_mod4, lane, 3; + + { + .reg .u32 cta; + mov.u32 cta, %ctaid.x; + mul.lo.u32 cta, cta, ${N_PER_CTA}; + mul.lo.u32 warp_n_base, warp, ${N_PER_WARP}; + add.u32 warp_n_base, warp_n_base, cta; + } + setp.ge.u32 pwarp_exit, warp_n_base, ${n}; + @pwarp_exit bra $L_EXIT; + +% for nt in range(NN): + add.u32 b_col_${nt}, warp_n_base, ${nt * 8}; + add.u32 b_col_${nt}, b_col_${nt}, r_div4; + { + .reg .u32 t; + shl.b32 t, r_mod4, 1; + add.u32 c_col0_${nt}, warp_n_base, ${nt * 8}; + add.u32 c_col0_${nt}, c_col0_${nt}, t; + add.u32 c_col1_${nt}, c_col0_${nt}, 1; + } + setp.lt.u32 pvalid_bcol_${nt}, b_col_${nt}, ${n}; + setp.lt.u32 pvalid_c0col_${nt}, c_col0_${nt}, ${n}; + setp.lt.u32 pvalid_c1col_${nt}, c_col1_${nt}, ${n}; +% endfor + + // A global thread base: &Ag[0] (generic -> global) + lane*8 + { + .reg .u64 t64, a_glb_base, lane64; + mov.u64 a_glb_base, ${kname}_Ag; + cvta.to.global.u64 a_glb_base, a_glb_base; + cvt.u64.u32 lane64, lane; + shl.b64 t64, lane64, 3; + add.u64 ag_thr_base, a_glb_base, t64; + } + + { + .reg .u64 t64, bcol64; + mul.wide.u32 t64, r_mod4, ${ldb}; + cvt.u64.u32 bcol64, b_col_0; + add.u64 t64, t64, bcol64; + shl.b64 t64, t64, 3; + add.u64 b_thr_base, b_ptr, t64; + } + + { + .reg .u64 t64, ccol64; + mul.wide.u32 t64, r_div4, ${ldc}; + cvt.u64.u32 ccol64, c_col0_0; + add.u64 t64, t64, ccol64; + shl.b64 t64, t64, 3; + add.u64 c_thr_base, c_ptr, t64; + } + +% for mt in range(M_TILES): + .reg .pred pm_${mt}; + { + .reg .u32 crow; + add.u32 crow, r_div4, ${mt * 8}; + setp.lt.u32 pm_${mt}, crow, ${M}; + } +% endfor + +% for nt in range(NN): +% for mt in range(M_TILES): +% if beta == 0: + mov.f64 c0_${nt}_${mt}, 0d0000000000000000; + mov.f64 c1_${nt}_${mt}, 0d0000000000000000; +% else: + { + .reg .u64 caddr; + .reg .pred p0, p1; + add.u64 caddr, c_thr_base, ${mt * C_MTILE_STRIDE + nt * C_NTILE_STRIDE}; + and.pred p0, pm_${mt}, pvalid_c0col_${nt}; + and.pred p1, pm_${mt}, pvalid_c1col_${nt}; + mov.f64 c0_${nt}_${mt}, 0d0000000000000000; + mov.f64 c1_${nt}_${mt}, 0d0000000000000000; + @p0 ld.global.f64 c0_${nt}_${mt}, [caddr]; + @p1 ld.global.f64 c1_${nt}_${mt}, [caddr + 8]; + } +% endif +% endfor +% endfor + +% for ki in range(K_ITERS): +% for nt in range(NN): + { + .reg .u64 baddr; + .reg .pred pb_load; + add.u64 baddr, b_thr_base, ${ki * B_KITER_STRIDE + nt * B_NTILE_STRIDE}; +% if K_REM != 0 and ki == K_ITERS - 1: + { + .reg .u32 brow; + .reg .pred pbrow; + add.u32 brow, r_mod4, ${ki * 4}; + setp.lt.u32 pbrow, brow, ${k}; + and.pred pb_load, pbrow, pvalid_bcol_${nt}; + } +% else: + and.pred pb_load, pvalid_bcol_${nt}, pvalid_bcol_${nt}; +% endif + mov.f64 b_frag_${nt}, 0d0000000000000000; + @pb_load ld.global.nc.f64 b_frag_${nt}, [baddr]; + } +% endfor +% for mt in range(M_TILES): + ld.global.nc.f64 a_frag, [ag_thr_base + ${(mt * K_ITERS + ki) * FRAG_STRIDE_BYTES}]; +% for nt in range(NN): + mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64 + {c0_${nt}_${mt}, c1_${nt}_${mt}}, + {a_frag}, + {b_frag_${nt}}, + {c0_${nt}_${mt}, c1_${nt}_${mt}}; +% endfor +% endfor +% endfor + +% for nt in range(NN): +% for mt in range(M_TILES): + { + .reg .u64 caddr; + .reg .pred p0, p1; + add.u64 caddr, c_thr_base, ${mt * C_MTILE_STRIDE + nt * C_NTILE_STRIDE}; + and.pred p0, pm_${mt}, pvalid_c0col_${nt}; + and.pred p1, pm_${mt}, pvalid_c1col_${nt}; + @p0 st.global.f64 [caddr], c0_${nt}_${mt}; + @p1 st.global.f64 [caddr + 8], c1_${nt}_${mt}; + } +% endfor +% endfor + +$L_EXIT: + ret; +} diff --git a/gimmik/kernels/ptx/dense-mma-smem-gA.mako b/gimmik/kernels/ptx/dense-mma-smem-gA.mako new file mode 100644 index 0000000..d395b2e --- /dev/null +++ b/gimmik/kernels/ptx/dense-mma-smem-gA.mako @@ -0,0 +1,264 @@ +<%inherit file='base'/> + +<%! +import struct +import math +%> + +<% +assert dtype == "double" +assert n is not None and ldb is not None and ldc is not None + +M, K_ = A.shape +assert K_ == k +M_PAD = -(-M // 8) * 8 +M_TILES = M_PAD // 8 +K_REM = k % 4 +K_PAD = k if K_REM == 0 else k + (4 - K_REM) +K_ITERS = K_PAD // 4 + +# A in fragment-layout (same as dense-mma-smem-nn) +a_u64 = [] +for m_tile in range(M_TILES): + for k_iter in range(K_ITERS): + for lane in range(32): + r_div4 = lane // 4 + r_mod4 = lane % 4 + i = m_tile * 8 + r_div4 + j = k_iter * 4 + r_mod4 + v = float(A[i, j]) if (i < M and j < k) else 0.0 + u = struct.unpack(' 2*BLOCKX elements per copy iter +A_PAIRS = A_ELEMS // 2 # number of f64x2 pairs +A_PAIRS_TAIL = A_ELEMS % 2 # 0 if even, 1 if odd +COPY_V2_ITERS = math.ceil(A_PAIRS / BLOCKX) + +FRAG_STRIDE_BYTES = 32 * 8 +B_KITER_STRIDE = 4 * ldb * 8 +B_NTILE_STRIDE = 8 * 8 +C_MTILE_STRIDE = 8 * ldc * 8 +C_NTILE_STRIDE = 8 * 8 +%> + +.global .align 16 .b64 ${kname}_Ag[${A_ELEMS}] = { + ${', '.join(a_u64)} +}; +.shared .align 16 .b64 ${kname}_As[${A_ELEMS}]; + +.visible .entry ${kname}(.param .u64 _b, + .param .u64 _c) +{ + .reg .u32 tid, warp, lane, r_mod4, r_div4; + .reg .u64 b_ptr, c_ptr; + .reg .u32 warp_n_base; + .reg .u64 as_thr_base, b_thr_base, c_thr_base; + .reg .pred pwarp_exit; + .reg .f64 a_frag; +% for nt in range(NN): + .reg .u32 b_col_${nt}, c_col0_${nt}, c_col1_${nt}; + .reg .pred pvalid_bcol_${nt}, pvalid_c0col_${nt}, pvalid_c1col_${nt}; + .reg .f64 b_frag_${nt}; + .reg .f64 c0_${nt}_<${M_TILES}>, c1_${nt}_<${M_TILES}>; +% endfor + + ld.param.u64 b_ptr, [_b]; + ld.param.u64 c_ptr, [_c]; + cvta.to.global.u64 b_ptr, b_ptr; + cvta.to.global.u64 c_ptr, c_ptr; + + mov.u32 tid, %tid.x; + shr.u32 warp, tid, 5; + and.b32 lane, tid, 31; + shr.u32 r_div4, lane, 2; + and.b32 r_mod4, lane, 3; + + // ---- Cooperative copy A from .global to .shared using v2 loads ---- + { + .reg .u64 a_glb_base, a_smem_base; + mov.u64 a_glb_base, ${kname}_Ag; + cvta.to.global.u64 a_glb_base, a_glb_base; + mov.u64 a_smem_base, ${kname}_As; +% for ci in range(COPY_V2_ITERS): +<% + base_pair = ci * BLOCKX + is_last = ci == COPY_V2_ITERS - 1 + pairs_this = min(BLOCKX, A_PAIRS - base_pair) +%> + { + .reg .u32 pidx; + .reg .u64 off64, gaddr, saddr; + .reg .f64 v0, v1; +% if is_last and pairs_this < BLOCKX: + .reg .pred plast; + add.u32 pidx, tid, ${base_pair}; + setp.lt.u32 plast, pidx, ${A_PAIRS}; + mul.wide.u32 off64, pidx, 16; + add.u64 gaddr, a_glb_base, off64; + add.u64 saddr, a_smem_base, off64; + @plast ld.global.nc.v2.f64 {v0, v1}, [gaddr]; + @plast st.shared.v2.f64 [saddr], {v0, v1}; +% else: + add.u32 pidx, tid, ${base_pair}; + mul.wide.u32 off64, pidx, 16; + add.u64 gaddr, a_glb_base, off64; + add.u64 saddr, a_smem_base, off64; + ld.global.nc.v2.f64 {v0, v1}, [gaddr]; + st.shared.v2.f64 [saddr], {v0, v1}; +% endif + } +% endfor +% if A_PAIRS_TAIL: + // Odd element at the very end (rare; A_ELEMS odd) + { + .reg .pred plast; + .reg .u64 gaddr, saddr; + .reg .f64 v; + setp.eq.u32 plast, tid, 0; + add.u64 gaddr, a_glb_base, ${(A_ELEMS-1) * 8}; + add.u64 saddr, a_smem_base, ${(A_ELEMS-1) * 8}; + @plast ld.global.nc.f64 v, [gaddr]; + @plast st.shared.f64 [saddr], v; + } +% endif + } + bar.sync 0; + + { + .reg .u32 cta; + mov.u32 cta, %ctaid.x; + mul.lo.u32 cta, cta, ${N_PER_CTA}; + mul.lo.u32 warp_n_base, warp, ${N_PER_WARP}; + add.u32 warp_n_base, warp_n_base, cta; + } + setp.ge.u32 pwarp_exit, warp_n_base, ${n}; + @pwarp_exit bra $L_EXIT; + +% for nt in range(NN): + add.u32 b_col_${nt}, warp_n_base, ${nt * 8}; + add.u32 b_col_${nt}, b_col_${nt}, r_div4; + { + .reg .u32 t; + shl.b32 t, r_mod4, 1; + add.u32 c_col0_${nt}, warp_n_base, ${nt * 8}; + add.u32 c_col0_${nt}, c_col0_${nt}, t; + add.u32 c_col1_${nt}, c_col0_${nt}, 1; + } + setp.lt.u32 pvalid_bcol_${nt}, b_col_${nt}, ${n}; + setp.lt.u32 pvalid_c0col_${nt}, c_col0_${nt}, ${n}; + setp.lt.u32 pvalid_c1col_${nt}, c_col1_${nt}, ${n}; +% endfor + + { + .reg .u64 t64, a_smem_base, lane64; + mov.u64 a_smem_base, ${kname}_As; + cvt.u64.u32 lane64, lane; + shl.b64 t64, lane64, 3; + add.u64 as_thr_base, a_smem_base, t64; + } + + { + .reg .u64 t64, bcol64; + mul.wide.u32 t64, r_mod4, ${ldb}; + cvt.u64.u32 bcol64, b_col_0; + add.u64 t64, t64, bcol64; + shl.b64 t64, t64, 3; + add.u64 b_thr_base, b_ptr, t64; + } + + { + .reg .u64 t64, ccol64; + mul.wide.u32 t64, r_div4, ${ldc}; + cvt.u64.u32 ccol64, c_col0_0; + add.u64 t64, t64, ccol64; + shl.b64 t64, t64, 3; + add.u64 c_thr_base, c_ptr, t64; + } + +% for mt in range(M_TILES): + .reg .pred pm_${mt}; + { + .reg .u32 crow; + add.u32 crow, r_div4, ${mt * 8}; + setp.lt.u32 pm_${mt}, crow, ${M}; + } +% endfor + +% for nt in range(NN): +% for mt in range(M_TILES): +% if beta == 0: + mov.f64 c0_${nt}_${mt}, 0d0000000000000000; + mov.f64 c1_${nt}_${mt}, 0d0000000000000000; +% else: + { + .reg .u64 caddr; + .reg .pred p0, p1; + add.u64 caddr, c_thr_base, ${mt * C_MTILE_STRIDE + nt * C_NTILE_STRIDE}; + and.pred p0, pm_${mt}, pvalid_c0col_${nt}; + and.pred p1, pm_${mt}, pvalid_c1col_${nt}; + mov.f64 c0_${nt}_${mt}, 0d0000000000000000; + mov.f64 c1_${nt}_${mt}, 0d0000000000000000; + @p0 ld.global.f64 c0_${nt}_${mt}, [caddr]; + @p1 ld.global.f64 c1_${nt}_${mt}, [caddr + 8]; + } +% endif +% endfor +% endfor + +% for ki in range(K_ITERS): +% for nt in range(NN): + { + .reg .u64 baddr; + .reg .pred pb_load; + add.u64 baddr, b_thr_base, ${ki * B_KITER_STRIDE + nt * B_NTILE_STRIDE}; +% if K_REM != 0 and ki == K_ITERS - 1: + { + .reg .u32 brow; + .reg .pred pbrow; + add.u32 brow, r_mod4, ${ki * 4}; + setp.lt.u32 pbrow, brow, ${k}; + and.pred pb_load, pbrow, pvalid_bcol_${nt}; + } +% else: + and.pred pb_load, pvalid_bcol_${nt}, pvalid_bcol_${nt}; +% endif + mov.f64 b_frag_${nt}, 0d0000000000000000; + @pb_load ld.global.nc.f64 b_frag_${nt}, [baddr]; + } +% endfor +% for mt in range(M_TILES): + ld.shared.f64 a_frag, [as_thr_base + ${(mt * K_ITERS + ki) * FRAG_STRIDE_BYTES}]; +% for nt in range(NN): + mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64 + {c0_${nt}_${mt}, c1_${nt}_${mt}}, + {a_frag}, + {b_frag_${nt}}, + {c0_${nt}_${mt}, c1_${nt}_${mt}}; +% endfor +% endfor +% endfor + +% for nt in range(NN): +% for mt in range(M_TILES): + { + .reg .u64 caddr; + .reg .pred p0, p1; + add.u64 caddr, c_thr_base, ${mt * C_MTILE_STRIDE + nt * C_NTILE_STRIDE}; + and.pred p0, pm_${mt}, pvalid_c0col_${nt}; + and.pred p1, pm_${mt}, pvalid_c1col_${nt}; + @p0 st.global.f64 [caddr], c0_${nt}_${mt}; + @p1 st.global.f64 [caddr + 8], c1_${nt}_${mt}; + } +% endfor +% endfor + +$L_EXIT: + ret; +} diff --git a/gimmik/ptx.py b/gimmik/ptx.py index bccdb61..dd3b259 100644 --- a/gimmik/ptx.py +++ b/gimmik/ptx.py @@ -1,7 +1,10 @@ # -*- coding: utf-8 -*- +import numpy as np + from gimmik.base import MatMul + class PTXSource: def __init__(self): self._src = "" @@ -48,16 +51,97 @@ def _address(self, out, base, size, *offs): src += f"mov.{out_type} {out[0]}, {base[0]};" return f"{{{src}\n\t}}" - def _kernel_generators(self, dtype, dsize, *, compute_capability=None): base_args = {'address': lambda o, b, s, *off: self._address(o, b, s, *off), 'cc': compute_capability} - # B streaming, C accumulation kernel - args = base_args | {} - yield ('bstream', args, {}) + # Matrix-property gates + arr = self.A + nnz = int(np.count_nonzero(arr)) + nuq = int(len(np.unique(np.abs(arr)))) + density = nnz / arr.size + sparse_suitable = (nuq <= 28) or (density <= 0.15) + + cc = compute_capability or (0, 0) + dense_suitable = ( + dtype == 'double' + and cc >= (9, 0) + and self.n is not None + and self.m <= 128 + and self.k <= 128 + ) + + if sparse_suitable: + yield ('cstream', base_args | {}, {}) + + yield ('bstream', base_args | {}, {}) + + ms, bsz, blkx = 4, 24, 32 + args = base_args | {'msplit': ms, 'bsz': bsz, 'blockx': blkx} + meta = {'block': (blkx, ms, 1), 'shared': 2*bsz*blkx*dsize} + yield ('bstream-msplit', args, meta) + + ms, bsz, blkx = 1, 16, 128 + args = base_args | {'msplit': ms, 'bsz': bsz, 'blockx': blkx} + meta = {'block': (blkx, ms, 1), 'shared': 2*bsz*blkx*dsize} + yield ('bstream-msplit', args, meta) + + ks, csz, blkx = 2, 24, 32 + args = base_args | {'ksplit': ks, 'csz': csz, 'blockx': blkx} + meta = {'block': (blkx, ks, 1), 'shared': (ks - 1)*csz*blkx*dsize} + yield ('cstream-ksplit', args, meta) + + K_used = len(self.bix) + if K_used > 500: + ks, csz, blkx = 4, 20, 32 + args = base_args | {'ksplit': ks, 'csz': csz, 'blockx': blkx} + meta = {'block': (blkx, ks, 1), + 'shared': (ks - 1)*csz*blkx*dsize} + yield ('cstream-ksplit', args, meta) + + if (dtype == 'double' and self.n is not None and self.n % 2 == 0 + and K_used <= 100 + and (self.aligne is None or self.aligne % 2 == 0)): + blkx = 128 + args = base_args | {'blockx': blkx} + meta = {'block': (blkx, 1, 1), 'width': 2} + yield ('cstream-w2', args, meta) + + if dense_suitable: + # Dense DMMA m8n8k4 templates. Yields a small cover of the nn × w + # space that empirically spans the autotune winners seen on tet + # p=3,4 at N=500k. The PyFR wrapper's _benchmark picks the fastest. + for tpl in ('dense-mma-smem-gA', 'dense-mma-gAd'): + for nn in (1, 2, 4): + for w in (2, 4, 8): + blkx = 32 * w + n_per_cta = 8 * nn * w + if n_per_cta > self.n: + continue + args = base_args | {'warps_per_cta': w, 'nn': nn} + meta = { + 'block': (blkx, 1, 1), + 'grid': (-(-self.n // n_per_cta), 1, 1), + } + yield (tpl, args, meta) + + # Extra fine-grained nn for shapes where a specific nn usually + # wins (p3/tet/m132, p4/tet/m132). + for tpl in ('dense-mma-smem-gA', 'dense-mma-gAd'): + for nn in (6,): + for w in (1, 4): + blkx = 32 * w + n_per_cta = 8 * nn * w + if n_per_cta > self.n: + continue + args = base_args | {'warps_per_cta': w, 'nn': nn} + meta = { + 'block': (blkx, 1, 1), + 'grid': (-(-self.n // n_per_cta), 1, 1), + } + yield (tpl, args, meta) def _process_meta(self, meta): - if self.n is not None: + if self.n is not None and 'grid' not in meta: div = meta['block'][0]*meta['width'] meta['grid'] = (-(-self.n // div), 1, 1) From bbbb8ef94f7803d5449cbcc3c51a1409d8efa630 Mon Sep 17 00:00:00 2001 From: WillTrojak Date: Wed, 29 Apr 2026 06:49:21 -0700 Subject: [PATCH 03/11] Dense and sparse optimisation --- gimmik/kernels/ptx/dense-mma-gAd.mako | 134 +++++------ gimmik/kernels/ptx/dense-mma-smem-gA.mako | 257 ++++++++++++--------- gimmik/ptx.py | 259 ++++++++++++---------- 3 files changed, 348 insertions(+), 302 deletions(-) diff --git a/gimmik/kernels/ptx/dense-mma-gAd.mako b/gimmik/kernels/ptx/dense-mma-gAd.mako index dcb8463..7fc6572 100644 --- a/gimmik/kernels/ptx/dense-mma-gAd.mako +++ b/gimmik/kernels/ptx/dense-mma-gAd.mako @@ -1,50 +1,11 @@ <%inherit file='base'/> -<%! -import struct -import math -%> - <% assert dtype == "double" assert n is not None and ldb is not None and ldc is not None - -M, K_ = A.shape -assert K_ == k -M_PAD = -(-M // 8) * 8 -M_TILES = M_PAD // 8 -K_REM = k % 4 -K_PAD = k if K_REM == 0 else k + (4 - K_REM) -K_ITERS = K_PAD // 4 - -# A in fragment-layout (32 contiguous elements per fragment) -a_u64 = [] -for m_tile in range(M_TILES): - for k_iter in range(K_ITERS): - for lane in range(32): - r_div4 = lane // 4 - r_mod4 = lane % 4 - i = m_tile * 8 + r_div4 - j = k_iter * 4 + r_mod4 - v = float(A[i, j]) if (i < M and j < k) else 0.0 - u = struct.unpack(' -.global .align 16 .b64 ${kname}_Ag[${A_ELEMS}] = { +.global .align 16 .b64 ${kname}_Ag[${a_elems}] = { ${', '.join(a_u64)} }; @@ -57,11 +18,13 @@ C_NTILE_STRIDE = 8 * 8 .reg .u64 ag_thr_base, b_thr_base, c_thr_base; .reg .pred pwarp_exit; .reg .f64 a_frag; -% for nt in range(NN): +% for nt in range(nn): .reg .u32 b_col_${nt}, c_col0_${nt}, c_col1_${nt}; +% if not n_col_aligned: .reg .pred pvalid_bcol_${nt}, pvalid_c0col_${nt}, pvalid_c1col_${nt}; +% endif .reg .f64 b_frag_${nt}; - .reg .f64 c0_${nt}_<${M_TILES}>, c1_${nt}_<${M_TILES}>; + .reg .f64 c0_${nt}_<${m_tiles}>, c1_${nt}_<${m_tiles}>; % endfor ld.param.u64 b_ptr, [_b]; @@ -78,14 +41,14 @@ C_NTILE_STRIDE = 8 * 8 { .reg .u32 cta; mov.u32 cta, %ctaid.x; - mul.lo.u32 cta, cta, ${N_PER_CTA}; - mul.lo.u32 warp_n_base, warp, ${N_PER_WARP}; + mul.lo.u32 cta, cta, ${n_per_cta}; + mul.lo.u32 warp_n_base, warp, ${n_per_warp}; add.u32 warp_n_base, warp_n_base, cta; } setp.ge.u32 pwarp_exit, warp_n_base, ${n}; @pwarp_exit bra $L_EXIT; -% for nt in range(NN): +% for nt in range(nn): add.u32 b_col_${nt}, warp_n_base, ${nt * 8}; add.u32 b_col_${nt}, b_col_${nt}, r_div4; { @@ -95,12 +58,14 @@ C_NTILE_STRIDE = 8 * 8 add.u32 c_col0_${nt}, c_col0_${nt}, t; add.u32 c_col1_${nt}, c_col0_${nt}, 1; } +% if not n_col_aligned: setp.lt.u32 pvalid_bcol_${nt}, b_col_${nt}, ${n}; setp.lt.u32 pvalid_c0col_${nt}, c_col0_${nt}, ${n}; setp.lt.u32 pvalid_c1col_${nt}, c_col1_${nt}, ${n}; +% endif % endfor - // A global thread base: &Ag[0] (generic -> global) + lane*8 + // A thread base: &Ag[0] + lane*8 { .reg .u64 t64, a_glb_base, lane64; mov.u64 a_glb_base, ${kname}_Ag; @@ -128,60 +93,71 @@ C_NTILE_STRIDE = 8 * 8 add.u64 c_thr_base, c_ptr, t64; } -% for mt in range(M_TILES): +% for mt in range(m_tiles): +% if pm_runtime(mt): .reg .pred pm_${mt}; { .reg .u32 crow; add.u32 crow, r_div4, ${mt * 8}; - setp.lt.u32 pm_${mt}, crow, ${M}; + setp.lt.u32 pm_${mt}, crow, ${m}; } +% endif % endfor -% for nt in range(NN): -% for mt in range(M_TILES): +% for nt in range(nn): +% for mt in range(m_tiles): % if beta == 0: mov.f64 c0_${nt}_${mt}, 0d0000000000000000; mov.f64 c1_${nt}_${mt}, 0d0000000000000000; % else: +<% + pm = f'pm_{mt}' if pm_runtime(mt) else None + pvc0 = f'pvalid_c0col_{nt}' if not n_col_aligned else None + pvc1 = f'pvalid_c1col_{nt}' if not n_col_aligned else None + needs_zero_init = pm is not None or pvc0 is not None or pvc1 is not None +%> { .reg .u64 caddr; - .reg .pred p0, p1; - add.u64 caddr, c_thr_base, ${mt * C_MTILE_STRIDE + nt * C_NTILE_STRIDE}; - and.pred p0, pm_${mt}, pvalid_c0col_${nt}; - and.pred p1, pm_${mt}, pvalid_c1col_${nt}; + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; +% if needs_zero_init: mov.f64 c0_${nt}_${mt}, 0d0000000000000000; mov.f64 c1_${nt}_${mt}, 0d0000000000000000; - @p0 ld.global.f64 c0_${nt}_${mt}, [caddr]; - @p1 ld.global.f64 c1_${nt}_${mt}, [caddr + 8]; +% endif + ${pred_emit(f'ld.global.f64 c0_{nt}_{mt}, [caddr];', pm, pvc0, pred_reg=f'p0_{nt}_{mt}')} + ${pred_emit(f'ld.global.f64 c1_{nt}_{mt}, [caddr + 8];', pm, pvc1, pred_reg=f'p1_{nt}_{mt}')} } % endif % endfor % endfor -% for ki in range(K_ITERS): -% for nt in range(NN): +% for ki in range(k_iters): +% for nt in range(nn): +<% + pvb = f'pvalid_bcol_{nt}' if not n_col_aligned else None + k_tail = (k_rem != 0 and ki == k_iters - 1) + needs_zero = pvb is not None or k_tail + pbrow = 'pbrow' if k_tail else None +%> { .reg .u64 baddr; - .reg .pred pb_load; - add.u64 baddr, b_thr_base, ${ki * B_KITER_STRIDE + nt * B_NTILE_STRIDE}; -% if K_REM != 0 and ki == K_ITERS - 1: + add.u64 baddr, b_thr_base, ${ki * b_kiter_stride + nt * b_ntile_stride}; +% if needs_zero: + mov.f64 b_frag_${nt}, 0d0000000000000000; +% endif +% if k_tail: + .reg .pred pbrow; { .reg .u32 brow; - .reg .pred pbrow; add.u32 brow, r_mod4, ${ki * 4}; setp.lt.u32 pbrow, brow, ${k}; - and.pred pb_load, pbrow, pvalid_bcol_${nt}; } -% else: - and.pred pb_load, pvalid_bcol_${nt}, pvalid_bcol_${nt}; % endif - mov.f64 b_frag_${nt}, 0d0000000000000000; - @pb_load ld.global.nc.f64 b_frag_${nt}, [baddr]; + ${pred_emit(f'ld.global.nc.f64 b_frag_{nt}, [baddr];', pbrow, pvb, pred_reg=f'pb_{ki}_{nt}')} } % endfor -% for mt in range(M_TILES): - ld.global.nc.f64 a_frag, [ag_thr_base + ${(mt * K_ITERS + ki) * FRAG_STRIDE_BYTES}]; -% for nt in range(NN): +% for mt in range(m_tiles): + ld.global.nc.f64 a_frag, [ag_thr_base + ${(mt * k_iters + ki) * frag_stride_bytes}]; +% for nt in range(nn): mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64 {c0_${nt}_${mt}, c1_${nt}_${mt}}, {a_frag}, @@ -191,16 +167,18 @@ C_NTILE_STRIDE = 8 * 8 % endfor % endfor -% for nt in range(NN): -% for mt in range(M_TILES): +% for nt in range(nn): +% for mt in range(m_tiles): +<% + pm = f'pm_{mt}' if pm_runtime(mt) else None + pvc0 = f'pvalid_c0col_{nt}' if not n_col_aligned else None + pvc1 = f'pvalid_c1col_{nt}' if not n_col_aligned else None +%> { .reg .u64 caddr; - .reg .pred p0, p1; - add.u64 caddr, c_thr_base, ${mt * C_MTILE_STRIDE + nt * C_NTILE_STRIDE}; - and.pred p0, pm_${mt}, pvalid_c0col_${nt}; - and.pred p1, pm_${mt}, pvalid_c1col_${nt}; - @p0 st.global.f64 [caddr], c0_${nt}_${mt}; - @p1 st.global.f64 [caddr + 8], c1_${nt}_${mt}; + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; + ${pred_emit(f'st.global.f64 [caddr], c0_{nt}_{mt};', pm, pvc0, pred_reg=f'p0s_{nt}_{mt}')} + ${pred_emit(f'st.global.f64 [caddr + 8], c1_{nt}_{mt};', pm, pvc1, pred_reg=f'p1s_{nt}_{mt}')} } % endfor % endfor diff --git a/gimmik/kernels/ptx/dense-mma-smem-gA.mako b/gimmik/kernels/ptx/dense-mma-smem-gA.mako index d395b2e..8451e06 100644 --- a/gimmik/kernels/ptx/dense-mma-smem-gA.mako +++ b/gimmik/kernels/ptx/dense-mma-smem-gA.mako @@ -1,57 +1,24 @@ <%inherit file='base'/> -<%! -import struct -import math -%> - <% assert dtype == "double" assert n is not None and ldb is not None and ldc is not None - -M, K_ = A.shape -assert K_ == k -M_PAD = -(-M // 8) * 8 -M_TILES = M_PAD // 8 -K_REM = k % 4 -K_PAD = k if K_REM == 0 else k + (4 - K_REM) -K_ITERS = K_PAD // 4 - -# A in fragment-layout (same as dense-mma-smem-nn) -a_u64 = [] -for m_tile in range(M_TILES): - for k_iter in range(K_ITERS): - for lane in range(32): - r_div4 = lane // 4 - r_mod4 = lane % 4 - i = m_tile * 8 + r_div4 - j = k_iter * 4 + r_mod4 - v = float(A[i, j]) if (i < M and j < k) else 0.0 - u = struct.unpack(' 2*BLOCKX elements per copy iter -A_PAIRS = A_ELEMS // 2 # number of f64x2 pairs -A_PAIRS_TAIL = A_ELEMS % 2 # 0 if even, 1 if odd -COPY_V2_ITERS = math.ceil(A_PAIRS / BLOCKX) - -FRAG_STRIDE_BYTES = 32 * 8 -B_KITER_STRIDE = 4 * ldb * 8 -B_NTILE_STRIDE = 8 * 8 -C_MTILE_STRIDE = 8 * ldc * 8 -C_NTILE_STRIDE = 8 * 8 +# Cooperative-copy params (gA-only) +blockx = 32 * warps_per_cta +a_pairs = a_elems // 2 +a_pairs_tail = a_elems % 2 +copy_v2_iters = (a_pairs + blockx - 1) // blockx +bs = bool(context.get('block_stealing', False)) %> -.global .align 16 .b64 ${kname}_Ag[${A_ELEMS}] = { +% if bs: +.shared .align 8 .b64 ${kname}_mbar; +.shared .align 16 .b8 ${kname}_workid[16]; +% endif +.global .align 16 .b64 ${kname}_Ag[${a_elems}] = { ${', '.join(a_u64)} }; -.shared .align 16 .b64 ${kname}_As[${A_ELEMS}]; +.shared .align 16 .b64 ${kname}_As[${a_elems}]; .visible .entry ${kname}(.param .u64 _b, .param .u64 _c) @@ -62,11 +29,18 @@ C_NTILE_STRIDE = 8 * 8 .reg .u64 as_thr_base, b_thr_base, c_thr_base; .reg .pred pwarp_exit; .reg .f64 a_frag; -% for nt in range(NN): +% if bs: + .reg .u32 ctaid; + .reg .u32 mbar_a, work_a; + .reg .pred p_root, p_done, p_have; +% endif +% for nt in range(nn): .reg .u32 b_col_${nt}, c_col0_${nt}, c_col1_${nt}; +% if not n_col_aligned: .reg .pred pvalid_bcol_${nt}, pvalid_c0col_${nt}, pvalid_c1col_${nt}; +% endif .reg .f64 b_frag_${nt}; - .reg .f64 c0_${nt}_<${M_TILES}>, c1_${nt}_<${M_TILES}>; + .reg .f64 c0_${nt}_<${m_tiles}>, c1_${nt}_<${m_tiles}>; % endfor ld.param.u64 b_ptr, [_b]; @@ -80,26 +54,34 @@ C_NTILE_STRIDE = 8 * 8 shr.u32 r_div4, lane, 2; and.b32 r_mod4, lane, 3; - // ---- Cooperative copy A from .global to .shared using v2 loads ---- +% if bs: + setp.eq.u32 p_root, tid, 0; + mov.u32 mbar_a, ${kname}_mbar; + mov.u32 work_a, ${kname}_workid; + @p_root mbarrier.init.shared::cta.b64 [mbar_a], 1; + bar.sync 0; +% endif + + // Cooperative copy A from .global to .shared via v2 loads { .reg .u64 a_glb_base, a_smem_base; mov.u64 a_glb_base, ${kname}_Ag; cvta.to.global.u64 a_glb_base, a_glb_base; mov.u64 a_smem_base, ${kname}_As; -% for ci in range(COPY_V2_ITERS): +% for ci in range(copy_v2_iters): <% - base_pair = ci * BLOCKX - is_last = ci == COPY_V2_ITERS - 1 - pairs_this = min(BLOCKX, A_PAIRS - base_pair) + base_pair = ci * blockx + is_last = ci == copy_v2_iters - 1 + pairs_this = min(blockx, a_pairs - base_pair) %> { .reg .u32 pidx; .reg .u64 off64, gaddr, saddr; .reg .f64 v0, v1; -% if is_last and pairs_this < BLOCKX: +% if is_last and pairs_this < blockx: .reg .pred plast; add.u32 pidx, tid, ${base_pair}; - setp.lt.u32 plast, pidx, ${A_PAIRS}; + setp.lt.u32 plast, pidx, ${a_pairs}; mul.wide.u32 off64, pidx, 16; add.u64 gaddr, a_glb_base, off64; add.u64 saddr, a_smem_base, off64; @@ -115,15 +97,15 @@ C_NTILE_STRIDE = 8 * 8 % endif } % endfor -% if A_PAIRS_TAIL: - // Odd element at the very end (rare; A_ELEMS odd) +% if a_pairs_tail: + // Tail element (only when a_elems is odd) { .reg .pred plast; .reg .u64 gaddr, saddr; .reg .f64 v; setp.eq.u32 plast, tid, 0; - add.u64 gaddr, a_glb_base, ${(A_ELEMS-1) * 8}; - add.u64 saddr, a_smem_base, ${(A_ELEMS-1) * 8}; + add.u64 gaddr, a_glb_base, ${(a_elems-1) * 8}; + add.u64 saddr, a_smem_base, ${(a_elems-1) * 8}; @plast ld.global.nc.f64 v, [gaddr]; @plast st.shared.f64 [saddr], v; } @@ -131,17 +113,50 @@ C_NTILE_STRIDE = 8 * 8 } bar.sync 0; + // Lane-only base; lifted out of the optional steal loop + { + .reg .u64 t64, a_smem_base, lane64; + mov.u64 a_smem_base, ${kname}_As; + cvt.u64.u32 lane64, lane; + shl.b64 t64, lane64, 3; + add.u64 as_thr_base, a_smem_base, t64; + } + +% for mt in range(m_tiles): +% if pm_runtime(mt): + .reg .pred pm_${mt}; + { + .reg .u32 crow; + add.u32 crow, r_div4, ${mt * 8}; + setp.lt.u32 pm_${mt}, crow, ${m}; + } +% endif +% endfor + +% if bs: + mov.u32 ctaid, %ctaid.x; +$L_LOOP: +% endif + { .reg .u32 cta; +% if bs: + mov.u32 cta, ctaid; +% else: mov.u32 cta, %ctaid.x; - mul.lo.u32 cta, cta, ${N_PER_CTA}; - mul.lo.u32 warp_n_base, warp, ${N_PER_WARP}; +% endif + mul.lo.u32 cta, cta, ${n_per_cta}; + mul.lo.u32 warp_n_base, warp, ${n_per_warp}; add.u32 warp_n_base, warp_n_base, cta; } setp.ge.u32 pwarp_exit, warp_n_base, ${n}; +% if bs: + @pwarp_exit bra $L_STEAL; +% else: @pwarp_exit bra $L_EXIT; +% endif -% for nt in range(NN): +% for nt in range(nn): add.u32 b_col_${nt}, warp_n_base, ${nt * 8}; add.u32 b_col_${nt}, b_col_${nt}, r_div4; { @@ -151,19 +166,13 @@ C_NTILE_STRIDE = 8 * 8 add.u32 c_col0_${nt}, c_col0_${nt}, t; add.u32 c_col1_${nt}, c_col0_${nt}, 1; } +% if not n_col_aligned: setp.lt.u32 pvalid_bcol_${nt}, b_col_${nt}, ${n}; setp.lt.u32 pvalid_c0col_${nt}, c_col0_${nt}, ${n}; setp.lt.u32 pvalid_c1col_${nt}, c_col1_${nt}, ${n}; +% endif % endfor - { - .reg .u64 t64, a_smem_base, lane64; - mov.u64 a_smem_base, ${kname}_As; - cvt.u64.u32 lane64, lane; - shl.b64 t64, lane64, 3; - add.u64 as_thr_base, a_smem_base, t64; - } - { .reg .u64 t64, bcol64; mul.wide.u32 t64, r_mod4, ${ldb}; @@ -182,60 +191,60 @@ C_NTILE_STRIDE = 8 * 8 add.u64 c_thr_base, c_ptr, t64; } -% for mt in range(M_TILES): - .reg .pred pm_${mt}; - { - .reg .u32 crow; - add.u32 crow, r_div4, ${mt * 8}; - setp.lt.u32 pm_${mt}, crow, ${M}; - } -% endfor - -% for nt in range(NN): -% for mt in range(M_TILES): +% for nt in range(nn): +% for mt in range(m_tiles): % if beta == 0: mov.f64 c0_${nt}_${mt}, 0d0000000000000000; mov.f64 c1_${nt}_${mt}, 0d0000000000000000; % else: +<% + pm = f'pm_{mt}' if pm_runtime(mt) else None + pvc0 = f'pvalid_c0col_{nt}' if not n_col_aligned else None + pvc1 = f'pvalid_c1col_{nt}' if not n_col_aligned else None + needs_zero_init = pm is not None or pvc0 is not None or pvc1 is not None +%> { .reg .u64 caddr; - .reg .pred p0, p1; - add.u64 caddr, c_thr_base, ${mt * C_MTILE_STRIDE + nt * C_NTILE_STRIDE}; - and.pred p0, pm_${mt}, pvalid_c0col_${nt}; - and.pred p1, pm_${mt}, pvalid_c1col_${nt}; + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; +% if needs_zero_init: mov.f64 c0_${nt}_${mt}, 0d0000000000000000; mov.f64 c1_${nt}_${mt}, 0d0000000000000000; - @p0 ld.global.f64 c0_${nt}_${mt}, [caddr]; - @p1 ld.global.f64 c1_${nt}_${mt}, [caddr + 8]; +% endif + ${pred_emit(f'ld.global.f64 c0_{nt}_{mt}, [caddr];', pm, pvc0, pred_reg=f'p0_{nt}_{mt}')} + ${pred_emit(f'ld.global.f64 c1_{nt}_{mt}, [caddr + 8];', pm, pvc1, pred_reg=f'p1_{nt}_{mt}')} } % endif % endfor % endfor -% for ki in range(K_ITERS): -% for nt in range(NN): +% for ki in range(k_iters): +% for nt in range(nn): +<% + pvb = f'pvalid_bcol_{nt}' if not n_col_aligned else None + k_tail = (k_rem != 0 and ki == k_iters - 1) + needs_zero = pvb is not None or k_tail + pbrow = 'pbrow' if k_tail else None +%> { .reg .u64 baddr; - .reg .pred pb_load; - add.u64 baddr, b_thr_base, ${ki * B_KITER_STRIDE + nt * B_NTILE_STRIDE}; -% if K_REM != 0 and ki == K_ITERS - 1: + add.u64 baddr, b_thr_base, ${ki * b_kiter_stride + nt * b_ntile_stride}; +% if needs_zero: + mov.f64 b_frag_${nt}, 0d0000000000000000; +% endif +% if k_tail: + .reg .pred pbrow; { .reg .u32 brow; - .reg .pred pbrow; add.u32 brow, r_mod4, ${ki * 4}; setp.lt.u32 pbrow, brow, ${k}; - and.pred pb_load, pbrow, pvalid_bcol_${nt}; } -% else: - and.pred pb_load, pvalid_bcol_${nt}, pvalid_bcol_${nt}; % endif - mov.f64 b_frag_${nt}, 0d0000000000000000; - @pb_load ld.global.nc.f64 b_frag_${nt}, [baddr]; + ${pred_emit(f'ld.global.nc.f64 b_frag_{nt}, [baddr];', pbrow, pvb, pred_reg=f'pb_{ki}_{nt}')} } % endfor -% for mt in range(M_TILES): - ld.shared.f64 a_frag, [as_thr_base + ${(mt * K_ITERS + ki) * FRAG_STRIDE_BYTES}]; -% for nt in range(NN): +% for mt in range(m_tiles): + ld.shared.f64 a_frag, [as_thr_base + ${(mt * k_iters + ki) * frag_stride_bytes}]; +% for nt in range(nn): mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64 {c0_${nt}_${mt}, c1_${nt}_${mt}}, {a_frag}, @@ -245,20 +254,52 @@ C_NTILE_STRIDE = 8 * 8 % endfor % endfor -% for nt in range(NN): -% for mt in range(M_TILES): +% for nt in range(nn): +% for mt in range(m_tiles): +<% + pm = f'pm_{mt}' if pm_runtime(mt) else None + pvc0 = f'pvalid_c0col_{nt}' if not n_col_aligned else None + pvc1 = f'pvalid_c1col_{nt}' if not n_col_aligned else None +%> { .reg .u64 caddr; - .reg .pred p0, p1; - add.u64 caddr, c_thr_base, ${mt * C_MTILE_STRIDE + nt * C_NTILE_STRIDE}; - and.pred p0, pm_${mt}, pvalid_c0col_${nt}; - and.pred p1, pm_${mt}, pvalid_c1col_${nt}; - @p0 st.global.f64 [caddr], c0_${nt}_${mt}; - @p1 st.global.f64 [caddr + 8], c1_${nt}_${mt}; + add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; + ${pred_emit(f'st.global.f64 [caddr], c0_{nt}_{mt};', pm, pvc0, pred_reg=f'p0s_{nt}_{mt}')} + ${pred_emit(f'st.global.f64 [caddr + 8], c1_{nt}_{mt};', pm, pvc1, pred_reg=f'p1s_{nt}_{mt}')} } % endfor % endfor +% if bs: +$L_STEAL: + // Root issues async try_cancel + waits; bar.sync orders the workid load + @!p_root bra $L_AFTER_WAIT; + { + .reg .u64 state; + mbarrier.arrive.expect_tx.shared::cta.b64 state, [mbar_a], 16; + clusterlaunchcontrol.try_cancel.async.shared::cta.mbarrier::complete_tx::bytes.b128 [work_a], [mbar_a]; +$L_WAIT: + mbarrier.try_wait.shared::cta.b64 p_done, [mbar_a], state, 10000000; + @!p_done bra $L_WAIT; + } +$L_AFTER_WAIT: + bar.sync 0; + + { + .reg .b128 resp; + ld.shared::cta.b128 resp, [work_a]; + clusterlaunchcontrol.query_cancel.is_canceled.pred.b128 p_have, resp; + @!p_have bra $L_FIN; + // 1D grid: extract just x + clusterlaunchcontrol.query_cancel.get_first_ctaid::x.b32.b128 ctaid, resp; + } + bra.uni $L_LOOP; + +$L_FIN: + bar.sync 0; + @p_root mbarrier.inval.shared::cta.b64 [mbar_a]; +% endif + $L_EXIT: ret; } diff --git a/gimmik/ptx.py b/gimmik/ptx.py index dd3b259..d2c6894 100644 --- a/gimmik/ptx.py +++ b/gimmik/ptx.py @@ -1,145 +1,172 @@ # -*- coding: utf-8 -*- +import struct + import numpy as np from gimmik.base import MatMul -class PTXSource: - def __init__(self): - self._src = "" - - def __iadd__(self, other): - self._src = f"{self}\n\t{other}" - return self - - def __str__(self): - return self._src - - def __repr__(self): - return self._src - - class PTXMatMul(MatMul): platform = 'ptx' basemeta = {'block': (128, 1, 1), 'width': 1, 'shared': 0, 'dynamic_shared': 0} - def _address(self, out, base, size, *offs): - src = PTXSource() - out_type = out[1] - if out_type != base[1]: - raise RuntimeError("out and base must have the same type") - - if offs: - off_type = offs[0][1] - if not all(off[1] == off_type for off in offs): - raise RuntimeError("offsets must all have the same tpye") - - if len(offs) == 1: - off = offs[0] - mad_type = "lo" if out_type == off_type else "wide" - src += f"mad.{mad_type}.{off_type} {out[0]}, {size}, {off[0]}, {base[0]};" - else: - src += f".reg .{off_type} _addrs_acum;" - src += f"add.{off_type} _addrs_acum, {offs[0][0]}, {offs[1][0]};" - for off in offs[2:]: - src += f"add.{off_type} _addrs_acum, _addrs_acum, {off[0]};" - mad_type = "lo" if out_type == off_type else "wide" - src += f"mad.{mad_type}.{off_type} {out[0]}, {size}, _addrs_acum, {base[0]};" - else: - src += f"mov.{out_type} {out[0]}, {base[0]};" - return f"{{{src}\n\t}}" - - def _kernel_generators(self, dtype, dsize, *, compute_capability=None): - base_args = {'address': lambda o, b, s, *off: self._address(o, b, s, - *off), 'cc': compute_capability} - - # Matrix-property gates + def _kernel_generators(self, dtype, dsize, *, compute_capability=None, + trim_a=False): + base_args = {'cc': compute_capability, + 'pred_emit': self._pred_emit, + 'trim_a': bool(trim_a) and dtype == 'double'} + + yield from self._sparse_kernel_generators(dtype, dsize, base_args) + yield from self._dense_kernel_generators(dtype, dsize, base_args) + + def _sparse_kernel_generators(self, dtype, dsize, base_args): arr = self.A nnz = int(np.count_nonzero(arr)) nuq = int(len(np.unique(np.abs(arr)))) density = nnz / arr.size - sparse_suitable = (nuq <= 28) or (density <= 0.15) - - cc = compute_capability or (0, 0) - dense_suitable = ( - dtype == 'double' - and cc >= (9, 0) - and self.n is not None - and self.m <= 128 - and self.k <= 128 - ) + if not ((nuq <= 28) or (density <= 0.15)): + return - if sparse_suitable: - yield ('cstream', base_args | {}, {}) + # B loading, C streaming kernel + yield ('cstream', base_args | {}, {'desc': 'cstream'}) - yield ('bstream', base_args | {}, {}) + # B streaming, C accumulation kernel + yield ('bstream', base_args | {}, {'desc': 'bstream'}) - ms, bsz, blkx = 4, 24, 32 - args = base_args | {'msplit': ms, 'bsz': bsz, 'blockx': blkx} - meta = {'block': (blkx, ms, 1), 'shared': 2*bsz*blkx*dsize} - yield ('bstream-msplit', args, meta) + # Four-way m-split B streaming, C accumulation kernel + ms, bsz, blkx = 4, 24, 32 + args = base_args | {'msplit': ms, 'bsz': bsz, 'blockx': blkx} + meta = {'block': (blkx, ms, 1), 'shared': 2*bsz*blkx*dsize, + 'desc': f'bstream-msplit/m{ms}-b{bsz}-x{blkx}'} + yield ('bstream-msplit', args, meta) - ms, bsz, blkx = 1, 16, 128 + # Single-warp LDGSTS variant for medium-M beta=0 large-K cases + if self.beta == 0 and self.m <= 320 and len(self.bix) >= 64: + ms, bsz, blkx = 1, 32, 64 args = base_args | {'msplit': ms, 'bsz': bsz, 'blockx': blkx} - meta = {'block': (blkx, ms, 1), 'shared': 2*bsz*blkx*dsize} + meta = {'block': (blkx, ms, 1), + 'shared': 2*bsz*blkx*dsize, + 'desc': f'bstream-msplit/m{ms}-b{bsz}-x{blkx}'} yield ('bstream-msplit', args, meta) - ks, csz, blkx = 2, 24, 32 + # Two-way k-split B loading, C streaming kernel + ks, csz, blkx = 2, 24, 32 + args = base_args | {'ksplit': ks, 'csz': csz, 'blockx': blkx} + meta = {'block': (blkx, ks, 1), 'shared': (ks - 1)*csz*blkx*dsize, + 'desc': f'cstream-ksplit/k{ks}-c{csz}-x{blkx}'} + yield ('cstream-ksplit', args, meta) + + # Four-way k-split for large K + K_used = len(self.bix) + if K_used > 500: + ks, csz, blkx = 4, 20, 32 args = base_args | {'ksplit': ks, 'csz': csz, 'blockx': blkx} - meta = {'block': (blkx, ks, 1), 'shared': (ks - 1)*csz*blkx*dsize} + meta = {'block': (blkx, ks, 1), + 'shared': (ks - 1)*csz*blkx*dsize, + 'desc': f'cstream-ksplit/k{ks}-c{csz}-x{blkx}'} yield ('cstream-ksplit', args, meta) - K_used = len(self.bix) - if K_used > 500: - ks, csz, blkx = 4, 20, 32 - args = base_args | {'ksplit': ks, 'csz': csz, 'blockx': blkx} - meta = {'block': (blkx, ks, 1), - 'shared': (ks - 1)*csz*blkx*dsize} - yield ('cstream-ksplit', args, meta) - - if (dtype == 'double' and self.n is not None and self.n % 2 == 0 - and K_used <= 100 - and (self.aligne is None or self.aligne % 2 == 0)): - blkx = 128 - args = base_args | {'blockx': blkx} - meta = {'block': (blkx, 1, 1), 'width': 2} - yield ('cstream-w2', args, meta) - - if dense_suitable: - # Dense DMMA m8n8k4 templates. Yields a small cover of the nn × w - # space that empirically spans the autotune winners seen on tet - # p=3,4 at N=500k. The PyFR wrapper's _benchmark picks the fastest. - for tpl in ('dense-mma-smem-gA', 'dense-mma-gAd'): - for nn in (1, 2, 4): - for w in (2, 4, 8): - blkx = 32 * w - n_per_cta = 8 * nn * w - if n_per_cta > self.n: - continue - args = base_args | {'warps_per_cta': w, 'nn': nn} - meta = { - 'block': (blkx, 1, 1), - 'grid': (-(-self.n // n_per_cta), 1, 1), - } - yield (tpl, args, meta) - - # Extra fine-grained nn for shapes where a specific nn usually - # wins (p3/tet/m132, p4/tet/m132). - for tpl in ('dense-mma-smem-gA', 'dense-mma-gAd'): - for nn in (6,): - for w in (1, 4): - blkx = 32 * w - n_per_cta = 8 * nn * w - if n_per_cta > self.n: - continue - args = base_args | {'warps_per_cta': w, 'nn': nn} - meta = { - 'block': (blkx, 1, 1), - 'grid': (-(-self.n // n_per_cta), 1, 1), - } - yield (tpl, args, meta) + # Width-2 vector cstream for fp64 small-K + if (dtype == 'double' and self.n is not None and self.n % 2 == 0 + and K_used <= 100 + and (self.aligne is None or self.aligne % 2 == 0)): + blkx = 128 + args = base_args | {'blockx': blkx} + meta = {'block': (blkx, 1, 1), 'width': 2, + 'desc': f'cstream-w2/x{blkx}'} + yield ('cstream-w2', args, meta) + + def _dense_kernel_generators(self, dtype, dsize, base_args): + cc = base_args['cc'] or (0, 0) + if not (dtype == 'double' and cc >= (9, 0) and self.n is not None + and self.m <= 128 and self.k <= 128): + return + + # Dense DMMA m8n8k4; block stealing default on sm_100+ for gA + bs_default = cc >= (10, 0) + dense_configs = [ + ('dense-mma-smem-gA', 1, 8), + ('dense-mma-smem-gA', 2, 4), + ('dense-mma-smem-gA', 4, 4), + ('dense-mma-gAd', 2, 2), + ('dense-mma-gAd', 4, 2), + ] + for tpl, nn, w in dense_configs: + blkx = 32 * w + n_per_cta = 8 * nn * w + if n_per_cta > self.n: + continue + bs = (tpl == 'dense-mma-smem-gA') and bs_default + setup = self._dense_mma_setup(nn=nn, warps_per_cta=w) + args = (base_args | {'warps_per_cta': w, 'nn': nn, + 'block_stealing': bs} | setup) + meta = { + 'block': (blkx, 1, 1), + 'grid': (-(-self.n // n_per_cta), 1, 1), + 'desc': f'{tpl}/nn{nn}-w{w}{"-bs" if bs else ""}', + } + yield (tpl, args, meta) + + def _dense_mma_setup(self, *, nn, warps_per_cta): + a = self.A + m, k = a.shape + m_tiles = -(-m // 8) + k_rem = k % 4 + k_iters = (k + (4 - k_rem if k_rem else 0)) // 4 + + # A in fragment layout: lane l -> A[m_tile*8 + l/4][k_iter*4 + l%4] + a_u64 = [] + for m_tile in range(m_tiles): + for k_iter in range(k_iters): + for lane in range(32): + i = m_tile * 8 + lane // 4 + j = k_iter * 4 + lane % 4 + v = float(a[i, j]) if (i < m and j < k) else 0.0 + u = struct.unpack(' m + + return { + 'm_tiles': m_tiles, + 'k_rem': k_rem, 'k_iters': k_iters, + 'a_u64': a_u64, + 'n_per_warp': n_per_warp, 'n_per_cta': n_per_cta, + 'a_elems': a_elems, + 'frag_stride_bytes': 32 * 8, + 'b_kiter_stride': 4 * (self.ldb or 0) * 8, + 'b_ntile_stride': 8 * 8, + 'c_mtile_stride': 8 * (self.ldc or 0) * 8, + 'c_ntile_stride': 8 * 8, + 'n_col_aligned': n_col_aligned, + 'pm_runtime': pm_runtime, + } + + @staticmethod + def _pred_emit(instr, *preds, pred_reg=None, indent=' ' * 8): + actual = [p for p in preds if p is not None] + if not actual: + return instr + if len(actual) == 1: + return f'@{actual[0]} {instr}' + if pred_reg is None: + raise ValueError('pred_reg required when combining multiple ' + 'predicates') + lines = [f'.reg .pred {pred_reg};', + f'and.pred {pred_reg}, {actual[0]}, {actual[1]};'] + for p in actual[2:]: + lines.append(f'and.pred {pred_reg}, {pred_reg}, {p};') + lines.append(f'@{pred_reg} {instr}') + return f'\n{indent}'.join(lines) def _process_meta(self, meta): if self.n is not None and 'grid' not in meta: From 393b4095a3985d6de32551b5d6daa0de4cd312c4 Mon Sep 17 00:00:00 2001 From: WillTrojak Date: Mon, 11 May 2026 05:35:20 -0700 Subject: [PATCH 04/11] Added warp specialised dense kernel --- gimmik/kernels/ptx/dense-mma-ws.mako | 422 +++++++++++++++++++++++++++ gimmik/ptx.py | 120 +++++++- 2 files changed, 535 insertions(+), 7 deletions(-) create mode 100644 gimmik/kernels/ptx/dense-mma-ws.mako diff --git a/gimmik/kernels/ptx/dense-mma-ws.mako b/gimmik/kernels/ptx/dense-mma-ws.mako new file mode 100644 index 0000000..514837f --- /dev/null +++ b/gimmik/kernels/ptx/dense-mma-ws.mako @@ -0,0 +1,422 @@ +<%inherit file='base'/> +<% +assert dtype == "double" +assert n is not None and ldb is not None and ldc is not None +mbar_maxwait = '0x989680' +%> + +.global .align 16 .b64 ${kname}_Ag[${a_elems}] = { + ${', '.join(a_u64)} +}; +.extern .shared .align 128 .b8 ${kname}_dynm[]; +.const .align 64 .b8 ${kname}_bdesc[128]; +.const .align 64 .b8 ${kname}_cdesc[128]; + +.visible .entry ${kname}(.param .u64 _b, + .param .u64 _c) +.maxntid ${blockx_total}, 1, 1 +{ + .reg .b32 tid, warp, lane, phase, ctaid_x; + .reg .b32 base_brow, base_bcol, base_crow, base_ccol; + .reg .b32 work, block_idx_x, n_start_curr, n_start_next; + .reg .u64 bdesc_addr, cdesc_addr; + .reg .b32 a_smem, b1_smem, b2_smem, c_smem; + .reg .b32 tma_mbar, wid_new_mbar, bready_mbar, cready_mbar, cstored_mbar, steal_mbar; + .reg .b32 wid_used_mbar, wid_smem; + .reg .pred p_compute, p_prod, p_steal; + .reg .pred p_warp_lead; + .reg .pred p_done; + .reg .pred p_tid0; + + mov.u32 tid, %tid.x; + shr.u32 warp, tid, 5; + and.b32 lane, tid, 31; + mov.u32 ctaid_x, %ctaid.x; + + .reg .b32 dynm_base; + mov.u32 dynm_base, ${kname}_dynm; + add.u32 b1_smem, dynm_base, ${b1_off}; + add.u32 b2_smem, dynm_base, ${b2_off}; + add.u32 c_smem, dynm_base, ${c_off}; + add.u32 a_smem, dynm_base, ${a_off}; + add.u32 wid_smem, dynm_base, ${wid_off}; + + add.u32 tma_mbar, dynm_base, ${tma_mbar_off}; + add.u32 bready_mbar, dynm_base, ${bready_mbar_off}; + add.u32 cready_mbar, dynm_base, ${cready_mbar_off}; + add.u32 cstored_mbar, dynm_base, ${cstored_mbar_off}; + add.u32 steal_mbar, dynm_base, ${steal_mbar_off}; + add.u32 wid_new_mbar, dynm_base, ${wid_new_mbar_off}; + add.u32 wid_used_mbar, dynm_base, ${wid_used_mbar_off}; + + cvta.const.u64 bdesc_addr, ${kname}_bdesc; + cvta.const.u64 cdesc_addr, ${kname}_cdesc; + + setp.eq.u32 p_tid0, tid, 0; + + setp.lt.u32 p_compute, warp, ${n_comp_warps}; + setp.eq.u32 p_prod, warp, ${prod_warp}; + setp.eq.u32 p_steal, warp, ${steal_warp}; + + { + .reg .b32 _elect_lane; + elect.sync _elect_lane|p_warp_lead, 0xffffffff; + } + + // mbarrier init (tid 0 only); pre-arrive csmem_free so compute iter 0 + // can write csmem immediately. + { + .reg .pred p_init; + setp.eq.u32 p_init, tid, 0; + .reg .b64 _state; + @p_init mbarrier.init.shared::cta.b64 [tma_mbar], 32; + @p_init mbarrier.init.shared::cta.b64 [bready_mbar], 1; + @p_init mbarrier.init.shared::cta.b64 [cready_mbar], 1; + @p_init mbarrier.init.shared::cta.b64 [cstored_mbar], 1; + @p_init mbarrier.init.shared::cta.b64 [steal_mbar], 1; + @p_init mbarrier.init.shared::cta.b64 [wid_used_mbar], ${n_comp_warps + 1}; + @p_init mbarrier.init.shared::cta.b64 [wid_new_mbar], 1; + @p_init mbarrier.arrive.shared::cta.b64 _state, [cstored_mbar]; + @p_init fence.proxy.async.shared::cta; + } + bar.sync 0; + + // Cooperative copy A: .global -> a_smem (ld.global.nc.v2.f64) + { + .reg .u64 a_glb_base; + .reg .b32 pidx; + .reg .f64 av0, av1; + mov.u64 a_glb_base, ${kname}_Ag; + cvta.to.global.u64 a_glb_base, a_glb_base; +% for ci in range(copy_v2_iters): +<% + base_pair = ci * blockx_total + is_last = ci == copy_v2_iters - 1 + pairs_this = min(blockx_total, a_pairs - base_pair) + needs_guard = is_last and pairs_this < blockx_total +%> + { + .reg .u64 ofs64, gaddr; + .reg .b32 saddr; + add.u32 pidx, tid, ${base_pair}; +% if needs_guard: + .reg .pred p_load; + setp.lt.u32 p_load, pidx, ${a_pairs}; +% endif + mul.wide.u32 ofs64, pidx, 16; + add.u64 gaddr, a_glb_base, ofs64; + cvt.u32.u64 saddr, ofs64; + add.u32 saddr, saddr, a_smem; +% if needs_guard: + @p_load ld.global.nc.v2.f64 {av0, av1}, [gaddr]; + @p_load st.shared.v2.f64 [saddr], {av0, av1}; +% else: + ld.global.nc.v2.f64 {av0, av1}, [gaddr]; + st.shared.v2.f64 [saddr], {av0, av1}; +% endif + } +% endfor +% if a_pairs_tail: + { + .reg .pred p_tail; + .reg .u64 gaddr; + .reg .b32 saddr; + .reg .f64 v; + setp.eq.u32 p_tail, tid, 0; + add.u64 gaddr, a_glb_base, ${(a_elems - 1) * 8}; + mov.u32 saddr, ${(a_elems - 1) * 8}; + add.u32 saddr, saddr, a_smem; + @p_tail ld.global.nc.f64 v, [gaddr]; + @p_tail st.shared.f64 [saddr], v; + } +% endif + } + bar.sync 0; + + // Compute-warp lane geometry (cheap; all warps execute uniformly) + { + .reg .b32 t, w_n_base; + and.b32 base_brow, lane, 3; + shr.u32 base_crow, lane, 2; + mul.lo.u32 w_n_base, warp, ${n_per_warp}; + add.u32 base_bcol, base_crow, w_n_base; + shl.b32 t, base_brow, 1; + add.u32 base_ccol, t, w_n_base; + } + + // Producer warp: initial B load for ctaid_x's work + @!p_prod bra.uni $L_AFTER_INIT_B; + { + .reg .b32 n_start0; + mul.lo.u32 n_start0, ctaid_x, ${n_per_cta}; + @p_warp_lead cp.async.bulk.tensor.2d.shared::cta.global.tile.mbarrier::complete_tx::bytes + [b1_smem], [bdesc_addr, {n_start0, 0}], [tma_mbar]; + @p_warp_lead mbarrier.expect_tx.relaxed.cta.shared::cta.b64 + [tma_mbar], ${b_tile_bytes}; + bar.warp.sync 0xffffffff; + .reg .b64 state; + .reg .pred p1; + mbarrier.arrive.shared::cta.b64 state, [tma_mbar]; +$L_TMA_INIT_W: + mbarrier.try_wait.shared::cta.b64 p1, [tma_mbar], state, ${mbar_maxwait}; + @!p1 bra.uni $L_TMA_INIT_W; + .reg .b64 _state2; + @p_warp_lead mbarrier.arrive.shared::cta.b64 _state2, [bready_mbar]; + } +$L_AFTER_INIT_B: + + mov.u32 block_idx_x, ctaid_x; + mov.u32 work, 1; + mov.u32 phase, 0; + +$L_LOOP: + setp.eq.u32 p_done, work, 0; + @p_done bra.uni $L_EXIT; + + mul.lo.u32 n_start_curr, block_idx_x, ${n_per_cta}; + + // --- Compute Warps + @!p_compute bra.uni $L_AFTER_COMPUTE; + + // Wait on B + { + .reg .pred p1; +$L_WAIT_BRDY: + mbarrier.try_wait.parity.shared::cta.b64 p1, [bready_mbar], phase, ${mbar_maxwait}; + @!p1 bra.uni $L_WAIT_BRDY; + } + + // MMA + { + .reg .b32 b_sm_a; + .reg .pred p_ph; + setp.ne.u32 p_ph, phase, 0; + selp.b32 b_sm_a, b2_smem, b1_smem, p_ph; + + .reg .b32 a_thr_a; + { + .reg .b32 t; + shl.b32 t, lane, 3; + add.u32 a_thr_a, a_smem, t; + } +% for nt in range(nn): + .reg .b32 b_thr_a_${nt}; + { + .reg .b32 bcol_g, t_off; + add.u32 bcol_g, base_bcol, ${8 * nt}; + shl.b32 t_off, bcol_g, 3; + add.u32 b_thr_a_${nt}, b_sm_a, t_off; + } +% endfor + + .reg .b32 c_thr_smem; + { + .reg .b32 t1, ccol_b; + mul.lo.u32 t1, base_crow, ${n_per_cta * 8}; + shl.b32 ccol_b, base_ccol, 3; + add.u32 c_thr_smem, c_smem, t1; + add.u32 c_thr_smem, c_thr_smem, ccol_b; + } + + // Zero accumulators +% for mt in range(m_tiles): +% for nt in range(nn): + .reg .f64 d_x_${mt}_${nt}, d_y_${mt}_${nt}; + mov.f64 d_x_${mt}_${nt}, 0d0000000000000000; + mov.f64 d_y_${mt}_${nt}, 0d0000000000000000; +% endfor +% endfor + + .reg .f64 a_f; +% for mt in range(m_tiles): +% for kt in range(k_iters): +<% + k_tail = (k_rem != 0 and kt == k_iters - 1) +%> + { + .reg .b32 a_a; + add.u32 a_a, a_thr_a, ${(kt * 32 + mt * 32 * k_iters) * 8}; + ld.shared.f64 a_f, [a_a]; +% if k_tail: + .reg .pred pbrow_${mt}_${kt}; + { + .reg .b32 brow; + add.u32 brow, base_brow, ${4 * kt}; + setp.lt.u32 pbrow_${mt}_${kt}, brow, ${k}; + } +% endif +% for nt in range(nn): + { + .reg .b32 b_a, b_row; + .reg .f64 b_f; + add.u32 b_row, base_brow, ${4 * kt}; + mul.lo.u32 b_row, b_row, ${n_per_cta * 8}; + add.u32 b_a, b_thr_a_${nt}, b_row; +% if k_tail: + mov.f64 b_f, 0d0000000000000000; + @pbrow_${mt}_${kt} ld.shared.f64 b_f, [b_a]; +% else: + ld.shared.f64 b_f, [b_a]; +% endif + mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64 + {d_x_${mt}_${nt}, d_y_${mt}_${nt}}, {a_f}, {b_f}, + {d_x_${mt}_${nt}, d_y_${mt}_${nt}}; + } +% endfor + } +% endfor +% endfor + + // Wait until producer's prev-iter TMA-store of C has drained. + { + .reg .pred p1; +$L_WAIT_CSTORE: + mbarrier.try_wait.parity.shared::cta.b64 p1, [cstored_mbar], phase, ${mbar_maxwait}; + @!p1 bra.uni $L_WAIT_CSTORE; + } + + // Vector-store {d_x, d_y} pairs to csmem. M-tail / N-tail OOB rows + // are dropped by the C tensor map. +% for mt in range(m_tiles): +% for nt in range(nn): + { + .reg .b32 csaddr; + add.u32 csaddr, c_thr_smem, ${mt * c_mtile_smem_stride + nt * c_ntile_smem_stride}; + st.shared.v2.f64 [csaddr], {d_x_${mt}_${nt}, d_y_${mt}_${nt}}; + } +% endfor +% endfor + + bar.sync 1, ${comp_threads}; + fence.proxy.async.shared::cta; + { + .reg .b64 _state; + @p_tid0 mbarrier.arrive.shared::cta.b64 _state, [cready_mbar]; + } + + // Wait for new work and unpack + { + .reg .pred p1, p_canc; + .reg .b128 resp; +$L_WAIT_WNEW_C: + mbarrier.try_wait.parity.shared::cta.b64 p1, [wid_new_mbar], phase, ${mbar_maxwait}; + @!p1 bra.uni $L_WAIT_WNEW_C; + + ld.shared::cta.b128 resp, [wid_smem]; + clusterlaunchcontrol.query_cancel.is_canceled.pred.b128 p_canc, resp; + @p_canc clusterlaunchcontrol.query_cancel.get_first_ctaid::x.b32.b128 block_idx_x, resp; + selp.b32 work, 1, 0, p_canc; + + .reg .b64 _state; + @p_warp_lead mbarrier.arrive.shared::cta.b64 _state, [wid_used_mbar]; + } + } +$L_AFTER_COMPUTE: + + // --- Data Movement Warp + @!p_prod bra.uni $L_AFTER_DATA; + { + .reg .b32 n_c_store; + mul.lo.u32 n_c_store, block_idx_x, ${n_per_cta}; + + // Wait for new work and unpack + { + .reg .pred p1, p_canc; + .reg .b128 resp; +$L_WAIT_WNEW_D: + mbarrier.try_wait.parity.shared::cta.b64 p1, [wid_new_mbar], phase, ${mbar_maxwait}; + @!p1 bra.uni $L_WAIT_WNEW_D; + + ld.shared::cta.b128 resp, [wid_smem]; + clusterlaunchcontrol.query_cancel.is_canceled.pred.b128 p_canc, resp; + @p_canc clusterlaunchcontrol.query_cancel.get_first_ctaid::x.b32.b128 block_idx_x, resp; + selp.b32 work, 1, 0, p_canc; + .reg .b64 _state; + @p_warp_lead mbarrier.arrive.shared::cta.b64 _state, [wid_used_mbar]; + } + + // TMA loads of next B + { + mul.lo.u32 n_start_next, block_idx_x, ${n_per_cta}; + .reg .b32 b_next; + .reg .pred p_ph; + setp.ne.u32 p_ph, phase, 0; + selp.b32 b_next, b1_smem, b2_smem, p_ph; + @p_warp_lead cp.async.bulk.tensor.2d.shared::cta.global.tile.mbarrier::complete_tx::bytes + [b_next], [bdesc_addr, {n_start_next, 0}], [tma_mbar]; + @p_warp_lead mbarrier.expect_tx.relaxed.cta.shared::cta.b64 + [tma_mbar], ${b_tile_bytes}; + @p_warp_lead cp.async.bulk.commit_group; + } + bar.warp.sync 0xffffffff; + + // TMA store/reduce+store of a C + { + .reg .pred p1; + .reg .b64 _c_state; +$L_WAIT_CRDY: + mbarrier.try_wait.parity.shared::cta.b64 p1, [cready_mbar], phase, ${mbar_maxwait}; + @!p1 bra.uni $L_WAIT_CRDY; +% if beta == 0: + @p_warp_lead cp.async.bulk.tensor.2d.global.shared::cta.tile.bulk_group + [cdesc_addr, {n_c_store, 0}], [c_smem]; +% else: + @p_warp_lead cp.reduce.async.bulk.tensor.2d.global.shared::cta.add.tile.bulk_group + [cdesc_addr, {n_c_store, 0}], [c_smem]; +% endif + @p_warp_lead cp.async.bulk.commit_group; + @p_warp_lead cp.async.bulk.wait_group 0; + @p_warp_lead mbarrier.arrive.shared::cta.b64 _c_state, [cstored_mbar]; + } + + // Wait for next B to be ready, then signal B and C ready + { + .reg .b64 b_state, _bready_state, _c_state; + .reg .pred p1; + mbarrier.arrive.shared::cta.b64 b_state, [tma_mbar]; +$L_WAIT_TMA: + mbarrier.try_wait.shared::cta.b64 p1, [tma_mbar], b_state, ${mbar_maxwait}; + @!p1 bra.uni $L_WAIT_TMA; + + @p_warp_lead mbarrier.arrive.shared::cta.b64 _bready_state, [bready_mbar]; + } + } +$L_AFTER_DATA: + + // --- Controller Warp + @!p_steal bra.uni $L_AFTER_CTRL; + { + .reg .pred p1, p2, p_canc; + .reg .b64 _state; + .reg .b128 resp; + @p_warp_lead fence.proxy.async.shared::cta; + @p_warp_lead clusterlaunchcontrol.try_cancel.async.shared::cta.mbarrier::complete_tx::bytes.b128 + [wid_smem], [steal_mbar]; + @p_warp_lead mbarrier.arrive.expect_tx.shared::cta.b64 + _state, [steal_mbar], 16; + +$L_WAIT_STEAL: + mbarrier.try_wait.parity.shared::cta.b64 p1, [steal_mbar], phase, ${mbar_maxwait}; + @!p1 bra.uni $L_WAIT_STEAL; + + // Signal new work + @p_warp_lead mbarrier.arrive.shared::cta.b64 _state, [wid_new_mbar]; + + // Query if there's new work + ld.shared::cta.b128 resp, [wid_smem]; + clusterlaunchcontrol.query_cancel.is_canceled.pred.b128 p_canc, resp; + selp.b32 work, 1, 0, p_canc; + + // Wait for old work to be used +$L_WAIT_WUSED: + mbarrier.try_wait.parity.shared::cta.b64 p2, [wid_used_mbar], phase, ${mbar_maxwait}; + @!p2 bra.uni $L_WAIT_WUSED; + } +$L_AFTER_CTRL: + + xor.b32 phase, phase, 1; + bra.uni $L_LOOP; + +$L_EXIT: + ret; +} diff --git a/gimmik/ptx.py b/gimmik/ptx.py index d2c6894..bf32a62 100644 --- a/gimmik/ptx.py +++ b/gimmik/ptx.py @@ -12,6 +12,28 @@ class PTXMatMul(MatMul): basemeta = {'block': (128, 1, 1), 'width': 1, 'shared': 0, 'dynamic_shared': 0} + @staticmethod + def is_sparse_suitable(arr): + nnz = int(np.count_nonzero(arr)) + nuq = int(len(np.unique(np.abs(arr)))) + density = nnz / arr.size + return (nuq <= 28) or (density <= 0.15) + + @staticmethod + def is_dense_suitable(arr, dtype, cc): + """True if A's shape and the target arch support the dense DMMA + template family. Does NOT check runtime args (n, ldb, ldc); those + are validated when the generator runs.""" + return (np.dtype(dtype) == np.float64 + and cc is not None and cc >= (9, 0) + and arr.shape[0] <= 128 and arr.shape[1] <= 128) + + @classmethod + def is_suitable(cls, arr, dtype, cc): + """True if either sparse or dense templates are applicable.""" + return (cls.is_sparse_suitable(arr) + or cls.is_dense_suitable(arr, dtype, cc)) + def _kernel_generators(self, dtype, dsize, *, compute_capability=None, trim_a=False): base_args = {'cc': compute_capability, @@ -22,11 +44,7 @@ def _kernel_generators(self, dtype, dsize, *, compute_capability=None, yield from self._dense_kernel_generators(dtype, dsize, base_args) def _sparse_kernel_generators(self, dtype, dsize, base_args): - arr = self.A - nnz = int(np.count_nonzero(arr)) - nuq = int(len(np.unique(np.abs(arr)))) - density = nnz / arr.size - if not ((nuq <= 28) or (density <= 0.15)): + if not self.is_sparse_suitable(self.A): return # B loading, C streaming kernel @@ -80,8 +98,8 @@ def _sparse_kernel_generators(self, dtype, dsize, base_args): def _dense_kernel_generators(self, dtype, dsize, base_args): cc = base_args['cc'] or (0, 0) - if not (dtype == 'double' and cc >= (9, 0) and self.n is not None - and self.m <= 128 and self.k <= 128): + if not (self.is_dense_suitable(self.A, dtype, cc) + and self.n is not None): return # Dense DMMA m8n8k4; block stealing default on sm_100+ for gA @@ -109,6 +127,94 @@ def _dense_kernel_generators(self, dtype, dsize, base_args): } yield (tpl, args, meta) + # Warp-specialised dense DMMA with TMA B-load + TMA C-store. + if cc >= (10, 0): + yield from self._dense_ws_kernel_generators(dtype, dsize, base_args) + + def _dense_ws_kernel_generators(self, dtype, dsize, base_args): + m_pad = -(-self.m // 8) * 8 + k_pad = -(-self.k // 4) * 4 + # (nn, w_compute) -- block has w_compute + 2 warps (producer, stealer) + ws_configs = [(1, 4), (2, 4), (4, 4)] + for nn, w in ws_configs: + n_per_cta = 8 * nn * w + if n_per_cta > self.n: + continue + blkx = 32 * (w + 2) + setup = self._dense_mma_setup(nn=nn, warps_per_cta=w) + ws_layout = self._dense_ws_layout( + n_comp_warps=w, n_per_cta=n_per_cta, + m_pad=m_pad, k_pad=k_pad, a_elems=setup['a_elems'] + ) + # sm_100 supports up to 228 KiB shared per CTA with the + # set_shared_size opt-in. Reserve some headroom for L1 carveout. + if ws_layout['dynm_total_bytes'] > 200 * 1024: + continue + args = (base_args + | {'warps_per_cta': w, 'nn': nn} + | setup | ws_layout) + yield ('dense-mma-ws', args, { + 'block': (blkx, 1, 1), + 'grid': (-(-self.n // n_per_cta), 1, 1), + 'desc': f'dense-mma-ws/nn{nn}-w{w}', + 'ws_tensor_map': True, + 'ws_n_per_cta': n_per_cta, + 'ws_k_pad': k_pad, + 'ws_m_pad': m_pad, + 'dynamic_shared': ws_layout['dynm_total_bytes'], + }) + + @staticmethod + def _dense_ws_layout(*, n_comp_warps, n_per_cta, m_pad, k_pad, a_elems): + """Render-time constants for the dense-mma-ws template: warp roles, + cooperative-copy iteration counts, smem-tile sizes, mbar timeout, + and dynamic-shared byte offsets for each buffer.""" + n_total_warps = n_comp_warps + 2 + blockx_total = 32 * n_total_warps + a_pairs = a_elems // 2 + a_pairs_tail = a_elems % 2 + + b_tile_bytes = k_pad * n_per_cta * 8 + c_tile_bytes = m_pad * n_per_cta * 8 + a_bytes = a_elems * 8 + + smem_size = {'b1': b_tile_bytes, 'b2': b_tile_bytes, 'c': c_tile_bytes, + 'a': a_bytes, 'wid': 16} + smem_off, off = {}, 0 + for k, v in smem_size.items(): + off = (off + 15) & ~15 + smem_off[f'{k}_off'] = off + off += v + + mbar_names = ('tma', 'bready', 'cready', 'cstored', + 'steal', 'wid_new', 'wid_used') + for k in mbar_names: + smem_off[f'{k}_mbar_off'] = off + off += 8 + + # Pad total to 16-byte multiple + dynm_total_bytes = (off + 15) & ~15 + + params = {'n_comp_warps': n_comp_warps, + 'blockx_total': blockx_total, + 'prod_warp': n_comp_warps, + 'steal_warp': n_comp_warps + 1, + 'comp_threads': 32 * n_comp_warps, + 'a_pairs': a_pairs, + 'a_pairs_tail': a_pairs_tail, + 'copy_v2_iters': -(-a_pairs // blockx_total), + 'm_pad': m_pad, + 'k_pad': k_pad, + 'b_tile_doubles': k_pad * n_per_cta, + 'b_tile_bytes': b_tile_bytes, + 'c_tile_doubles': m_pad * n_per_cta, + 'c_mtile_smem_stride': 8 * n_per_cta * 8, + 'c_ntile_smem_stride': 8 * 8, + 'dynm_total_bytes': dynm_total_bytes, + } + params |= smem_off + return params + def _dense_mma_setup(self, *, nn, warps_per_cta): a = self.A m, k = a.shape From 67d1bebd516e29b7d3b70460919056a6e534ab0a Mon Sep 17 00:00:00 2001 From: WillTrojak Date: Wed, 13 May 2026 09:27:38 -0700 Subject: [PATCH 05/11] Performance tuning and cleanup --- gimmik/kernels/ptx/bstream-msplit.mako | 12 +- gimmik/kernels/ptx/bstream.mako | 38 ++- gimmik/kernels/ptx/cstream-ksplit.mako | 6 +- gimmik/kernels/ptx/cstream-w2.mako | 36 ++- gimmik/kernels/ptx/cstream.mako | 36 +-- gimmik/kernels/ptx/dense-mma-gAd.mako | 13 +- gimmik/kernels/ptx/dense-mma-smem-gA.mako | 4 +- gimmik/kernels/ptx/dense-mma-ws.mako | 334 ++++++++++++---------- gimmik/ptx.py | 58 ++-- 9 files changed, 271 insertions(+), 266 deletions(-) diff --git a/gimmik/kernels/ptx/bstream-msplit.mako b/gimmik/kernels/ptx/bstream-msplit.mako index 77e7ce7..0af5091 100644 --- a/gimmik/kernels/ptx/bstream-msplit.mako +++ b/gimmik/kernels/ptx/bstream-msplit.mako @@ -1,14 +1,13 @@ <%inherit file='base'/> <% -pftype = "f32" if dtype == "float" else "f64" -dwidth_i = 4 if dtype == "float" else 8 -fzero = "0f00000000" if dtype == "float" else "0d0000000000000000" +pftype = 'f32' if dtype == 'float' else 'f64' +dwidth_i = 4 if dtype == 'float' else 8 +fzero = '0f00000000' if dtype == 'float' else '0d0000000000000000' has_zero_rows = any(jx == -1 for jx in afix) mx = partition(A, into=msplit, by='rows') bix_list = list(bix) bchunks = chunk(bix_list, bsz) -nchunks = len(bchunks) m_per_group = max(len(mcx) for mcx in mx) bsub_bytes = 2 * bsz * blockx * dwidth_i def bsub_off(buf, idx): @@ -135,11 +134,11 @@ use_cpasync = cc is not None and (cc[0], cc[1]) >= (8, 0) and dwidth_i in (4, 8) % endif ## Main loop over B-chunks (double-buffered) -% for bb in range(nchunks): +% for bb in range(len(bchunks)): <% buf_cur = bb % 2 buf_next = (bb + 1) % 2 - is_last = (bb == nchunks - 1) + is_last = (bb == len(bchunks) - 1) %> % if not is_last: % for idx, kx in enumerate(bchunks[bb + 1]): @@ -232,6 +231,7 @@ use_cpasync = cc is not None and (cc[0], cc[1]) >= (8, 0) and dwidth_i in (4, 8) % endif bar.sync 0; % endfor +## End of Main loop over B-chunks ## Handle zero rows in this cid's group % if has_zero_rows: diff --git a/gimmik/kernels/ptx/bstream.mako b/gimmik/kernels/ptx/bstream.mako index 465eac3..f58e9b3 100644 --- a/gimmik/kernels/ptx/bstream.mako +++ b/gimmik/kernels/ptx/bstream.mako @@ -1,14 +1,12 @@ <%inherit file='base'/> <% -pftype = "f32" if dtype == "float" else "f64" -dwidth_i = 4 if dtype == "float" else 8 -fzero = "0f00000000" if dtype == "float" else "0d0000000000000000" +pftype = 'f32' if dtype == 'float' else 'f64' +dwidth_i = 4 if dtype == 'float' else 8 +fzero = '0f00000000' if dtype == 'float' else '0d0000000000000000' has_zero_rows = any(jx == -1 for jx in afix) bix_list = list(bix) -bix_idx = {kx: i for i, kx in enumerate(bix_list)} -preload_c = beta != 0 -need_scale = beta != 0 and beta != 1 +bix_pos = {kx: i for i, kx in enumerate(bix_list)} %> % if n is None: @@ -61,7 +59,7 @@ need_scale = beta != 0 and beta != 1 } ## Batch-load active B columns -%for i, kx in enumerate(bix_list): +% for i, kx in enumerate(bix_list): % if n is None: { .reg .u32 _boff; @@ -73,9 +71,9 @@ need_scale = beta != 0 and beta != 1 % else: ld.weak.global.cg.${pftype} bv${i}, [b_base + ${ldb*kx*dwidth_i}]; % endif -%endfor +% endfor -% if preload_c: +% if beta != 0: ## Pre-load C so per-row completion is a plain store % for j in range(m): % if afix[j] != -1: @@ -92,7 +90,7 @@ need_scale = beta != 0 and beta != 1 % endif % endif % endfor -% if need_scale: +% if beta != 0 and beta != 1: % for j in range(m): % if afix[j] != -1: mul.${pftype} csub${j}, csub${j}, ${float(beta)}; @@ -102,15 +100,15 @@ need_scale = beta != 0 and beta != 1 % endif ## Main compute -%for kx in bix_list: -% for j, jx in enumerate(A[:, kx]): -% if jx != 0: -% if preload_c: - fma.rn.${pftype} csub${j}, bv${bix_idx[kx]}, ${jx}, csub${j}; +% for kx in bix_list: +% for j, jx in enumerate(A[:, kx]): +% if jx != 0: +% if preload_c: + fma.rn.${pftype} csub${j}, bv${bix_pos[kx]}, ${jx}, csub${j}; % elif kx == afix[j]: - mul.${pftype} csub${j}, bv${bix_idx[kx]}, ${jx}; + mul.${pftype} csub${j}, bv${bix_pos[kx]}, ${jx}; % else: - fma.rn.${pftype} csub${j}, bv${bix_idx[kx]}, ${jx}, csub${j}; + fma.rn.${pftype} csub${j}, bv${bix_pos[kx]}, ${jx}, csub${j}; % endif % endif % if kx == alix[j]: @@ -126,9 +124,9 @@ need_scale = beta != 0 and beta != 1 st.weak.global.cg.${pftype} [c_base + ${ldc*j*dwidth_i}], csub${j}; % endif -% endif -% endfor -%endfor +% endif +% endfor +% endfor % if has_zero_rows: { diff --git a/gimmik/kernels/ptx/cstream-ksplit.mako b/gimmik/kernels/ptx/cstream-ksplit.mako index 06e8a77..1ba2491 100644 --- a/gimmik/kernels/ptx/cstream-ksplit.mako +++ b/gimmik/kernels/ptx/cstream-ksplit.mako @@ -1,9 +1,9 @@ <%inherit file='base'/> <% -pftype = "f32" if dtype == "float" else "f64" -dwidth_i = 4 if dtype == "float" else 8 -fzero = "0f00000000" if dtype == "float" else "0d0000000000000000" +pftype = 'f32' if dtype == 'float' else 'f64' +dwidth_i = 4 if dtype == 'float' else 8 +fzero = '0f00000000' if dtype == 'float' else '0d0000000000000000' kparts = partition(A, ksplit, by='cols') cchunks = chunk(list(range(m)), csz) cv_per_thread = -(-csz // ksplit) diff --git a/gimmik/kernels/ptx/cstream-w2.mako b/gimmik/kernels/ptx/cstream-w2.mako index 150cf57..c82ebab 100644 --- a/gimmik/kernels/ptx/cstream-w2.mako +++ b/gimmik/kernels/ptx/cstream-w2.mako @@ -1,15 +1,13 @@ <%inherit file='base'/> <% -pftype = "f64" +pftype = 'f64' dwidth_i = 8 -fzero = "0d0000000000000000" +fzero = '0d0000000000000000' bix_list = list(bix) bix_pos = {kx: i for i, kx in enumerate(bix_list)} -K_used = len(bix_list) -row_nz = [[(kx, A[j, kx]) for kx in range(k) if A[j, kx] != 0] for j in range(m)] -assert dtype == 'double', 'cstream-w2 is double-precision only' -assert n is not None, 'cstream-w2 requires compile-time n' +row_nz = [[(kx, A[j, kx]) for kx in range(k) if A[j, kx] != 0] + for j in range(m)] %> .visible .entry ${kname}(.param .u64 _b, @@ -17,7 +15,7 @@ assert n is not None, 'cstream-w2 requires compile-time n' { .reg .u32 n, id; .reg .u64 b, c, b_base, c_base; - .reg .f64 bv_a<${K_used}>, bv_b<${K_used}>, dotp_a, dotp_b; + .reg .f64 bv_a<${len(bix_list)}>, bv_b<${len(bix_list)}>, dotp_a, dotp_b; .reg .pred p1; mov.u32 n, ${-(-n // 2)}; @@ -45,22 +43,22 @@ assert n is not None, 'cstream-w2 requires compile-time n' } ## Batch-load B column pairs -%for i, kx in enumerate(bix_list): +% for i, kx in enumerate(bix_list): ld.global.nc.v2.f64 {bv_a${i}, bv_b${i}}, [b_base + ${ldb*kx*dwidth_i}]; -%endfor +% endfor ## Main compute: two parallel dot-product streams per thread -%for j in range(m): -% if row_nz[j]: -% for i_nz, (kx, jx) in enumerate(row_nz[j]): -% if i_nz == 0: +% for j in range(m): +% if row_nz[j]: +% for i_nz, (kx, jx) in enumerate(row_nz[j]): +% if i_nz == 0: mul.f64 dotp_a, bv_a${bix_pos[kx]}, ${jx}; mul.f64 dotp_b, bv_b${bix_pos[kx]}, ${jx}; -% else: +% else: fma.rn.f64 dotp_a, bv_a${bix_pos[kx]}, ${jx}, dotp_a; fma.rn.f64 dotp_b, bv_b${bix_pos[kx]}, ${jx}, dotp_b; -% endif -% endfor +% endif +% endfor % if beta == 0: st.weak.global.cg.v2.f64 [c_base + ${ldc*j*dwidth_i}], {dotp_a, dotp_b}; % else: @@ -73,7 +71,7 @@ assert n is not None, 'cstream-w2 requires compile-time n' } % endif -% else: +% else: ## Zero row of A % if beta == 0: { @@ -90,8 +88,8 @@ assert n is not None, 'cstream-w2 requires compile-time n' st.global.v2.f64 [c_base + ${ldc*j*dwidth_i}], {_ca, _cb}; } % endif -% endif -%endfor +% endif +% endfor $L_EXIT: ret; diff --git a/gimmik/kernels/ptx/cstream.mako b/gimmik/kernels/ptx/cstream.mako index f26abeb..ec46934 100644 --- a/gimmik/kernels/ptx/cstream.mako +++ b/gimmik/kernels/ptx/cstream.mako @@ -1,13 +1,13 @@ <%inherit file='base'/> <% -pftype = "f32" if dtype == "float" else "f64" -dwidth_i = 4 if dtype == "float" else 8 -fzero = "0f00000000" if dtype == "float" else "0d0000000000000000" +pftype = 'f32' if dtype == 'float' else 'f64' +dwidth_i = 4 if dtype == 'float' else 8 +fzero = '0f00000000' if dtype == 'float' else '0d0000000000000000' bix_list = list(bix) bix_pos = {kx: i for i, kx in enumerate(bix_list)} -K_used = len(bix_list) -row_nz = [[(kx, A[j, kx]) for kx in range(k) if A[j, kx] != 0] for j in range(m)] +row_nz = [[(kx, A[j, kx]) for kx in range(k) if A[j, kx] != 0] + for j in range(m)] %> % if n is None: @@ -27,7 +27,7 @@ row_nz = [[(kx, A[j, kx]) for kx in range(k) if A[j, kx] != 0] for j in range(m) % endif .reg .u32 n, id; .reg .u64 b, c, b_base, c_base; - .reg .${pftype} bv<${K_used}>, dotp; + .reg .${pftype} bv<${len(bix_list)}>, dotp; .reg .pred p1; % if n is None: @@ -60,7 +60,7 @@ row_nz = [[(kx, A[j, kx]) for kx in range(k) if A[j, kx] != 0] for j in range(m) } ## Batch-load active B columns -%for i, kx in enumerate(bix_list): +% for i, kx in enumerate(bix_list): % if n is None: { .reg .u32 _boff; @@ -72,18 +72,18 @@ row_nz = [[(kx, A[j, kx]) for kx in range(k) if A[j, kx] != 0] for j in range(m) % else: ld.global.nc.${pftype} bv${i}, [b_base + ${ldb*kx*dwidth_i}]; % endif -%endfor +% endfor ## Compute and store each output row -%for j in range(m): -% if row_nz[j]: -% for i_nz, (kx, jx) in enumerate(row_nz[j]): -% if i_nz == 0: +% for j in range(m): +% if row_nz[j]: +% for i_nz, (kx, jx) in enumerate(row_nz[j]): +% if i_nz == 0: mul.${pftype} dotp, bv${bix_pos[kx]}, ${jx}; -% else: +% else: fma.rn.${pftype} dotp, bv${bix_pos[kx]}, ${jx}, dotp; -% endif -% endfor +% endif +% endfor % if beta == 0: % if n is None: { @@ -115,7 +115,7 @@ row_nz = [[(kx, A[j, kx]) for kx in range(k) if A[j, kx] != 0] for j in range(m) } % endif -% else: +% else: ## Zero row of A % if beta == 0: { @@ -149,8 +149,8 @@ row_nz = [[(kx, A[j, kx]) for kx in range(k) if A[j, kx] != 0] for j in range(m) % endif } % endif -% endif -%endfor +% endif +% endfor $L_EXIT: ret; diff --git a/gimmik/kernels/ptx/dense-mma-gAd.mako b/gimmik/kernels/ptx/dense-mma-gAd.mako index 7fc6572..ce8066d 100644 --- a/gimmik/kernels/ptx/dense-mma-gAd.mako +++ b/gimmik/kernels/ptx/dense-mma-gAd.mako @@ -1,8 +1,7 @@ <%inherit file='base'/> <% -assert dtype == "double" -assert n is not None and ldb is not None and ldc is not None +fzero = '0d0000000000000000' %> .global .align 16 .b64 ${kname}_Ag[${a_elems}] = { @@ -107,8 +106,8 @@ assert n is not None and ldb is not None and ldc is not None % for nt in range(nn): % for mt in range(m_tiles): % if beta == 0: - mov.f64 c0_${nt}_${mt}, 0d0000000000000000; - mov.f64 c1_${nt}_${mt}, 0d0000000000000000; + mov.f64 c0_${nt}_${mt}, ${fzero}; + mov.f64 c1_${nt}_${mt}, ${fzero}; % else: <% pm = f'pm_{mt}' if pm_runtime(mt) else None @@ -120,8 +119,8 @@ assert n is not None and ldb is not None and ldc is not None .reg .u64 caddr; add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; % if needs_zero_init: - mov.f64 c0_${nt}_${mt}, 0d0000000000000000; - mov.f64 c1_${nt}_${mt}, 0d0000000000000000; + mov.f64 c0_${nt}_${mt}, ${fzero}; + mov.f64 c1_${nt}_${mt}, ${fzero}; % endif ${pred_emit(f'ld.global.f64 c0_{nt}_{mt}, [caddr];', pm, pvc0, pred_reg=f'p0_{nt}_{mt}')} ${pred_emit(f'ld.global.f64 c1_{nt}_{mt}, [caddr + 8];', pm, pvc1, pred_reg=f'p1_{nt}_{mt}')} @@ -142,7 +141,7 @@ assert n is not None and ldb is not None and ldc is not None .reg .u64 baddr; add.u64 baddr, b_thr_base, ${ki * b_kiter_stride + nt * b_ntile_stride}; % if needs_zero: - mov.f64 b_frag_${nt}, 0d0000000000000000; + mov.f64 b_frag_${nt}, ${fzero}; % endif % if k_tail: .reg .pred pbrow; diff --git a/gimmik/kernels/ptx/dense-mma-smem-gA.mako b/gimmik/kernels/ptx/dense-mma-smem-gA.mako index 8451e06..ec2f013 100644 --- a/gimmik/kernels/ptx/dense-mma-smem-gA.mako +++ b/gimmik/kernels/ptx/dense-mma-smem-gA.mako @@ -1,14 +1,12 @@ <%inherit file='base'/> <% -assert dtype == "double" -assert n is not None and ldb is not None and ldc is not None # Cooperative-copy params (gA-only) blockx = 32 * warps_per_cta a_pairs = a_elems // 2 a_pairs_tail = a_elems % 2 copy_v2_iters = (a_pairs + blockx - 1) // blockx -bs = bool(context.get('block_stealing', False)) +bs = bool(block_stealing) %> % if bs: diff --git a/gimmik/kernels/ptx/dense-mma-ws.mako b/gimmik/kernels/ptx/dense-mma-ws.mako index 514837f..e4b576a 100644 --- a/gimmik/kernels/ptx/dense-mma-ws.mako +++ b/gimmik/kernels/ptx/dense-mma-ws.mako @@ -1,158 +1,24 @@ <%inherit file='base'/> <% -assert dtype == "double" -assert n is not None and ldb is not None and ldc is not None mbar_maxwait = '0x989680' +direct_store = (beta == 0) %> -.global .align 16 .b64 ${kname}_Ag[${a_elems}] = { - ${', '.join(a_u64)} -}; -.extern .shared .align 128 .b8 ${kname}_dynm[]; -.const .align 64 .b8 ${kname}_bdesc[128]; -.const .align 64 .b8 ${kname}_cdesc[128]; - -.visible .entry ${kname}(.param .u64 _b, - .param .u64 _c) -.maxntid ${blockx_total}, 1, 1 -{ - .reg .b32 tid, warp, lane, phase, ctaid_x; - .reg .b32 base_brow, base_bcol, base_crow, base_ccol; - .reg .b32 work, block_idx_x, n_start_curr, n_start_next; - .reg .u64 bdesc_addr, cdesc_addr; - .reg .b32 a_smem, b1_smem, b2_smem, c_smem; - .reg .b32 tma_mbar, wid_new_mbar, bready_mbar, cready_mbar, cstored_mbar, steal_mbar; - .reg .b32 wid_used_mbar, wid_smem; - .reg .pred p_compute, p_prod, p_steal; - .reg .pred p_warp_lead; - .reg .pred p_done; - .reg .pred p_tid0; - - mov.u32 tid, %tid.x; - shr.u32 warp, tid, 5; - and.b32 lane, tid, 31; - mov.u32 ctaid_x, %ctaid.x; - - .reg .b32 dynm_base; - mov.u32 dynm_base, ${kname}_dynm; - add.u32 b1_smem, dynm_base, ${b1_off}; - add.u32 b2_smem, dynm_base, ${b2_off}; - add.u32 c_smem, dynm_base, ${c_off}; - add.u32 a_smem, dynm_base, ${a_off}; - add.u32 wid_smem, dynm_base, ${wid_off}; - - add.u32 tma_mbar, dynm_base, ${tma_mbar_off}; - add.u32 bready_mbar, dynm_base, ${bready_mbar_off}; - add.u32 cready_mbar, dynm_base, ${cready_mbar_off}; - add.u32 cstored_mbar, dynm_base, ${cstored_mbar_off}; - add.u32 steal_mbar, dynm_base, ${steal_mbar_off}; - add.u32 wid_new_mbar, dynm_base, ${wid_new_mbar_off}; - add.u32 wid_used_mbar, dynm_base, ${wid_used_mbar_off}; - - cvta.const.u64 bdesc_addr, ${kname}_bdesc; - cvta.const.u64 cdesc_addr, ${kname}_cdesc; - - setp.eq.u32 p_tid0, tid, 0; - - setp.lt.u32 p_compute, warp, ${n_comp_warps}; - setp.eq.u32 p_prod, warp, ${prod_warp}; - setp.eq.u32 p_steal, warp, ${steal_warp}; - - { - .reg .b32 _elect_lane; - elect.sync _elect_lane|p_warp_lead, 0xffffffff; - } - - // mbarrier init (tid 0 only); pre-arrive csmem_free so compute iter 0 - // can write csmem immediately. - { - .reg .pred p_init; - setp.eq.u32 p_init, tid, 0; - .reg .b64 _state; - @p_init mbarrier.init.shared::cta.b64 [tma_mbar], 32; - @p_init mbarrier.init.shared::cta.b64 [bready_mbar], 1; - @p_init mbarrier.init.shared::cta.b64 [cready_mbar], 1; - @p_init mbarrier.init.shared::cta.b64 [cstored_mbar], 1; - @p_init mbarrier.init.shared::cta.b64 [steal_mbar], 1; - @p_init mbarrier.init.shared::cta.b64 [wid_used_mbar], ${n_comp_warps + 1}; - @p_init mbarrier.init.shared::cta.b64 [wid_new_mbar], 1; - @p_init mbarrier.arrive.shared::cta.b64 _state, [cstored_mbar]; - @p_init fence.proxy.async.shared::cta; - } - bar.sync 0; - - // Cooperative copy A: .global -> a_smem (ld.global.nc.v2.f64) - { - .reg .u64 a_glb_base; - .reg .b32 pidx; - .reg .f64 av0, av1; - mov.u64 a_glb_base, ${kname}_Ag; - cvta.to.global.u64 a_glb_base, a_glb_base; -% for ci in range(copy_v2_iters): -<% - base_pair = ci * blockx_total - is_last = ci == copy_v2_iters - 1 - pairs_this = min(blockx_total, a_pairs - base_pair) - needs_guard = is_last and pairs_this < blockx_total -%> - { - .reg .u64 ofs64, gaddr; - .reg .b32 saddr; - add.u32 pidx, tid, ${base_pair}; -% if needs_guard: - .reg .pred p_load; - setp.lt.u32 p_load, pidx, ${a_pairs}; -% endif - mul.wide.u32 ofs64, pidx, 16; - add.u64 gaddr, a_glb_base, ofs64; - cvt.u32.u64 saddr, ofs64; - add.u32 saddr, saddr, a_smem; -% if needs_guard: - @p_load ld.global.nc.v2.f64 {av0, av1}, [gaddr]; - @p_load st.shared.v2.f64 [saddr], {av0, av1}; -% else: - ld.global.nc.v2.f64 {av0, av1}, [gaddr]; - st.shared.v2.f64 [saddr], {av0, av1}; -% endif - } -% endfor -% if a_pairs_tail: - { - .reg .pred p_tail; - .reg .u64 gaddr; - .reg .b32 saddr; - .reg .f64 v; - setp.eq.u32 p_tail, tid, 0; - add.u64 gaddr, a_glb_base, ${(a_elems - 1) * 8}; - mov.u32 saddr, ${(a_elems - 1) * 8}; - add.u32 saddr, saddr, a_smem; - @p_tail ld.global.nc.f64 v, [gaddr]; - @p_tail st.shared.f64 [saddr], v; - } -% endif - } - bar.sync 0; - - // Compute-warp lane geometry (cheap; all warps execute uniformly) - { - .reg .b32 t, w_n_base; - and.b32 base_brow, lane, 3; - shr.u32 base_crow, lane, 2; - mul.lo.u32 w_n_base, warp, ${n_per_warp}; - add.u32 base_bcol, base_crow, w_n_base; - shl.b32 t, base_brow, 1; - add.u32 base_ccol, t, w_n_base; - } - - // Producer warp: initial B load for ctaid_x's work +<%def name="producer_init_setup()"> + // Producer warp: initial A bulk-copy + B load for ctaid_x's work @!p_prod bra.uni $L_AFTER_INIT_B; { .reg .b32 n_start0; + .reg .u64 a_glb; mul.lo.u32 n_start0, ctaid_x, ${n_per_cta}; + mov.u64 a_glb, ${kname}_Ag; + cvta.to.global.u64 a_glb, a_glb; + @p_warp_lead cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes + [a_smem], [a_glb], ${a_elems * 8}, [tma_mbar]; @p_warp_lead cp.async.bulk.tensor.2d.shared::cta.global.tile.mbarrier::complete_tx::bytes [b1_smem], [bdesc_addr, {n_start0, 0}], [tma_mbar]; @p_warp_lead mbarrier.expect_tx.relaxed.cta.shared::cta.b64 - [tma_mbar], ${b_tile_bytes}; + [tma_mbar], ${b_tile_bytes + a_elems * 8}; bar.warp.sync 0xffffffff; .reg .b64 state; .reg .pred p1; @@ -164,17 +30,9 @@ $L_TMA_INIT_W: @p_warp_lead mbarrier.arrive.shared::cta.b64 _state2, [bready_mbar]; } $L_AFTER_INIT_B: + - mov.u32 block_idx_x, ctaid_x; - mov.u32 work, 1; - mov.u32 phase, 0; - -$L_LOOP: - setp.eq.u32 p_done, work, 0; - @p_done bra.uni $L_EXIT; - - mul.lo.u32 n_start_curr, block_idx_x, ${n_per_cta}; - +<%def name="compute_warp_body()"> // --- Compute Warps @!p_compute bra.uni $L_AFTER_COMPUTE; @@ -209,6 +67,13 @@ $L_WAIT_BRDY: } % endfor +% if direct_store: + // direct_store: skip shared-staging entirely; compute warps store + // MMA outputs straight to global C with N-tail predication. + .reg .u64 c_glob_addr; + ld.param.u64 c_glob_addr, [_c]; + cvta.to.global.u64 c_glob_addr, c_glob_addr; +% else: .reg .b32 c_thr_smem; { .reg .b32 t1, ccol_b; @@ -217,6 +82,7 @@ $L_WAIT_BRDY: add.u32 c_thr_smem, c_smem, t1; add.u32 c_thr_smem, c_thr_smem, ccol_b; } +% endif // Zero accumulators % for mt in range(m_tiles): @@ -267,6 +133,45 @@ $L_WAIT_BRDY: % endfor % endfor +% if direct_store: + .reg .u64 c_thr_glob_base; + { + .reg .u32 thr_col_off, thr_addr_off_lo; + add.u32 thr_col_off, base_ccol, n_start_curr; + mad.lo.u32 thr_addr_off_lo, base_crow, ${ldc}, thr_col_off; + .reg .u64 thr_byte_off; + mul.wide.u32 thr_byte_off, thr_addr_off_lo, 8; + add.u64 c_thr_glob_base, c_glob_addr, thr_byte_off; + } +% for mt in range(m_tiles): +<% + row_tail = (m_pad > m) and ((mt + 1) * 8 > m) +%> +% if row_tail: + .reg .pred p_row_${mt}; + { + .reg .b32 crow; + add.u32 crow, base_crow, ${8 * mt}; + setp.lt.u32 p_row_${mt}, crow, ${m}; + } +% endif +% for nt in range(nn): + { + .reg .pred p_st; + .reg .u32 g_ccol; + add.u32 g_ccol, base_ccol, ${8 * nt}; + add.u32 g_ccol, g_ccol, n_start_curr; + setp.lt.u32 p_st, g_ccol, ${n}; +% if row_tail: + and.pred p_st, p_st, p_row_${mt}; +% endif + .reg .u64 c_addr; + add.u64 c_addr, c_thr_glob_base, ${(mt * 8 * ldc + nt * 8) * 8}; + @p_st st.global.v2.f64 [c_addr], {d_x_${mt}_${nt}, d_y_${mt}_${nt}}; + } +% endfor +% endfor +% else: // Wait until producer's prev-iter TMA-store of C has drained. { .reg .pred p1; @@ -286,13 +191,16 @@ $L_WAIT_CSTORE: } % endfor % endfor +% endif +% if not direct_store: bar.sync 1, ${comp_threads}; fence.proxy.async.shared::cta; { .reg .b64 _state; @p_tid0 mbarrier.arrive.shared::cta.b64 _state, [cready_mbar]; } +% endif // Wait for new work and unpack { @@ -312,7 +220,9 @@ $L_WAIT_WNEW_C: } } $L_AFTER_COMPUTE: + +<%def name="data_warp_body()"> // --- Data Movement Warp @!p_prod bra.uni $L_AFTER_DATA; { @@ -350,24 +260,22 @@ $L_WAIT_WNEW_D: } bar.warp.sync 0xffffffff; - // TMA store/reduce+store of a C +% if not direct_store: + // TMA reduce+store of C (beta=1 only; beta=0 uses direct global + // stores from compute warps, so the producer does no C work). { .reg .pred p1; .reg .b64 _c_state; $L_WAIT_CRDY: mbarrier.try_wait.parity.shared::cta.b64 p1, [cready_mbar], phase, ${mbar_maxwait}; @!p1 bra.uni $L_WAIT_CRDY; -% if beta == 0: - @p_warp_lead cp.async.bulk.tensor.2d.global.shared::cta.tile.bulk_group - [cdesc_addr, {n_c_store, 0}], [c_smem]; -% else: @p_warp_lead cp.reduce.async.bulk.tensor.2d.global.shared::cta.add.tile.bulk_group [cdesc_addr, {n_c_store, 0}], [c_smem]; -% endif @p_warp_lead cp.async.bulk.commit_group; @p_warp_lead cp.async.bulk.wait_group 0; @p_warp_lead mbarrier.arrive.shared::cta.b64 _c_state, [cstored_mbar]; } +% endif // Wait for next B to be ready, then signal B and C ready { @@ -382,7 +290,9 @@ $L_WAIT_TMA: } } $L_AFTER_DATA: + +<%def name="ctrl_warp_body()"> // --- Controller Warp @!p_steal bra.uni $L_AFTER_CTRL; { @@ -413,6 +323,112 @@ $L_WAIT_WUSED: @!p2 bra.uni $L_WAIT_WUSED; } $L_AFTER_CTRL: + + +.global .align 16 .b64 ${kname}_Ag[${a_elems}] = { + ${', '.join(a_u64)} +}; +.extern .shared .align 128 .b8 ${kname}_dynm[]; +.const .align 64 .b8 ${kname}_bdesc[128]; +.const .align 64 .b8 ${kname}_cdesc[128]; + +.visible .entry ${kname}(.param .u64 _b, + .param .u64 _c) +.maxntid ${blockx_total}, 1, 1 +{ + .reg .b32 tid, warp, lane, phase, ctaid_x; + .reg .b32 base_brow, base_bcol, base_crow, base_ccol; + .reg .b32 work, block_idx_x, n_start_curr, n_start_next; + .reg .u64 bdesc_addr, cdesc_addr; + .reg .b32 a_smem, b1_smem, b2_smem, c_smem; + .reg .b32 tma_mbar, wid_new_mbar, bready_mbar, cready_mbar, cstored_mbar, steal_mbar; + .reg .b32 wid_used_mbar, wid_smem; + .reg .pred p_compute, p_prod, p_steal; + .reg .pred p_warp_lead; + .reg .pred p_done; + .reg .pred p_tid0; + + mov.u32 tid, %tid.x; + shr.u32 warp, tid, 5; + and.b32 lane, tid, 31; + mov.u32 ctaid_x, %ctaid.x; + + .reg .b32 dynm_base; + mov.u32 dynm_base, ${kname}_dynm; + add.u32 b1_smem, dynm_base, ${b1_off}; + add.u32 b2_smem, dynm_base, ${b2_off}; + add.u32 c_smem, dynm_base, ${c_off}; + add.u32 a_smem, dynm_base, ${a_off}; + add.u32 wid_smem, dynm_base, ${wid_off}; + + add.u32 tma_mbar, dynm_base, ${tma_mbar_off}; + add.u32 bready_mbar, dynm_base, ${bready_mbar_off}; + add.u32 cready_mbar, dynm_base, ${cready_mbar_off}; + add.u32 cstored_mbar, dynm_base, ${cstored_mbar_off}; + add.u32 steal_mbar, dynm_base, ${steal_mbar_off}; + add.u32 wid_new_mbar, dynm_base, ${wid_new_mbar_off}; + add.u32 wid_used_mbar, dynm_base, ${wid_used_mbar_off}; + + cvta.const.u64 bdesc_addr, ${kname}_bdesc; + cvta.const.u64 cdesc_addr, ${kname}_cdesc; + + setp.eq.u32 p_tid0, tid, 0; + + setp.lt.u32 p_compute, warp, ${n_comp_warps}; + setp.eq.u32 p_prod, warp, ${prod_warp}; + setp.eq.u32 p_steal, warp, ${steal_warp}; + + { + .reg .b32 _elect_lane; + elect.sync _elect_lane|p_warp_lead, 0xffffffff; + } + + // mbarrier init (tid 0 only); pre-arrive csmem_free so compute iter 0 + // can write csmem immediately. + { + .reg .pred p_init; + setp.eq.u32 p_init, tid, 0; + .reg .b64 _state; + @p_init mbarrier.init.shared::cta.b64 [tma_mbar], 32; + @p_init mbarrier.init.shared::cta.b64 [bready_mbar], 1; + @p_init mbarrier.init.shared::cta.b64 [cready_mbar], 1; + @p_init mbarrier.init.shared::cta.b64 [cstored_mbar], 1; + @p_init mbarrier.init.shared::cta.b64 [steal_mbar], 1; + @p_init mbarrier.init.shared::cta.b64 [wid_used_mbar], ${n_comp_warps + 1}; + @p_init mbarrier.init.shared::cta.b64 [wid_new_mbar], 1; + @p_init mbarrier.arrive.shared::cta.b64 _state, [cstored_mbar]; + @p_init fence.proxy.async.shared::cta; + } + bar.sync 0; + + // Compute-warp lane geometry (cheap; all warps execute uniformly) + { + .reg .b32 t, w_n_base; + and.b32 base_brow, lane, 3; + shr.u32 base_crow, lane, 2; + mul.lo.u32 w_n_base, warp, ${n_per_warp}; + add.u32 base_bcol, base_crow, w_n_base; + shl.b32 t, base_brow, 1; + add.u32 base_ccol, t, w_n_base; + } + + ${producer_init_setup()} + + mov.u32 block_idx_x, ctaid_x; + mov.u32 work, 1; + mov.u32 phase, 0; + +$L_LOOP: + setp.eq.u32 p_done, work, 0; + @p_done bra.uni $L_EXIT; + + mul.lo.u32 n_start_curr, block_idx_x, ${n_per_cta}; + + ${compute_warp_body()} + + ${data_warp_body()} + + ${ctrl_warp_body()} xor.b32 phase, phase, 1; bra.uni $L_LOOP; diff --git a/gimmik/ptx.py b/gimmik/ptx.py index bf32a62..ad429e8 100644 --- a/gimmik/ptx.py +++ b/gimmik/ptx.py @@ -19,26 +19,21 @@ def is_sparse_suitable(arr): density = nnz / arr.size return (nuq <= 28) or (density <= 0.15) + # Shape/arch gate for dense DMMA; n/ldb/ldc are validated at generate time @staticmethod def is_dense_suitable(arr, dtype, cc): - """True if A's shape and the target arch support the dense DMMA - template family. Does NOT check runtime args (n, ldb, ldc); those - are validated when the generator runs.""" return (np.dtype(dtype) == np.float64 and cc is not None and cc >= (9, 0) and arr.shape[0] <= 128 and arr.shape[1] <= 128) @classmethod def is_suitable(cls, arr, dtype, cc): - """True if either sparse or dense templates are applicable.""" return (cls.is_sparse_suitable(arr) or cls.is_dense_suitable(arr, dtype, cc)) - def _kernel_generators(self, dtype, dsize, *, compute_capability=None, - trim_a=False): + def _kernel_generators(self, dtype, dsize, *, compute_capability=None): base_args = {'cc': compute_capability, - 'pred_emit': self._pred_emit, - 'trim_a': bool(trim_a) and dtype == 'double'} + 'pred_emit': self._pred_emit} yield from self._sparse_kernel_generators(dtype, dsize, base_args) yield from self._dense_kernel_generators(dtype, dsize, base_args) @@ -48,10 +43,10 @@ def _sparse_kernel_generators(self, dtype, dsize, base_args): return # B loading, C streaming kernel - yield ('cstream', base_args | {}, {'desc': 'cstream'}) + yield ('cstream', base_args, {'desc': 'cstream'}) # B streaming, C accumulation kernel - yield ('bstream', base_args | {}, {'desc': 'bstream'}) + yield ('bstream', base_args, {'desc': 'bstream'}) # Four-way m-split B streaming, C accumulation kernel ms, bsz, blkx = 4, 24, 32 @@ -102,15 +97,22 @@ def _dense_kernel_generators(self, dtype, dsize, base_args): and self.n is not None): return - # Dense DMMA m8n8k4; block stealing default on sm_100+ for gA + # Some kernels can optional steal blocks bs_default = cc >= (10, 0) - dense_configs = [ - ('dense-mma-smem-gA', 1, 8), - ('dense-mma-smem-gA', 2, 4), - ('dense-mma-smem-gA', 4, 4), - ('dense-mma-gAd', 2, 2), - ('dense-mma-gAd', 4, 2), - ] + + if cc >= (10, 0): + # Warp specialised is uniformly better on sm_100+, so no need to JIT + # other versions + dense_configs = [('dense-mma-smem-gA', 4, 4)] + else: + dense_configs = [ + ('dense-mma-smem-gA', 1, 8), + ('dense-mma-smem-gA', 2, 4), + ('dense-mma-smem-gA', 4, 4), + ('dense-mma-gAd', 2, 2), + ('dense-mma-gAd', 4, 2), + ] + for tpl, nn, w in dense_configs: blkx = 32 * w n_per_cta = 8 * nn * w @@ -127,13 +129,14 @@ def _dense_kernel_generators(self, dtype, dsize, base_args): } yield (tpl, args, meta) - # Warp-specialised dense DMMA with TMA B-load + TMA C-store. + # Warp-specialised dense DMMA if cc >= (10, 0): yield from self._dense_ws_kernel_generators(dtype, dsize, base_args) def _dense_ws_kernel_generators(self, dtype, dsize, base_args): m_pad = -(-self.m // 8) * 8 k_pad = -(-self.k // 4) * 4 + # (nn, w_compute) -- block has w_compute + 2 warps (producer, stealer) ws_configs = [(1, 4), (2, 4), (4, 4)] for nn, w in ws_configs: @@ -146,14 +149,14 @@ def _dense_ws_kernel_generators(self, dtype, dsize, base_args): n_comp_warps=w, n_per_cta=n_per_cta, m_pad=m_pad, k_pad=k_pad, a_elems=setup['a_elems'] ) - # sm_100 supports up to 228 KiB shared per CTA with the - # set_shared_size opt-in. Reserve some headroom for L1 carveout. + if ws_layout['dynm_total_bytes'] > 200 * 1024: continue + args = (base_args | {'warps_per_cta': w, 'nn': nn} | setup | ws_layout) - yield ('dense-mma-ws', args, { + meta = { 'block': (blkx, 1, 1), 'grid': (-(-self.n // n_per_cta), 1, 1), 'desc': f'dense-mma-ws/nn{nn}-w{w}', @@ -162,17 +165,13 @@ def _dense_ws_kernel_generators(self, dtype, dsize, base_args): 'ws_k_pad': k_pad, 'ws_m_pad': m_pad, 'dynamic_shared': ws_layout['dynm_total_bytes'], - }) + } + yield ('dense-mma-ws', args, meta) @staticmethod def _dense_ws_layout(*, n_comp_warps, n_per_cta, m_pad, k_pad, a_elems): - """Render-time constants for the dense-mma-ws template: warp roles, - cooperative-copy iteration counts, smem-tile sizes, mbar timeout, - and dynamic-shared byte offsets for each buffer.""" n_total_warps = n_comp_warps + 2 blockx_total = 32 * n_total_warps - a_pairs = a_elems // 2 - a_pairs_tail = a_elems % 2 b_tile_bytes = k_pad * n_per_cta * 8 c_tile_bytes = m_pad * n_per_cta * 8 @@ -200,9 +199,6 @@ def _dense_ws_layout(*, n_comp_warps, n_per_cta, m_pad, k_pad, a_elems): 'prod_warp': n_comp_warps, 'steal_warp': n_comp_warps + 1, 'comp_threads': 32 * n_comp_warps, - 'a_pairs': a_pairs, - 'a_pairs_tail': a_pairs_tail, - 'copy_v2_iters': -(-a_pairs // blockx_total), 'm_pad': m_pad, 'k_pad': k_pad, 'b_tile_doubles': k_pad * n_per_cta, From e2a818bb9234d5326deda37035635dd6c0ae3129 Mon Sep 17 00:00:00 2001 From: Will Trojak Date: Fri, 15 May 2026 13:19:38 +0100 Subject: [PATCH 06/11] Whitespace --- gimmik/kernels/ptx/base.mako | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gimmik/kernels/ptx/base.mako b/gimmik/kernels/ptx/base.mako index e380f1b..b64ddc1 100644 --- a/gimmik/kernels/ptx/base.mako +++ b/gimmik/kernels/ptx/base.mako @@ -1,4 +1,4 @@ .version 8.7 .target sm_${cc[0]}${cc[1]}${"a" if cc[0] >= 9 else ""} .address_size 64 -${next.body()} \ No newline at end of file +${next.body()} From 7d7299a48bc486883f9b3e933ce7321d9e0e2dc2 Mon Sep 17 00:00:00 2001 From: WillTrojak Date: Tue, 19 May 2026 11:31:30 -0700 Subject: [PATCH 07/11] Cleanups, formating and addressign comments --- gimmik/kernels/ptx/base.mako | 4 +- gimmik/kernels/ptx/bstream-msplit.mako | 306 ++++++++++------------ gimmik/kernels/ptx/bstream.mako | 157 +++++------ gimmik/kernels/ptx/cstream-ksplit.mako | 157 ++++++----- gimmik/kernels/ptx/cstream-w2.mako | 80 +++--- gimmik/kernels/ptx/cstream.mako | 166 ++++++------ gimmik/kernels/ptx/dense-mma-gAd.mako | 86 +++--- gimmik/kernels/ptx/dense-mma-smem-gA.mako | 113 ++++---- gimmik/kernels/ptx/dense-mma-ws.mako | 114 ++++---- gimmik/ptx.py | 217 +++++++-------- 10 files changed, 667 insertions(+), 733 deletions(-) diff --git a/gimmik/kernels/ptx/base.mako b/gimmik/kernels/ptx/base.mako index b64ddc1..71eb414 100644 --- a/gimmik/kernels/ptx/base.mako +++ b/gimmik/kernels/ptx/base.mako @@ -1,4 +1,4 @@ -.version 8.7 -.target sm_${cc[0]}${cc[1]}${"a" if cc[0] >= 9 else ""} +.version 8.6 +.target sm_${cc[0]}${cc[1]}${'a' if cc[0] >= 9 else ''} .address_size 64 ${next.body()} diff --git a/gimmik/kernels/ptx/bstream-msplit.mako b/gimmik/kernels/ptx/bstream-msplit.mako index 0af5091..530b19f 100644 --- a/gimmik/kernels/ptx/bstream-msplit.mako +++ b/gimmik/kernels/ptx/bstream-msplit.mako @@ -1,12 +1,7 @@ <%inherit file='base'/> <% -pftype = 'f32' if dtype == 'float' else 'f64' -dwidth_i = 4 if dtype == 'float' else 8 -fzero = '0f00000000' if dtype == 'float' else '0d0000000000000000' -has_zero_rows = any(jx == -1 for jx in afix) mx = partition(A, into=msplit, by='rows') -bix_list = list(bix) bchunks = chunk(bix_list, bsz) m_per_group = max(len(mcx) for mcx in mx) bsub_bytes = 2 * bsz * blockx * dwidth_i @@ -48,11 +43,11 @@ use_cpasync = cc is not None and (cc[0], cc[1]) >= (8, 0) and dwidth_i in (4, 8) ld.param.u64 c, [_c]; { - .reg .u32 _ctaid_x; - mov.u32 _ctaid_x, %ctaid.x; - mov.u32 tid_x, %tid.x; - mov.u32 tid_y, %tid.y; - mad.lo.u32 id, _ctaid_x, ${blockx}, tid_x; + .reg .u32 _ctaid_x; + mov.u32 _ctaid_x, %ctaid.x; + mov.u32 tid_x, %tid.x; + mov.u32 tid_y, %tid.y; + mad.lo.u32 id, _ctaid_x, ${blockx}, tid_x; } setp.ge.u32 p1, id, n; @@ -62,23 +57,23 @@ use_cpasync = cc is not None and (cc[0], cc[1]) >= (8, 0) and dwidth_i in (4, 8) cvta.to.global.u64 c, c; { - .reg .u64 _id64; - cvt.u64.u32 _id64, id; - mad.lo.u64 b_base, _id64, ${dwidth_i}, b; - mad.lo.u64 c_base, _id64, ${dwidth_i}, c; + .reg .u64 _id64; + cvt.u64.u32 _id64, id; + mad.lo.u64 b_base, _id64, ${dwidth_i}, b; + mad.lo.u64 c_base, _id64, ${dwidth_i}, c; } { - .reg .u64 _tx_off; - mul.wide.u32 _tx_off, tid_x, ${dwidth_i}; - mov.u64 bsub_thread, _bsub; - add.u64 bsub_thread, bsub_thread, _tx_off; + .reg .u64 _tx_off; + mul.wide.u32 _tx_off, tid_x, ${dwidth_i}; + mov.u64 bsub_thread, _bsub; + add.u64 bsub_thread, bsub_thread, _tx_off; } % if use_cpasync: { - .reg .u64 _sm64; - cvta.to.shared.u64 _sm64, bsub_thread; - cvt.u32.u64 bsub_sm_thread, _sm64; + .reg .u64 _sm64; + cvta.to.shared.u64 _sm64, bsub_thread; + cvt.u32.u64 bsub_sm_thread, _sm64; } % endif @@ -87,191 +82,176 @@ use_cpasync = cc is not None and (cc[0], cc[1]) >= (8, 0) and dwidth_i in (4, 8) setp.ne.u32 p_skip, tid_y, ${cid}; @p_skip bra $L_END_CID_${cid}; -% if use_cpasync: +% if use_cpasync: ## Async fill of chunk 0 -% for idx, kx in enumerate(bchunks[0]): -% if idx % msplit == cid: -% if n is None: +% for idx, kx in [(i, k) for i, k in enumerate(bchunks[0]) if i % msplit == cid]: +% if n is None: { - .reg .u32 _boff; - .reg .u64 _bptr; - mul.lo.u32 _boff, ldb, ${kx}; - mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; - cp.async.ca.shared::cta.global [bsub_sm_thread + ${bsub_off(0, idx)}], [_bptr], ${dwidth_i}; + .reg .u32 _boff; + .reg .u64 _bptr; + mul.lo.u32 _boff, ldb, ${kx}; + mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; + cp.async.ca.shared::cta.global [bsub_sm_thread + ${bsub_off(0, idx)}], [_bptr], ${dwidth_i}; } -% else: +% else: cp.async.ca.shared::cta.global [bsub_sm_thread + ${bsub_off(0, idx)}], [b_base + ${ldb*kx*dwidth_i}], ${dwidth_i}; -% endif -% endif +% endif % endfor cp.async.commit_group; cp.async.wait_all; bar.sync 0; -% else: +% else: ## Sync fill of chunk 0 -% for idx, kx in enumerate(bchunks[0]): -% if idx % msplit == cid: -% if n is None: +% for idx, kx in [(i, k) for i, k in enumerate(bchunks[0]) if i % msplit == cid]: { - .reg .u32 _boff; - .reg .u64 _bptr; - .reg .${pftype} _bv; - mul.lo.u32 _boff, ldb, ${kx}; - mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; - ld.global.cg.${pftype} _bv, [_bptr]; - st.shared.${pftype} [bsub_thread + ${bsub_off(0, idx)}], _bv; + .reg .${pftype} _bv; +% if n is None: + .reg .u32 _boff; + .reg .u64 _bptr; + mul.lo.u32 _boff, ldb, ${kx}; + mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; + ld.weak.global.cg.${pftype} _bv, [_bptr]; +% else: + ld.weak.global.cg.${pftype} _bv, [b_base + ${ldb*kx*dwidth_i}]; +% endif + st.shared.${pftype} [bsub_thread + ${bsub_off(0, idx)}], _bv; } -% else: - { - .reg .${pftype} _bv; - ld.global.cg.${pftype} _bv, [b_base + ${ldb*kx*dwidth_i}]; - st.shared.${pftype} [bsub_thread + ${bsub_off(0, idx)}], _bv; - } -% endif -% endif % endfor bar.sync 0; -% endif +% endif ## Main loop over B-chunks (double-buffered) -% for bb in range(len(bchunks)): +% for bb in range(len(bchunks)): <% buf_cur = bb % 2 buf_next = (bb + 1) % 2 - is_last = (bb == len(bchunks) - 1) %> -% if not is_last: -% for idx, kx in enumerate(bchunks[bb + 1]): -% if idx % msplit == cid: -% if use_cpasync: -% if n is None: +% if not loop.last: +% for idx, kx in [(i, k) for i, k in enumerate(bchunks[bb + 1]) if i % msplit == cid]: +% if use_cpasync: +% if n is None: { - .reg .u32 _boff; - .reg .u64 _bptr; - mul.lo.u32 _boff, ldb, ${kx}; - mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; - cp.async.ca.shared::cta.global [bsub_sm_thread + ${bsub_off(buf_next, idx)}], [_bptr], ${dwidth_i}; + .reg .u32 _boff; + .reg .u64 _bptr; + mul.lo.u32 _boff, ldb, ${kx}; + mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; + cp.async.ca.shared::cta.global [bsub_sm_thread + ${bsub_off(buf_next, idx)}], [_bptr], ${dwidth_i}; } -% else: +% else: cp.async.ca.shared::cta.global [bsub_sm_thread + ${bsub_off(buf_next, idx)}], [b_base + ${ldb*kx*dwidth_i}], ${dwidth_i}; -% endif -% else: -% if n is None: - { - .reg .u32 _boff; - .reg .u64 _bptr; - .reg .${pftype} _bv; - mul.lo.u32 _boff, ldb, ${kx}; - mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; - ld.global.cg.${pftype} _bv, [_bptr]; - st.shared.${pftype} [bsub_thread + ${bsub_off(buf_next, idx)}], _bv; - } -% else: +% endif +% else: { - .reg .${pftype} _bv; - ld.global.cg.${pftype} _bv, [b_base + ${ldb*kx*dwidth_i}]; - st.shared.${pftype} [bsub_thread + ${bsub_off(buf_next, idx)}], _bv; + .reg .${pftype} _bv; +% if n is None: + .reg .u32 _boff; + .reg .u64 _bptr; + mul.lo.u32 _boff, ldb, ${kx}; + mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; + ld.weak.global.cg.${pftype} _bv, [_bptr]; +% else: + ld.weak.global.cg.${pftype} _bv, [b_base + ${ldb*kx*dwidth_i}]; +% endif + st.shared.${pftype} [bsub_thread + ${bsub_off(buf_next, idx)}], _bv; } -% endif -% endif -% endif -% endfor -% if use_cpasync: - cp.async.commit_group; -% endif % endif +% endfor +% if use_cpasync: + cp.async.commit_group; +% endif +% endif -% for idx, kx in enumerate(bchunks[bb]): +% for idx, kx in enumerate(bchunks[bb]): ld.shared.${pftype} bv, [bsub_thread + ${bsub_off(buf_cur, idx)}]; -% for j, row_j in enumerate(mcx): -<% jx = A[row_j, kx] %> -% if jx != 0 and kx == afix[row_j]: +% for j, row_j in enumerate(mcx): +<% jx = A[row_j, kx] %> +% if jx != 0 and kx == afix[row_j]: mul.${pftype} csub${j}, bv, ${jx}; -% elif jx != 0: +% elif jx != 0: fma.rn.${pftype} csub${j}, bv, ${jx}, csub${j}; -% endif -% if kx == alix[row_j]: -% if beta == 0: -% if n is None: +% endif +% if kx == alix[row_j]: +% if beta_zero: +% if n is None: { - .reg .u32 _coff; - .reg .u64 _cptr; - mul.lo.u32 _coff, ldc, ${row_j}; - mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; - st.weak.global.cg.${pftype} [_cptr], csub${j}; + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${row_j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + st.weak.global.cg.${pftype} [_cptr], csub${j}; } -% else: +% else: st.weak.global.cg.${pftype} [c_base + ${ldc*row_j*dwidth_i}], csub${j}; -% endif -% else: +% endif +% else: { - .reg .${pftype} _ctmp; -% if n is None: - .reg .u32 _coff; - .reg .u64 _cptr; - mul.lo.u32 _coff, ldc, ${row_j}; - mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; - ld.global.${pftype} _ctmp, [_cptr]; - fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, csub${j}; - st.global.${pftype} [_cptr], _ctmp; -% else: - ld.global.${pftype} _ctmp, [c_base + ${ldc*row_j*dwidth_i}]; - fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, csub${j}; - st.global.${pftype} [c_base + ${ldc*row_j*dwidth_i}], _ctmp; -% endif + .reg .${pftype} _ctmp; +% if n is None: + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${row_j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + ld.weak.global.cg.${pftype} _ctmp, [_cptr]; + fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, csub${j}; + st.weak.global.${pftype} [_cptr], _ctmp; +% else: + ld.weak.global.cg.${pftype} _ctmp, [c_base + ${ldc*row_j*dwidth_i}]; + fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, csub${j}; + st.weak.global.${pftype} [c_base + ${ldc*row_j*dwidth_i}], _ctmp; +% endif } -% endif -% endif -% endfor -% endfor -% if use_cpasync: -% if not is_last: - cp.async.wait_all; +% endif % endif -% endif - bar.sync 0; +% endfor % endfor +% if use_cpasync: +% if not loop.last: + cp.async.wait_all; +% endif +% endif + bar.sync 0; +% endfor ## End of Main loop over B-chunks ## Handle zero rows in this cid's group -% if has_zero_rows: -% for row_j in mcx: -% if afix[row_j] == -1: -% if beta == 0: +% if has_zero_rows: +% for row_j in mcx: +% if afix[row_j] == -1: +% if beta_zero: { - .reg .${pftype} _tmp; - mov.${pftype} _tmp, ${fzero}; -% if n is None: - .reg .u32 _coff; - .reg .u64 _cptr; - mul.lo.u32 _coff, ldc, ${row_j}; - mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; - st.weak.global.cg.${pftype} [_cptr], _tmp; -% else: - st.weak.global.cg.${pftype} [c_base + ${ldc*row_j*dwidth_i}], _tmp; -% endif + .reg .${pftype} _tmp; + mov.${pftype} _tmp, ${fzero}; +% if n is None: + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${row_j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + st.weak.global.cg.${pftype} [_cptr], _tmp; +% else: + st.weak.global.cg.${pftype} [c_base + ${ldc*row_j*dwidth_i}], _tmp; +% endif } -% elif beta != 1: +% elif beta != 1: { - .reg .${pftype} _tmp; -% if n is None: - .reg .u32 _coff; - .reg .u64 _cptr; - mul.lo.u32 _coff, ldc, ${row_j}; - mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; - ld.global.${pftype} _tmp, [_cptr]; - mul.${pftype} _tmp, _tmp, ${float(beta)}; - st.global.${pftype} [_cptr], _tmp; -% else: - ld.global.${pftype} _tmp, [c_base + ${ldc*row_j*dwidth_i}]; - mul.${pftype} _tmp, _tmp, ${float(beta)}; - st.global.${pftype} [c_base + ${ldc*row_j*dwidth_i}], _tmp; -% endif + .reg .${pftype} _tmp; +% if n is None: + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${row_j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + ld.weak.global.cg.${pftype} _tmp, [_cptr]; + mul.${pftype} _tmp, _tmp, ${float(beta)}; + st.weak.global.${pftype} [_cptr], _tmp; +% else: + ld.weak.global.cg.${pftype} _tmp, [c_base + ${ldc*row_j*dwidth_i}]; + mul.${pftype} _tmp, _tmp, ${float(beta)}; + st.weak.global.${pftype} [c_base + ${ldc*row_j*dwidth_i}], _tmp; +% endif } -% endif -% endif -% endfor -% endif +% endif +% endif +% endfor +% endif $L_END_CID_${cid}: % endfor diff --git a/gimmik/kernels/ptx/bstream.mako b/gimmik/kernels/ptx/bstream.mako index f58e9b3..24b0acb 100644 --- a/gimmik/kernels/ptx/bstream.mako +++ b/gimmik/kernels/ptx/bstream.mako @@ -1,14 +1,5 @@ <%inherit file='base'/> -<% -pftype = 'f32' if dtype == 'float' else 'f64' -dwidth_i = 4 if dtype == 'float' else 8 -fzero = '0f00000000' if dtype == 'float' else '0d0000000000000000' -has_zero_rows = any(jx == -1 for jx in afix) -bix_list = list(bix) -bix_pos = {kx: i for i, kx in enumerate(bix_list)} -%> - % if n is None: .visible .entry ${kname}(.param .u32 _n, .param .u64 _b, @@ -38,11 +29,11 @@ bix_pos = {kx: i for i, kx in enumerate(bix_list)} ld.param.u64 c, [_c]; { - .reg .u32 _grd<3>; - mov.u32 _grd0, %ntid.x; - mov.u32 _grd1, %ctaid.x; - mov.u32 _grd2, %tid.x; - mad.lo.u32 id, _grd0, _grd1, _grd2; + .reg .u32 _grd<3>; + mov.u32 _grd0, %ntid.x; + mov.u32 _grd1, %ctaid.x; + mov.u32 _grd2, %tid.x; + mad.lo.u32 id, _grd0, _grd1, _grd2; } setp.ge.u32 p1, id, n; @@ -52,117 +43,113 @@ bix_pos = {kx: i for i, kx in enumerate(bix_list)} cvta.to.global.u64 c, c; { - .reg .u64 _id64; - cvt.u64.u32 _id64, id; - mad.lo.u64 b_base, _id64, ${dwidth_i}, b; - mad.lo.u64 c_base, _id64, ${dwidth_i}, c; + .reg .u64 _id64; + cvt.u64.u32 _id64, id; + mad.lo.u64 b_base, _id64, ${dwidth_i}, b; + mad.lo.u64 c_base, _id64, ${dwidth_i}, c; } ## Batch-load active B columns % for i, kx in enumerate(bix_list): -% if n is None: +% if n is None: { - .reg .u32 _boff; - .reg .u64 _bptr; - mul.lo.u32 _boff, ldb, ${kx}; - mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; - ld.weak.global.cg.${pftype} bv${i}, [_bptr]; + .reg .u32 _boff; + .reg .u64 _bptr; + mul.lo.u32 _boff, ldb, ${kx}; + mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; + ld.weak.global.cg.${pftype} bv${i}, [_bptr]; } -% else: +% else: ld.weak.global.cg.${pftype} bv${i}, [b_base + ${ldb*kx*dwidth_i}]; -% endif +% endif % endfor -% if beta != 0: +% if not beta_zero: ## Pre-load C so per-row completion is a plain store % for j in range(m): % if afix[j] != -1: -% if n is None: +% if n is None: { - .reg .u32 _coff; - .reg .u64 _cptr; - mul.lo.u32 _coff, ldc, ${j}; - mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; - ld.weak.global.cg.${pftype} csub${j}, [_cptr]; + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + ld.weak.global.cg.${pftype} csub${j}, [_cptr]; } -% else: +% else: ld.weak.global.cg.${pftype} csub${j}, [c_base + ${ldc*j*dwidth_i}]; -% endif +% endif % endif % endfor -% if beta != 0 and beta != 1: % for j in range(m): % if afix[j] != -1: mul.${pftype} csub${j}, csub${j}, ${float(beta)}; % endif % endfor % endif -% endif ## Main compute % for kx in bix_list: -% for j, jx in enumerate(A[:, kx]): -% if jx != 0: -% if preload_c: - fma.rn.${pftype} csub${j}, bv${bix_pos[kx]}, ${jx}, csub${j}; -% elif kx == afix[j]: +% for j, jx in enumerate(A[:, kx]): +% if jx != 0: +% if beta_zero and kx == afix[j]: mul.${pftype} csub${j}, bv${bix_pos[kx]}, ${jx}; -% else: +% else: fma.rn.${pftype} csub${j}, bv${bix_pos[kx]}, ${jx}, csub${j}; -% endif % endif -% if kx == alix[j]: -% if n is None: +% endif +% if kx == alix[j]: +% if n is None: { - .reg .u32 _coff; - .reg .u64 _cptr; - mul.lo.u32 _coff, ldc, ${j}; - mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; - st.weak.global.cg.${pftype} [_cptr], csub${j}; + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + st.weak.global.cg.${pftype} [_cptr], csub${j}; } -% else: +% else: st.weak.global.cg.${pftype} [c_base + ${ldc*j*dwidth_i}], csub${j}; -% endif +% endif -% endif -% endfor +% endif +% endfor % endfor % if has_zero_rows: { - .reg .${pftype} _tmp; - mov.${pftype} _tmp, ${fzero}; + .reg .${pftype} _tmp; + mov.${pftype} _tmp, ${fzero}; % for j, jx in enumerate(afix): -% if jx == -1 and beta == 0: -% if n is None: - { - .reg .u32 _coff; - .reg .u64 _cptr; - mul.lo.u32 _coff, ldc, ${j}; - mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; - st.weak.global.cg.${pftype} [_cptr], _tmp; - } -% else: - st.weak.global.cg.${pftype} [c_base + ${ldc*j*dwidth_i}], _tmp; -% endif +% if jx == -1 and beta_zero: +% if n is None: + { + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + st.weak.global.cg.${pftype} [_cptr], _tmp; + } +% else: + st.weak.global.cg.${pftype} [c_base + ${ldc*j*dwidth_i}], _tmp; +% endif -% elif jx == -1: -% if n is None: - { - .reg .u32 _coff; - .reg .u64 _cptr; - mul.lo.u32 _coff, ldc, ${j}; - mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; - ld.global.${pftype} _tmp, [_cptr]; - mul.${pftype} _tmp, _tmp, ${float(beta)}; - st.global.${pftype} [_cptr], _tmp; - } -% else: - ld.global.${pftype} _tmp, [c_base + ${ldc*j*dwidth_i}]; - mul.${pftype} _tmp, _tmp, ${float(beta)}; - st.global.${pftype} [c_base + ${ldc*j*dwidth_i}], _tmp; -% endif +% elif jx == -1: +% if n is None: + { + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + ld.weak.global.cg.${pftype} _tmp, [_cptr]; + mul.${pftype} _tmp, _tmp, ${float(beta)}; + st.weak.global.cg.${pftype} [_cptr], _tmp; + } +% else: + ld.weak.global.cg.${pftype} _tmp, [c_base + ${ldc*j*dwidth_i}]; + mul.${pftype} _tmp, _tmp, ${float(beta)}; + st.weak.global.cg.${pftype} [c_base + ${ldc*j*dwidth_i}], _tmp; % endif +% endif % endfor } % endif diff --git a/gimmik/kernels/ptx/cstream-ksplit.mako b/gimmik/kernels/ptx/cstream-ksplit.mako index 1ba2491..5d704de 100644 --- a/gimmik/kernels/ptx/cstream-ksplit.mako +++ b/gimmik/kernels/ptx/cstream-ksplit.mako @@ -1,9 +1,6 @@ <%inherit file='base'/> <% -pftype = 'f32' if dtype == 'float' else 'f64' -dwidth_i = 4 if dtype == 'float' else 8 -fzero = '0f00000000' if dtype == 'float' else '0d0000000000000000' kparts = partition(A, ksplit, by='cols') cchunks = chunk(list(range(m)), csz) cv_per_thread = -(-csz // ksplit) @@ -41,11 +38,11 @@ csub_bytes = (ksplit - 1) * csz * blockx * dwidth_i ld.param.u64 c, [_c]; { - .reg .u32 _ctaid_x; - mov.u32 _ctaid_x, %ctaid.x; - mov.u32 tid_x, %tid.x; - mov.u32 tid_y, %tid.y; - mad.lo.u32 id, _ctaid_x, ${blockx}, tid_x; + .reg .u32 _ctaid_x; + mov.u32 _ctaid_x, %ctaid.x; + mov.u32 tid_x, %tid.x; + mov.u32 tid_y, %tid.y; + mad.lo.u32 id, _ctaid_x, ${blockx}, tid_x; } setp.ge.u32 p1, id, n; @@ -55,17 +52,17 @@ csub_bytes = (ksplit - 1) * csz * blockx * dwidth_i cvta.to.global.u64 c, c; { - .reg .u64 _id64; - cvt.u64.u32 _id64, id; - mad.lo.u64 b_base, _id64, ${dwidth_i}, b; - mad.lo.u64 c_base, _id64, ${dwidth_i}, c; + .reg .u64 _id64; + cvt.u64.u32 _id64, id; + mad.lo.u64 b_base, _id64, ${dwidth_i}, b; + mad.lo.u64 c_base, _id64, ${dwidth_i}, c; } { - .reg .u64 _tx_off; - mul.wide.u32 _tx_off, tid_x, ${dwidth_i}; - mov.u64 csub_thread, _csub; - add.u64 csub_thread, csub_thread, _tx_off; + .reg .u64 _tx_off; + mul.wide.u32 _tx_off, tid_x, ${dwidth_i}; + mov.u64 csub_thread, _csub; + add.u64 csub_thread, csub_thread, _tx_off; } % for bid, kbx in enumerate(kparts): @@ -78,98 +75,98 @@ csub_bytes = (ksplit - 1) * csz * blockx * dwidth_i kbx_idx = {kx: i for i, kx in enumerate(kbx)} %> -% for cchunk_i, cchunk in enumerate(cchunks): +% for cchunk_i, cchunk in enumerate(cchunks): ## Chunk ${cchunk_i}: partial dot-product -% for row_idx, j in enumerate(cchunk): +% for row_idx, j in enumerate(cchunk): <% nz = [(kbx_idx[kx], kx, A[j, kx]) for kx in kbx if A[j, kx] != 0] owner_bid = row_idx % ksplit %> -% for (kxi, kx, jx) in nz: -% if kx not in loaded: -% if n is None: +% for (kxi, kx, jx) in nz: +% if kx not in loaded: +% if n is None: { - .reg .u32 _boff; - .reg .u64 _bptr; - mul.lo.u32 _boff, ldb, ${kx}; - mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; - ld.global.nc.${pftype} bv${kxi}, [_bptr]; + .reg .u32 _boff; + .reg .u64 _bptr; + mul.lo.u32 _boff, ldb, ${kx}; + mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; + ld.weak.global.cg.${pftype} bv${kxi}, [_bptr]; } -% else: - ld.global.nc.${pftype} bv${kxi}, [b_base + ${ldb*kx*dwidth_i}]; -% endif +% else: + ld.weak.global.cg.${pftype} bv${kxi}, [b_base + ${ldb*kx*dwidth_i}]; +% endif <% loaded.add(kx) %> -% endif -% endfor -% if nz: -% for i, (kxi, kx, jx) in enumerate(nz): -% if i == 0: +% endif +% endfor +% if nz: +% for kxi, kx, jx in nz: +% if loop.first: mul.${pftype} dotp, bv${kxi}, ${jx}; -% else: +% else: fma.rn.${pftype} dotp, bv${kxi}, ${jx}, dotp; -% endif -% endfor -% else: +% endif +% endfor +% else: mov.${pftype} dotp, ${fzero}; -% endif -% if owner_bid == bid: +% endif +% if owner_bid == bid: mov.${pftype} cv${row_idx // ksplit}, dotp; -% else: +% else: <% csub_idx = bid - (1 if bid > owner_bid else 0) %> st.shared.${pftype} [csub_thread + ${(csub_idx * csz + row_idx) * blockx * dwidth_i}], dotp; -% endif -% endfor +% endif +% endfor bar.sync 0; ## Combine phase (owned rows only) -% for row_idx, j in enumerate(cchunk): -% if row_idx % ksplit == bid: +% for row_idx, j in enumerate(cchunk): +% if row_idx % ksplit == bid: mov.${pftype} dotp, cv${row_idx // ksplit}; -% for other_bid in range(ksplit): -% if other_bid != bid: +% for other_bid in range(ksplit): +% if other_bid != bid: <% csub_idx = other_bid - (1 if other_bid > (row_idx % ksplit) else 0) %> { - .reg .${pftype} _tmp; - ld.shared.${pftype} _tmp, [csub_thread + ${(csub_idx * csz + row_idx) * blockx * dwidth_i}]; - add.${pftype} dotp, dotp, _tmp; + .reg .${pftype} _tmp; + ld.shared.${pftype} _tmp, [csub_thread + ${(csub_idx * csz + row_idx) * blockx * dwidth_i}]; + add.${pftype} dotp, dotp, _tmp; } -% endif -% endfor -% if beta == 0: -% if n is None: +% endif +% endfor +% if beta_zero: +% if n is None: { - .reg .u32 _coff; - .reg .u64 _cptr; - mul.lo.u32 _coff, ldc, ${j}; - mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; - st.weak.global.cg.${pftype} [_cptr], dotp; + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + st.weak.global.cg.${pftype} [_cptr], dotp; } -% else: +% else: st.weak.global.cg.${pftype} [c_base + ${ldc*j*dwidth_i}], dotp; -% endif -% else: +% endif +% else: { - .reg .${pftype} _ctmp; -% if n is None: - .reg .u32 _coff; - .reg .u64 _cptr; - mul.lo.u32 _coff, ldc, ${j}; - mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; - ld.global.${pftype} _ctmp, [_cptr]; - fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, dotp; - st.global.${pftype} [_cptr], _ctmp; -% else: - ld.global.${pftype} _ctmp, [c_base + ${ldc*j*dwidth_i}]; - fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, dotp; - st.global.${pftype} [c_base + ${ldc*j*dwidth_i}], _ctmp; -% endif + .reg .${pftype} _ctmp; +% if n is None: + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + ld.weak.global.cg.${pftype} _ctmp, [_cptr]; + fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, dotp; + st.weak.global.${pftype} [_cptr], _ctmp; +% else: + ld.weak.global.cg.${pftype} _ctmp, [c_base + ${ldc*j*dwidth_i}]; + fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, dotp; + st.weak.global.${pftype} [c_base + ${ldc*j*dwidth_i}], _ctmp; +% endif } -% endif +% endif -% endif -% endfor - bar.sync 0; +% endif % endfor + bar.sync 0; +% endfor $L_END_BID_${bid}: % endfor diff --git a/gimmik/kernels/ptx/cstream-w2.mako b/gimmik/kernels/ptx/cstream-w2.mako index c82ebab..e6b4d75 100644 --- a/gimmik/kernels/ptx/cstream-w2.mako +++ b/gimmik/kernels/ptx/cstream-w2.mako @@ -1,15 +1,5 @@ <%inherit file='base'/> -<% -pftype = 'f64' -dwidth_i = 8 -fzero = '0d0000000000000000' -bix_list = list(bix) -bix_pos = {kx: i for i, kx in enumerate(bix_list)} -row_nz = [[(kx, A[j, kx]) for kx in range(k) if A[j, kx] != 0] - for j in range(m)] -%> - .visible .entry ${kname}(.param .u64 _b, .param .u64 _c) { @@ -23,10 +13,10 @@ row_nz = [[(kx, A[j, kx]) for kx in range(k) if A[j, kx] != 0] ld.param.u64 c, [_c]; { - .reg .u32 _ctaid_x, _tid_x; - mov.u32 _ctaid_x, %ctaid.x; - mov.u32 _tid_x, %tid.x; - mad.lo.u32 id, _ctaid_x, ${blockx}, _tid_x; + .reg .u32 _ctaid_x, _tid_x; + mov.u32 _ctaid_x, %ctaid.x; + mov.u32 _tid_x, %tid.x; + mad.lo.u32 id, _ctaid_x, ${blockx}, _tid_x; } setp.ge.u32 p1, id, n; @@ -36,59 +26,59 @@ row_nz = [[(kx, A[j, kx]) for kx in range(k) if A[j, kx] != 0] cvta.to.global.u64 c, c; { - .reg .u64 _id64; - cvt.u64.u32 _id64, id; - mad.lo.u64 b_base, _id64, 16, b; - mad.lo.u64 c_base, _id64, 16, c; + .reg .u64 _id64; + cvt.u64.u32 _id64, id; + mad.lo.u64 b_base, _id64, 16, b; + mad.lo.u64 c_base, _id64, 16, c; } ## Batch-load B column pairs % for i, kx in enumerate(bix_list): - ld.global.nc.v2.f64 {bv_a${i}, bv_b${i}}, [b_base + ${ldb*kx*dwidth_i}]; + ld.weak.global.cg.v2.f64 {bv_a${i}, bv_b${i}}, [b_base + ${ldb*kx*dwidth_i}]; % endfor ## Main compute: two parallel dot-product streams per thread % for j in range(m): -% if row_nz[j]: -% for i_nz, (kx, jx) in enumerate(row_nz[j]): -% if i_nz == 0: +% if row_nz[j]: +% for kx, jx in row_nz[j]: +% if loop.first: mul.f64 dotp_a, bv_a${bix_pos[kx]}, ${jx}; mul.f64 dotp_b, bv_b${bix_pos[kx]}, ${jx}; -% else: +% else: fma.rn.f64 dotp_a, bv_a${bix_pos[kx]}, ${jx}, dotp_a; fma.rn.f64 dotp_b, bv_b${bix_pos[kx]}, ${jx}, dotp_b; -% endif -% endfor -% if beta == 0: +% endif +% endfor +% if beta_zero: st.weak.global.cg.v2.f64 [c_base + ${ldc*j*dwidth_i}], {dotp_a, dotp_b}; -% else: +% else: { - .reg .f64 _ca, _cb; - ld.global.v2.f64 {_ca, _cb}, [c_base + ${ldc*j*dwidth_i}]; - fma.rn.f64 _ca, _ca, ${float(beta)}, dotp_a; - fma.rn.f64 _cb, _cb, ${float(beta)}, dotp_b; - st.global.v2.f64 [c_base + ${ldc*j*dwidth_i}], {_ca, _cb}; + .reg .f64 _ca, _cb; + ld.weak.global.cg.v2.f64 {_ca, _cb}, [c_base + ${ldc*j*dwidth_i}]; + fma.rn.f64 _ca, _ca, ${float(beta)}, dotp_a; + fma.rn.f64 _cb, _cb, ${float(beta)}, dotp_b; + st.weak.global.v2.f64 [c_base + ${ldc*j*dwidth_i}], {_ca, _cb}; } -% endif +% endif -% else: +% else: ## Zero row of A -% if beta == 0: +% if beta_zero: { - .reg .f64 _z; - mov.f64 _z, ${fzero}; - st.weak.global.cg.v2.f64 [c_base + ${ldc*j*dwidth_i}], {_z, _z}; + .reg .f64 _z; + mov.f64 _z, ${fzero}; + st.weak.global.cg.v2.f64 [c_base + ${ldc*j*dwidth_i}], {_z, _z}; } -% elif beta != 1: +% elif beta != 1: { - .reg .f64 _ca, _cb; - ld.global.v2.f64 {_ca, _cb}, [c_base + ${ldc*j*dwidth_i}]; - mul.f64 _ca, _ca, ${float(beta)}; - mul.f64 _cb, _cb, ${float(beta)}; - st.global.v2.f64 [c_base + ${ldc*j*dwidth_i}], {_ca, _cb}; + .reg .f64 _ca, _cb; + ld.weak.global.cg.v2.f64 {_ca, _cb}, [c_base + ${ldc*j*dwidth_i}]; + mul.f64 _ca, _ca, ${float(beta)}; + mul.f64 _cb, _cb, ${float(beta)}; + st.weak.global.v2.f64 [c_base + ${ldc*j*dwidth_i}], {_ca, _cb}; } -% endif % endif +% endif % endfor $L_EXIT: diff --git a/gimmik/kernels/ptx/cstream.mako b/gimmik/kernels/ptx/cstream.mako index ec46934..726fe46 100644 --- a/gimmik/kernels/ptx/cstream.mako +++ b/gimmik/kernels/ptx/cstream.mako @@ -1,15 +1,5 @@ <%inherit file='base'/> -<% -pftype = 'f32' if dtype == 'float' else 'f64' -dwidth_i = 4 if dtype == 'float' else 8 -fzero = '0f00000000' if dtype == 'float' else '0d0000000000000000' -bix_list = list(bix) -bix_pos = {kx: i for i, kx in enumerate(bix_list)} -row_nz = [[(kx, A[j, kx]) for kx in range(k) if A[j, kx] != 0] - for j in range(m)] -%> - % if n is None: .visible .entry ${kname}(.param .u32 _n, .param .u64 _b, @@ -39,11 +29,11 @@ row_nz = [[(kx, A[j, kx]) for kx in range(k) if A[j, kx] != 0] ld.param.u64 c, [_c]; { - .reg .u32 _grd<3>; - mov.u32 _grd0, %ntid.x; - mov.u32 _grd1, %ctaid.x; - mov.u32 _grd2, %tid.x; - mad.lo.u32 id, _grd0, _grd1, _grd2; + .reg .u32 _grd<3>; + mov.u32 _grd0, %ntid.x; + mov.u32 _grd1, %ctaid.x; + mov.u32 _grd2, %tid.x; + mad.lo.u32 id, _grd0, _grd1, _grd2; } setp.ge.u32 p1, id, n; @@ -53,103 +43,103 @@ row_nz = [[(kx, A[j, kx]) for kx in range(k) if A[j, kx] != 0] cvta.to.global.u64 c, c; { - .reg .u64 _id64; - cvt.u64.u32 _id64, id; - mad.lo.u64 b_base, _id64, ${dwidth_i}, b; - mad.lo.u64 c_base, _id64, ${dwidth_i}, c; + .reg .u64 _id64; + cvt.u64.u32 _id64, id; + mad.lo.u64 b_base, _id64, ${dwidth_i}, b; + mad.lo.u64 c_base, _id64, ${dwidth_i}, c; } ## Batch-load active B columns % for i, kx in enumerate(bix_list): -% if n is None: +% if n is None: { - .reg .u32 _boff; - .reg .u64 _bptr; - mul.lo.u32 _boff, ldb, ${kx}; - mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; - ld.global.nc.${pftype} bv${i}, [_bptr]; + .reg .u32 _boff; + .reg .u64 _bptr; + mul.lo.u32 _boff, ldb, ${kx}; + mad.wide.u32 _bptr, ${dwidth_i}, _boff, b_base; + ld.weak.global.cg.${pftype} bv${i}, [_bptr]; } -% else: - ld.global.nc.${pftype} bv${i}, [b_base + ${ldb*kx*dwidth_i}]; -% endif +% else: + ld.weak.global.cg.${pftype} bv${i}, [b_base + ${ldb*kx*dwidth_i}]; +% endif % endfor ## Compute and store each output row % for j in range(m): -% if row_nz[j]: -% for i_nz, (kx, jx) in enumerate(row_nz[j]): -% if i_nz == 0: +% if row_nz[j]: +% for kx, jx in row_nz[j]: +% if loop.first: mul.${pftype} dotp, bv${bix_pos[kx]}, ${jx}; -% else: +% else: fma.rn.${pftype} dotp, bv${bix_pos[kx]}, ${jx}, dotp; -% endif -% endfor -% if beta == 0: -% if n is None: +% endif +% endfor +% if beta_zero: +% if n is None: { - .reg .u32 _coff; - .reg .u64 _cptr; - mul.lo.u32 _coff, ldc, ${j}; - mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; - st.weak.global.cg.${pftype} [_cptr], dotp; + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + st.weak.global.cg.${pftype} [_cptr], dotp; } -% else: +% else: st.weak.global.cg.${pftype} [c_base + ${ldc*j*dwidth_i}], dotp; -% endif -% else: +% endif +% else: { - .reg .${pftype} _ctmp; -% if n is None: - .reg .u32 _coff; - .reg .u64 _cptr; - mul.lo.u32 _coff, ldc, ${j}; - mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; - ld.global.${pftype} _ctmp, [_cptr]; - fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, dotp; - st.global.${pftype} [_cptr], _ctmp; -% else: - ld.global.${pftype} _ctmp, [c_base + ${ldc*j*dwidth_i}]; - fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, dotp; - st.global.${pftype} [c_base + ${ldc*j*dwidth_i}], _ctmp; -% endif + .reg .${pftype} _ctmp; +% if n is None: + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + ld.weak.global.cg.${pftype} _ctmp, [_cptr]; + fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, dotp; + st.weak.global.${pftype} [_cptr], _ctmp; +% else: + ld.weak.global.cg.${pftype} _ctmp, [c_base + ${ldc*j*dwidth_i}]; + fma.rn.${pftype} _ctmp, _ctmp, ${float(beta)}, dotp; + st.weak.global.${pftype} [c_base + ${ldc*j*dwidth_i}], _ctmp; +% endif } -% endif +% endif -% else: +% else: ## Zero row of A -% if beta == 0: +% if beta_zero: { - .reg .${pftype} _tmp; - mov.${pftype} _tmp, ${fzero}; -% if n is None: - .reg .u32 _coff; - .reg .u64 _cptr; - mul.lo.u32 _coff, ldc, ${j}; - mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; - st.weak.global.cg.${pftype} [_cptr], _tmp; -% else: - st.weak.global.cg.${pftype} [c_base + ${ldc*j*dwidth_i}], _tmp; -% endif + .reg .${pftype} _tmp; + mov.${pftype} _tmp, ${fzero}; +% if n is None: + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + st.weak.global.cg.${pftype} [_cptr], _tmp; +% else: + st.weak.global.cg.${pftype} [c_base + ${ldc*j*dwidth_i}], _tmp; +% endif } -% elif beta != 1: +% elif beta != 1: { - .reg .${pftype} _tmp; -% if n is None: - .reg .u32 _coff; - .reg .u64 _cptr; - mul.lo.u32 _coff, ldc, ${j}; - mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; - ld.global.${pftype} _tmp, [_cptr]; - mul.${pftype} _tmp, _tmp, ${float(beta)}; - st.global.${pftype} [_cptr], _tmp; -% else: - ld.global.${pftype} _tmp, [c_base + ${ldc*j*dwidth_i}]; - mul.${pftype} _tmp, _tmp, ${float(beta)}; - st.global.${pftype} [c_base + ${ldc*j*dwidth_i}], _tmp; -% endif + .reg .${pftype} _tmp; +% if n is None: + .reg .u32 _coff; + .reg .u64 _cptr; + mul.lo.u32 _coff, ldc, ${j}; + mad.wide.u32 _cptr, ${dwidth_i}, _coff, c_base; + ld.weak.global.cg.${pftype} _tmp, [_cptr]; + mul.${pftype} _tmp, _tmp, ${float(beta)}; + st.weak.global.${pftype} [_cptr], _tmp; +% else: + ld.weak.global.cg.${pftype} _tmp, [c_base + ${ldc*j*dwidth_i}]; + mul.${pftype} _tmp, _tmp, ${float(beta)}; + st.weak.global.${pftype} [c_base + ${ldc*j*dwidth_i}], _tmp; +% endif } -% endif % endif +% endif % endfor $L_EXIT: diff --git a/gimmik/kernels/ptx/dense-mma-gAd.mako b/gimmik/kernels/ptx/dense-mma-gAd.mako index ce8066d..8933e51 100644 --- a/gimmik/kernels/ptx/dense-mma-gAd.mako +++ b/gimmik/kernels/ptx/dense-mma-gAd.mako @@ -1,9 +1,5 @@ <%inherit file='base'/> -<% -fzero = '0d0000000000000000' -%> - .global .align 16 .b64 ${kname}_Ag[${a_elems}] = { ${', '.join(a_u64)} }; @@ -16,14 +12,14 @@ fzero = '0d0000000000000000' .reg .u32 warp_n_base; .reg .u64 ag_thr_base, b_thr_base, c_thr_base; .reg .pred pwarp_exit; - .reg .f64 a_frag; + .reg .${pftype} a_frag; % for nt in range(nn): .reg .u32 b_col_${nt}, c_col0_${nt}, c_col1_${nt}; -% if not n_col_aligned: +% if not n_col_aligned: .reg .pred pvalid_bcol_${nt}, pvalid_c0col_${nt}, pvalid_c1col_${nt}; -% endif - .reg .f64 b_frag_${nt}; - .reg .f64 c0_${nt}_<${m_tiles}>, c1_${nt}_<${m_tiles}>; +% endif + .reg .${pftype} b_frag_${nt}; + .reg .${pftype} c0_${nt}_<${m_tiles}>, c1_${nt}_<${m_tiles}>; % endfor ld.param.u64 b_ptr, [_b]; @@ -57,11 +53,11 @@ fzero = '0d0000000000000000' add.u32 c_col0_${nt}, c_col0_${nt}, t; add.u32 c_col1_${nt}, c_col0_${nt}, 1; } -% if not n_col_aligned: +% if not n_col_aligned: setp.lt.u32 pvalid_bcol_${nt}, b_col_${nt}, ${n}; setp.lt.u32 pvalid_c0col_${nt}, c_col0_${nt}, ${n}; setp.lt.u32 pvalid_c1col_${nt}, c_col1_${nt}, ${n}; -% endif +% endif % endfor // A thread base: &Ag[0] + lane*8 @@ -93,22 +89,22 @@ fzero = '0d0000000000000000' } % for mt in range(m_tiles): -% if pm_runtime(mt): +% if pm_runtime(mt): .reg .pred pm_${mt}; { .reg .u32 crow; add.u32 crow, r_div4, ${mt * 8}; setp.lt.u32 pm_${mt}, crow, ${m}; } -% endif +% endif % endfor % for nt in range(nn): -% for mt in range(m_tiles): -% if beta == 0: - mov.f64 c0_${nt}_${mt}, ${fzero}; - mov.f64 c1_${nt}_${mt}, ${fzero}; -% else: +% for mt in range(m_tiles): +% if beta_zero: + mov.${pftype} c0_${nt}_${mt}, ${fzero}; + mov.${pftype} c1_${nt}_${mt}, ${fzero}; +% else: <% pm = f'pm_{mt}' if pm_runtime(mt) else None pvc0 = f'pvalid_c0col_{nt}' if not n_col_aligned else None @@ -118,56 +114,56 @@ fzero = '0d0000000000000000' { .reg .u64 caddr; add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; -% if needs_zero_init: - mov.f64 c0_${nt}_${mt}, ${fzero}; - mov.f64 c1_${nt}_${mt}, ${fzero}; -% endif - ${pred_emit(f'ld.global.f64 c0_{nt}_{mt}, [caddr];', pm, pvc0, pred_reg=f'p0_{nt}_{mt}')} - ${pred_emit(f'ld.global.f64 c1_{nt}_{mt}, [caddr + 8];', pm, pvc1, pred_reg=f'p1_{nt}_{mt}')} +% if needs_zero_init: + mov.${pftype} c0_${nt}_${mt}, ${fzero}; + mov.${pftype} c1_${nt}_${mt}, ${fzero}; +% endif + ${pred_emit(f'ld.weak.global.cg.{pftype} c0_{nt}_{mt}, [caddr];', pm, pvc0, pred_reg=f'p0_{nt}_{mt}')} + ${pred_emit(f'ld.weak.global.cg.{pftype} c1_{nt}_{mt}, [caddr + {dwidth_i}];', pm, pvc1, pred_reg=f'p1_{nt}_{mt}')} } -% endif -% endfor +% endif +% endfor % endfor % for ki in range(k_iters): -% for nt in range(nn): +% for nt in range(nn): <% pvb = f'pvalid_bcol_{nt}' if not n_col_aligned else None - k_tail = (k_rem != 0 and ki == k_iters - 1) + k_tail = (k_rem != 0 and loop.parent.last) needs_zero = pvb is not None or k_tail pbrow = 'pbrow' if k_tail else None %> { .reg .u64 baddr; add.u64 baddr, b_thr_base, ${ki * b_kiter_stride + nt * b_ntile_stride}; -% if needs_zero: - mov.f64 b_frag_${nt}, ${fzero}; -% endif -% if k_tail: +% if needs_zero: + mov.${pftype} b_frag_${nt}, ${fzero}; +% endif +% if k_tail: .reg .pred pbrow; { .reg .u32 brow; add.u32 brow, r_mod4, ${ki * 4}; setp.lt.u32 pbrow, brow, ${k}; } -% endif - ${pred_emit(f'ld.global.nc.f64 b_frag_{nt}, [baddr];', pbrow, pvb, pred_reg=f'pb_{ki}_{nt}')} +% endif + ${pred_emit(f'ld.weak.global.cg.{pftype} b_frag_{nt}, [baddr];', pbrow, pvb, pred_reg=f'pb_{ki}_{nt}')} } -% endfor -% for mt in range(m_tiles): - ld.global.nc.f64 a_frag, [ag_thr_base + ${(mt * k_iters + ki) * frag_stride_bytes}]; -% for nt in range(nn): - mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64 +% endfor +% for mt in range(m_tiles): + ld.weak.global.${pftype} a_frag, [ag_thr_base + ${(mt * k_iters + ki) * frag_stride_bytes}]; +% for nt in range(nn): + mma.sync.aligned.m8n8k4.row.col.${pftype}.${pftype}.${pftype}.${pftype} {c0_${nt}_${mt}, c1_${nt}_${mt}}, {a_frag}, {b_frag_${nt}}, {c0_${nt}_${mt}, c1_${nt}_${mt}}; -% endfor -% endfor +% endfor +% endfor % endfor % for nt in range(nn): -% for mt in range(m_tiles): +% for mt in range(m_tiles): <% pm = f'pm_{mt}' if pm_runtime(mt) else None pvc0 = f'pvalid_c0col_{nt}' if not n_col_aligned else None @@ -176,10 +172,10 @@ fzero = '0d0000000000000000' { .reg .u64 caddr; add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; - ${pred_emit(f'st.global.f64 [caddr], c0_{nt}_{mt};', pm, pvc0, pred_reg=f'p0s_{nt}_{mt}')} - ${pred_emit(f'st.global.f64 [caddr + 8], c1_{nt}_{mt};', pm, pvc1, pred_reg=f'p1s_{nt}_{mt}')} + ${pred_emit(f'st.weak.global.{pftype} [caddr], c0_{nt}_{mt};', pm, pvc0, pred_reg=f'p0s_{nt}_{mt}')} + ${pred_emit(f'st.weak.global.{pftype} [caddr + {dwidth_i}], c1_{nt}_{mt};', pm, pvc1, pred_reg=f'p1s_{nt}_{mt}')} } -% endfor +% endfor % endfor $L_EXIT: diff --git a/gimmik/kernels/ptx/dense-mma-smem-gA.mako b/gimmik/kernels/ptx/dense-mma-smem-gA.mako index ec2f013..d1b72a8 100644 --- a/gimmik/kernels/ptx/dense-mma-smem-gA.mako +++ b/gimmik/kernels/ptx/dense-mma-smem-gA.mako @@ -26,7 +26,7 @@ bs = bool(block_stealing) .reg .u32 warp_n_base; .reg .u64 as_thr_base, b_thr_base, c_thr_base; .reg .pred pwarp_exit; - .reg .f64 a_frag; + .reg .${pftype} a_frag; % if bs: .reg .u32 ctaid; .reg .u32 mbar_a, work_a; @@ -34,11 +34,11 @@ bs = bool(block_stealing) % endif % for nt in range(nn): .reg .u32 b_col_${nt}, c_col0_${nt}, c_col1_${nt}; -% if not n_col_aligned: +% if not n_col_aligned: .reg .pred pvalid_bcol_${nt}, pvalid_c0col_${nt}, pvalid_c1col_${nt}; -% endif - .reg .f64 b_frag_${nt}; - .reg .f64 c0_${nt}_<${m_tiles}>, c1_${nt}_<${m_tiles}>; +% endif + .reg .${pftype} b_frag_${nt}; + .reg .${pftype} c0_${nt}_<${m_tiles}>, c1_${nt}_<${m_tiles}>; % endfor ld.param.u64 b_ptr, [_b]; @@ -69,30 +69,29 @@ bs = bool(block_stealing) % for ci in range(copy_v2_iters): <% base_pair = ci * blockx - is_last = ci == copy_v2_iters - 1 pairs_this = min(blockx, a_pairs - base_pair) %> { .reg .u32 pidx; .reg .u64 off64, gaddr, saddr; - .reg .f64 v0, v1; -% if is_last and pairs_this < blockx: + .reg .${pftype} v0, v1; +% if loop.last and pairs_this < blockx: .reg .pred plast; add.u32 pidx, tid, ${base_pair}; setp.lt.u32 plast, pidx, ${a_pairs}; - mul.wide.u32 off64, pidx, 16; + mul.wide.u32 off64, pidx, ${2 * dwidth_i}; add.u64 gaddr, a_glb_base, off64; add.u64 saddr, a_smem_base, off64; - @plast ld.global.nc.v2.f64 {v0, v1}, [gaddr]; - @plast st.shared.v2.f64 [saddr], {v0, v1}; -% else: + @plast ld.weak.global.cg.v2.${pftype} {v0, v1}, [gaddr]; + @plast st.shared.v2.${pftype} [saddr], {v0, v1}; +% else: add.u32 pidx, tid, ${base_pair}; - mul.wide.u32 off64, pidx, 16; + mul.wide.u32 off64, pidx, ${2 * dwidth_i}; add.u64 gaddr, a_glb_base, off64; add.u64 saddr, a_smem_base, off64; - ld.global.nc.v2.f64 {v0, v1}, [gaddr]; - st.shared.v2.f64 [saddr], {v0, v1}; -% endif + ld.weak.global.cg.v2.${pftype} {v0, v1}, [gaddr]; + st.shared.v2.${pftype} [saddr], {v0, v1}; +% endif } % endfor % if a_pairs_tail: @@ -100,12 +99,12 @@ bs = bool(block_stealing) { .reg .pred plast; .reg .u64 gaddr, saddr; - .reg .f64 v; + .reg .${pftype} v; setp.eq.u32 plast, tid, 0; - add.u64 gaddr, a_glb_base, ${(a_elems-1) * 8}; - add.u64 saddr, a_smem_base, ${(a_elems-1) * 8}; - @plast ld.global.nc.f64 v, [gaddr]; - @plast st.shared.f64 [saddr], v; + add.u64 gaddr, a_glb_base, ${(a_elems-1) * dwidth_i}; + add.u64 saddr, a_smem_base, ${(a_elems-1) * dwidth_i}; + @plast ld.weak.global.cg.${pftype} v, [gaddr]; + @plast st.shared.${pftype} [saddr], v; } % endif } @@ -121,14 +120,14 @@ bs = bool(block_stealing) } % for mt in range(m_tiles): -% if pm_runtime(mt): +% if pm_runtime(mt): .reg .pred pm_${mt}; { .reg .u32 crow; add.u32 crow, r_div4, ${mt * 8}; setp.lt.u32 pm_${mt}, crow, ${m}; } -% endif +% endif % endfor % if bs: @@ -164,11 +163,11 @@ $L_LOOP: add.u32 c_col0_${nt}, c_col0_${nt}, t; add.u32 c_col1_${nt}, c_col0_${nt}, 1; } -% if not n_col_aligned: +% if not n_col_aligned: setp.lt.u32 pvalid_bcol_${nt}, b_col_${nt}, ${n}; setp.lt.u32 pvalid_c0col_${nt}, c_col0_${nt}, ${n}; setp.lt.u32 pvalid_c1col_${nt}, c_col1_${nt}, ${n}; -% endif +% endif % endfor { @@ -190,11 +189,11 @@ $L_LOOP: } % for nt in range(nn): -% for mt in range(m_tiles): -% if beta == 0: - mov.f64 c0_${nt}_${mt}, 0d0000000000000000; - mov.f64 c1_${nt}_${mt}, 0d0000000000000000; -% else: +% for mt in range(m_tiles): +% if beta_zero: + mov.${pftype} c0_${nt}_${mt}, ${fzero}; + mov.${pftype} c1_${nt}_${mt}, ${fzero}; +% else: <% pm = f'pm_{mt}' if pm_runtime(mt) else None pvc0 = f'pvalid_c0col_{nt}' if not n_col_aligned else None @@ -204,56 +203,56 @@ $L_LOOP: { .reg .u64 caddr; add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; -% if needs_zero_init: - mov.f64 c0_${nt}_${mt}, 0d0000000000000000; - mov.f64 c1_${nt}_${mt}, 0d0000000000000000; -% endif - ${pred_emit(f'ld.global.f64 c0_{nt}_{mt}, [caddr];', pm, pvc0, pred_reg=f'p0_{nt}_{mt}')} - ${pred_emit(f'ld.global.f64 c1_{nt}_{mt}, [caddr + 8];', pm, pvc1, pred_reg=f'p1_{nt}_{mt}')} +% if needs_zero_init: + mov.${pftype} c0_${nt}_${mt}, ${fzero}; + mov.${pftype} c1_${nt}_${mt}, ${fzero}; +% endif + ${pred_emit(f'ld.weak.global.cg.{pftype} c0_{nt}_{mt}, [caddr];', pm, pvc0, pred_reg=f'p0_{nt}_{mt}')} + ${pred_emit(f'ld.weak.global.cg.{pftype} c1_{nt}_{mt}, [caddr + {dwidth_i}];', pm, pvc1, pred_reg=f'p1_{nt}_{mt}')} } -% endif -% endfor +% endif +% endfor % endfor % for ki in range(k_iters): -% for nt in range(nn): +% for nt in range(nn): <% pvb = f'pvalid_bcol_{nt}' if not n_col_aligned else None - k_tail = (k_rem != 0 and ki == k_iters - 1) + k_tail = (k_rem != 0 and loop.parent.last) needs_zero = pvb is not None or k_tail pbrow = 'pbrow' if k_tail else None %> { .reg .u64 baddr; add.u64 baddr, b_thr_base, ${ki * b_kiter_stride + nt * b_ntile_stride}; -% if needs_zero: - mov.f64 b_frag_${nt}, 0d0000000000000000; -% endif -% if k_tail: +% if needs_zero: + mov.${pftype} b_frag_${nt}, ${fzero}; +% endif +% if k_tail: .reg .pred pbrow; { .reg .u32 brow; add.u32 brow, r_mod4, ${ki * 4}; setp.lt.u32 pbrow, brow, ${k}; } -% endif - ${pred_emit(f'ld.global.nc.f64 b_frag_{nt}, [baddr];', pbrow, pvb, pred_reg=f'pb_{ki}_{nt}')} +% endif + ${pred_emit(f'ld.weak.global.cg.{pftype} b_frag_{nt}, [baddr];', pbrow, pvb, pred_reg=f'pb_{ki}_{nt}')} } -% endfor -% for mt in range(m_tiles): - ld.shared.f64 a_frag, [as_thr_base + ${(mt * k_iters + ki) * frag_stride_bytes}]; -% for nt in range(nn): - mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64 +% endfor +% for mt in range(m_tiles): + ld.shared.${pftype} a_frag, [as_thr_base + ${(mt * k_iters + ki) * frag_stride_bytes}]; +% for nt in range(nn): + mma.sync.aligned.m8n8k4.row.col.${pftype}.${pftype}.${pftype}.${pftype} {c0_${nt}_${mt}, c1_${nt}_${mt}}, {a_frag}, {b_frag_${nt}}, {c0_${nt}_${mt}, c1_${nt}_${mt}}; -% endfor -% endfor +% endfor +% endfor % endfor % for nt in range(nn): -% for mt in range(m_tiles): +% for mt in range(m_tiles): <% pm = f'pm_{mt}' if pm_runtime(mt) else None pvc0 = f'pvalid_c0col_{nt}' if not n_col_aligned else None @@ -262,10 +261,10 @@ $L_LOOP: { .reg .u64 caddr; add.u64 caddr, c_thr_base, ${mt * c_mtile_stride + nt * c_ntile_stride}; - ${pred_emit(f'st.global.f64 [caddr], c0_{nt}_{mt};', pm, pvc0, pred_reg=f'p0s_{nt}_{mt}')} - ${pred_emit(f'st.global.f64 [caddr + 8], c1_{nt}_{mt};', pm, pvc1, pred_reg=f'p1s_{nt}_{mt}')} + ${pred_emit(f'st.weak.global.{pftype} [caddr], c0_{nt}_{mt};', pm, pvc0, pred_reg=f'p0s_{nt}_{mt}')} + ${pred_emit(f'st.weak.global.{pftype} [caddr + {dwidth_i}], c1_{nt}_{mt};', pm, pvc1, pred_reg=f'p1s_{nt}_{mt}')} } -% endfor +% endfor % endfor % if bs: diff --git a/gimmik/kernels/ptx/dense-mma-ws.mako b/gimmik/kernels/ptx/dense-mma-ws.mako index e4b576a..a7c9f88 100644 --- a/gimmik/kernels/ptx/dense-mma-ws.mako +++ b/gimmik/kernels/ptx/dense-mma-ws.mako @@ -1,8 +1,4 @@ <%inherit file='base'/> -<% -mbar_maxwait = '0x989680' -direct_store = (beta == 0) -%> <%def name="producer_init_setup()"> // Producer warp: initial A bulk-copy + B load for ctaid_x's work @@ -67,17 +63,17 @@ $L_WAIT_BRDY: } % endfor -% if direct_store: - // direct_store: skip shared-staging entirely; compute warps store - // MMA outputs straight to global C with N-tail predication. +% if beta_zero: + // beta=0: skip shared-staging entirely; compute warps store MMA + // outputs straight to global C with N-tail predication. .reg .u64 c_glob_addr; - ld.param.u64 c_glob_addr, [_c]; + ld.param.u64 c_glob_addr, [c_desc]; cvta.to.global.u64 c_glob_addr, c_glob_addr; % else: .reg .b32 c_thr_smem; { .reg .b32 t1, ccol_b; - mul.lo.u32 t1, base_crow, ${n_per_cta * 8}; + mul.lo.u32 t1, base_crow, ${n_per_cta * dwidth_i}; shl.b32 ccol_b, base_ccol, 3; add.u32 c_thr_smem, c_smem, t1; add.u32 c_thr_smem, c_thr_smem, ccol_b; @@ -86,91 +82,91 @@ $L_WAIT_BRDY: // Zero accumulators % for mt in range(m_tiles): -% for nt in range(nn): - .reg .f64 d_x_${mt}_${nt}, d_y_${mt}_${nt}; - mov.f64 d_x_${mt}_${nt}, 0d0000000000000000; - mov.f64 d_y_${mt}_${nt}, 0d0000000000000000; -% endfor +% for nt in range(nn): + .reg .${pftype} d_x_${mt}_${nt}, d_y_${mt}_${nt}; + mov.${pftype} d_x_${mt}_${nt}, ${fzero}; + mov.${pftype} d_y_${mt}_${nt}, ${fzero}; +% endfor % endfor - .reg .f64 a_f; + .reg .${pftype} a_f; % for mt in range(m_tiles): -% for kt in range(k_iters): +% for kt in range(k_iters): <% - k_tail = (k_rem != 0 and kt == k_iters - 1) + k_tail = (k_rem != 0 and loop.last) %> { .reg .b32 a_a; - add.u32 a_a, a_thr_a, ${(kt * 32 + mt * 32 * k_iters) * 8}; - ld.shared.f64 a_f, [a_a]; -% if k_tail: + add.u32 a_a, a_thr_a, ${(kt * 32 + mt * 32 * k_iters) * dwidth_i}; + ld.shared.${pftype} a_f, [a_a]; +% if k_tail: .reg .pred pbrow_${mt}_${kt}; { .reg .b32 brow; add.u32 brow, base_brow, ${4 * kt}; setp.lt.u32 pbrow_${mt}_${kt}, brow, ${k}; } -% endif -% for nt in range(nn): +% endif +% for nt in range(nn): { .reg .b32 b_a, b_row; - .reg .f64 b_f; + .reg .${pftype} b_f; add.u32 b_row, base_brow, ${4 * kt}; - mul.lo.u32 b_row, b_row, ${n_per_cta * 8}; + mul.lo.u32 b_row, b_row, ${n_per_cta * dwidth_i}; add.u32 b_a, b_thr_a_${nt}, b_row; -% if k_tail: - mov.f64 b_f, 0d0000000000000000; - @pbrow_${mt}_${kt} ld.shared.f64 b_f, [b_a]; -% else: - ld.shared.f64 b_f, [b_a]; -% endif - mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64 +% if k_tail: + mov.${pftype} b_f, ${fzero}; + @pbrow_${mt}_${kt} ld.shared.${pftype} b_f, [b_a]; +% else: + ld.shared.${pftype} b_f, [b_a]; +% endif + mma.sync.aligned.m8n8k4.row.col.${pftype}.${pftype}.${pftype}.${pftype} {d_x_${mt}_${nt}, d_y_${mt}_${nt}}, {a_f}, {b_f}, {d_x_${mt}_${nt}, d_y_${mt}_${nt}}; } -% endfor +% endfor } -% endfor +% endfor % endfor -% if direct_store: +% if beta_zero: .reg .u64 c_thr_glob_base; { .reg .u32 thr_col_off, thr_addr_off_lo; add.u32 thr_col_off, base_ccol, n_start_curr; mad.lo.u32 thr_addr_off_lo, base_crow, ${ldc}, thr_col_off; .reg .u64 thr_byte_off; - mul.wide.u32 thr_byte_off, thr_addr_off_lo, 8; + mul.wide.u32 thr_byte_off, thr_addr_off_lo, ${dwidth_i}; add.u64 c_thr_glob_base, c_glob_addr, thr_byte_off; } -% for mt in range(m_tiles): +% for mt in range(m_tiles): <% row_tail = (m_pad > m) and ((mt + 1) * 8 > m) %> -% if row_tail: +% if row_tail: .reg .pred p_row_${mt}; { .reg .b32 crow; add.u32 crow, base_crow, ${8 * mt}; setp.lt.u32 p_row_${mt}, crow, ${m}; } -% endif -% for nt in range(nn): +% endif +% for nt in range(nn): { .reg .pred p_st; .reg .u32 g_ccol; add.u32 g_ccol, base_ccol, ${8 * nt}; add.u32 g_ccol, g_ccol, n_start_curr; setp.lt.u32 p_st, g_ccol, ${n}; -% if row_tail: +% if row_tail: and.pred p_st, p_st, p_row_${mt}; -% endif - .reg .u64 c_addr; - add.u64 c_addr, c_thr_glob_base, ${(mt * 8 * ldc + nt * 8) * 8}; - @p_st st.global.v2.f64 [c_addr], {d_x_${mt}_${nt}, d_y_${mt}_${nt}}; +% endif + .reg .u64 _c_addr; + add.u64 _c_addr, c_thr_glob_base, ${(mt * 8 * ldc + nt * 8) * dwidth_i}; + @p_st st.weak.global.v2.${pftype} [_c_addr], {d_x_${mt}_${nt}, d_y_${mt}_${nt}}; } -% endfor -% endfor +% endfor +% endfor % else: // Wait until producer's prev-iter TMA-store of C has drained. { @@ -182,18 +178,18 @@ $L_WAIT_CSTORE: // Vector-store {d_x, d_y} pairs to csmem. M-tail / N-tail OOB rows // are dropped by the C tensor map. -% for mt in range(m_tiles): -% for nt in range(nn): +% for mt in range(m_tiles): +% for nt in range(nn): { .reg .b32 csaddr; add.u32 csaddr, c_thr_smem, ${mt * c_mtile_smem_stride + nt * c_ntile_smem_stride}; - st.shared.v2.f64 [csaddr], {d_x_${mt}_${nt}, d_y_${mt}_${nt}}; + st.shared.v2.${pftype} [csaddr], {d_x_${mt}_${nt}, d_y_${mt}_${nt}}; } -% endfor -% endfor +% endfor +% endfor % endif -% if not direct_store: +% if not beta_zero: bar.sync 1, ${comp_threads}; fence.proxy.async.shared::cta; { @@ -260,7 +256,7 @@ $L_WAIT_WNEW_D: } bar.warp.sync 0xffffffff; -% if not direct_store: +% if not beta_zero: // TMA reduce+store of C (beta=1 only; beta=0 uses direct global // stores from compute warps, so the producer does no C work). { @@ -329,11 +325,9 @@ $L_AFTER_CTRL: ${', '.join(a_u64)} }; .extern .shared .align 128 .b8 ${kname}_dynm[]; -.const .align 64 .b8 ${kname}_bdesc[128]; -.const .align 64 .b8 ${kname}_cdesc[128]; -.visible .entry ${kname}(.param .u64 _b, - .param .u64 _c) +.visible .entry ${kname}(.param .u64 b_desc, + .param .u64 c_desc) .maxntid ${blockx_total}, 1, 1 { .reg .b32 tid, warp, lane, phase, ctaid_x; @@ -369,8 +363,8 @@ $L_AFTER_CTRL: add.u32 wid_new_mbar, dynm_base, ${wid_new_mbar_off}; add.u32 wid_used_mbar, dynm_base, ${wid_used_mbar_off}; - cvta.const.u64 bdesc_addr, ${kname}_bdesc; - cvta.const.u64 cdesc_addr, ${kname}_cdesc; + ld.param.u64 bdesc_addr, [b_desc]; + ld.param.u64 cdesc_addr, [c_desc]; setp.eq.u32 p_tid0, tid, 0; @@ -401,7 +395,7 @@ $L_AFTER_CTRL: } bar.sync 0; - // Compute-warp lane geometry (cheap; all warps execute uniformly) + // Compute-warp lane geometry { .reg .b32 t, w_n_base; and.b32 base_brow, lane, 3; diff --git a/gimmik/ptx.py b/gimmik/ptx.py index ad429e8..1f46384 100644 --- a/gimmik/ptx.py +++ b/gimmik/ptx.py @@ -12,35 +12,53 @@ class PTXMatMul(MatMul): basemeta = {'block': (128, 1, 1), 'width': 1, 'shared': 0, 'dynamic_shared': 0} - @staticmethod - def is_sparse_suitable(arr): + DENSE_SMEM_MAX = 200*1024 + PTX_SM = {(8, 0), (9, 0), (10, 0), (10, 3), (12, 0), (12, 1)} + + @classmethod + def is_sparse_suitable(cls, arr, cc): nnz = int(np.count_nonzero(arr)) nuq = int(len(np.unique(np.abs(arr)))) density = nnz / arr.size - return (nuq <= 28) or (density <= 0.15) + return ((nuq <= 28) or (density <= 0.15)) and cc in cls.PTX_SM - # Shape/arch gate for dense DMMA; n/ldb/ldc are validated at generate time - @staticmethod - def is_dense_suitable(arr, dtype, cc): - return (np.dtype(dtype) == np.float64 - and cc is not None and cc >= (9, 0) + @classmethod + def is_dense_suitable(cls, arr, cc): + cc_appropriate = cc in cls.PTX_SM and cc >= (9, 0) + return (arr.dtype == np.float64 and cc_appropriate and arr.shape[0] <= 128 and arr.shape[1] <= 128) @classmethod - def is_suitable(cls, arr, dtype, cc): - return (cls.is_sparse_suitable(arr) - or cls.is_dense_suitable(arr, dtype, cc)) + def is_suitable(cls, arr, cc): + return cls.is_sparse_suitable(arr, cc) or cls.is_dense_suitable(arr, cc) def _kernel_generators(self, dtype, dsize, *, compute_capability=None): - base_args = {'cc': compute_capability, - 'pred_emit': self._pred_emit} - - yield from self._sparse_kernel_generators(dtype, dsize, base_args) - yield from self._dense_kernel_generators(dtype, dsize, base_args) + cc = compute_capability or (0, 0) + base_args = {'cc': cc, + 'pred_emit': self._pred_emit, + 'pftype': 'f32' if dtype == 'float' else 'f64', + 'dwidth_i': 4 if dtype == 'float' else 8, + 'fzero': ('0f00000000' if dtype == 'float' + else '0d0000000000000000'), + 'beta_zero': self.beta == 0, + 'mbar_maxwait': '0x989680' + } + + if self.is_sparse_suitable(self.A, cc): + yield from self._sparse_kernel_generators(dtype, dsize, base_args) + + if self.is_dense_suitable(self.A, cc): + yield from self._dense_kernel_generators(dtype, dsize, base_args) def _sparse_kernel_generators(self, dtype, dsize, base_args): - if not self.is_sparse_suitable(self.A): - return + # Sparse-shared template constants + base_args = base_args | { + 'bix_list': list(self.bix), + 'bix_pos': self.bix, + 'has_zero_rows': bool(self.has_zero_rows), + 'row_nz': [[(kx, self.A[j, kx]) for kx in range(self.k) + if self.A[j, kx] != 0] for j in range(self.m)], + } # B loading, C streaming kernel yield ('cstream', base_args, {'desc': 'cstream'}) @@ -93,145 +111,124 @@ def _sparse_kernel_generators(self, dtype, dsize, base_args): def _dense_kernel_generators(self, dtype, dsize, base_args): cc = base_args['cc'] or (0, 0) - if not (self.is_dense_suitable(self.A, dtype, cc) - and self.n is not None): - return - # Some kernels can optional steal blocks - bs_default = cc >= (10, 0) - - if cc >= (10, 0): - # Warp specialised is uniformly better on sm_100+, so no need to JIT - # other versions + # Block stealing requires sm_100+ + block_steal = cc >= (10, 0) + if block_steal: dense_configs = [('dense-mma-smem-gA', 4, 4)] else: dense_configs = [ ('dense-mma-smem-gA', 1, 8), ('dense-mma-smem-gA', 2, 4), ('dense-mma-smem-gA', 4, 4), - ('dense-mma-gAd', 2, 2), - ('dense-mma-gAd', 4, 2), + ('dense-mma-gAd', 2, 2), + ('dense-mma-gAd', 4, 2), ] for tpl, nn, w in dense_configs: blkx = 32 * w - n_per_cta = 8 * nn * w - if n_per_cta > self.n: + if (n_per_cta := 8 * nn * w) > self.n: continue - bs = (tpl == 'dense-mma-smem-gA') and bs_default setup = self._dense_mma_setup(nn=nn, warps_per_cta=w) args = (base_args | {'warps_per_cta': w, 'nn': nn, - 'block_stealing': bs} | setup) + 'block_stealing': block_steal} | setup) meta = { 'block': (blkx, 1, 1), 'grid': (-(-self.n // n_per_cta), 1, 1), - 'desc': f'{tpl}/nn{nn}-w{w}{"-bs" if bs else ""}', + 'desc': f'{tpl}/nn{nn}-w{w}{'-bs' if block_steal else ''}', } yield (tpl, args, meta) - # Warp-specialised dense DMMA - if cc >= (10, 0): + # Warp-specialised dense DMMA, required block stealing + if block_steal: yield from self._dense_ws_kernel_generators(dtype, dsize, base_args) def _dense_ws_kernel_generators(self, dtype, dsize, base_args): - m_pad = -(-self.m // 8) * 8 - k_pad = -(-self.k // 4) * 4 - - # (nn, w_compute) -- block has w_compute + 2 warps (producer, stealer) + # (nn, compute) -- block has compute + 2 warps (producer, stealer) ws_configs = [(1, 4), (2, 4), (4, 4)] for nn, w in ws_configs: - n_per_cta = 8 * nn * w - if n_per_cta > self.n: + if (n_per_cta := 8 * nn * w) > self.n: continue - blkx = 32 * (w + 2) + setup = self._dense_mma_setup(nn=nn, warps_per_cta=w) - ws_layout = self._dense_ws_layout( - n_comp_warps=w, n_per_cta=n_per_cta, - m_pad=m_pad, k_pad=k_pad, a_elems=setup['a_elems'] - ) + ws_setup = self._dense_ws_setup(setup, n_comp_warps=w) - if ws_layout['dynm_total_bytes'] > 200 * 1024: + if ws_setup['dynm_total_bytes'] > self.DENSE_SMEM_MAX: continue - args = (base_args - | {'warps_per_cta': w, 'nn': nn} - | setup | ws_layout) + args = base_args | {'nn': nn} | setup | ws_setup meta = { - 'block': (blkx, 1, 1), + 'block': (32 * (w + 2), 1, 1), 'grid': (-(-self.n // n_per_cta), 1, 1), 'desc': f'dense-mma-ws/nn{nn}-w{w}', - 'ws_tensor_map': True, - 'ws_n_per_cta': n_per_cta, - 'ws_k_pad': k_pad, - 'ws_m_pad': m_pad, - 'dynamic_shared': ws_layout['dynm_total_bytes'], + 'ws_b_tile': (n_per_cta, setup['k_pad']), + 'dynamic_shared': ws_setup['dynm_total_bytes'], } + if self.beta != 0: + meta |= {'ws_out_tile': (n_per_cta, setup['m_pad'])} yield ('dense-mma-ws', args, meta) @staticmethod - def _dense_ws_layout(*, n_comp_warps, n_per_cta, m_pad, k_pad, a_elems): - n_total_warps = n_comp_warps + 2 - blockx_total = 32 * n_total_warps - - b_tile_bytes = k_pad * n_per_cta * 8 - c_tile_bytes = m_pad * n_per_cta * 8 - a_bytes = a_elems * 8 - - smem_size = {'b1': b_tile_bytes, 'b2': b_tile_bytes, 'c': c_tile_bytes, - 'a': a_bytes, 'wid': 16} - smem_off, off = {}, 0 - for k, v in smem_size.items(): - off = (off + 15) & ~15 - smem_off[f'{k}_off'] = off - off += v - - mbar_names = ('tma', 'bready', 'cready', 'cstored', - 'steal', 'wid_new', 'wid_used') - for k in mbar_names: - smem_off[f'{k}_mbar_off'] = off + def _dsmem_alloc(regions, mbars, align=16): + out, off = {}, 0 + for name, size in regions: + off = (off + align - 1) & ~(align - 1) + out[f'{name}_off'] = off + off += size + for name in mbars: + out[f'{name}_mbar_off'] = off off += 8 + total = (off + align - 1) & ~(align - 1) + return out, total - # Pad total to 16-byte multiple - dynm_total_bytes = (off + 15) & ~15 - - params = {'n_comp_warps': n_comp_warps, - 'blockx_total': blockx_total, - 'prod_warp': n_comp_warps, - 'steal_warp': n_comp_warps + 1, - 'comp_threads': 32 * n_comp_warps, - 'm_pad': m_pad, - 'k_pad': k_pad, - 'b_tile_doubles': k_pad * n_per_cta, - 'b_tile_bytes': b_tile_bytes, - 'c_tile_doubles': m_pad * n_per_cta, - 'c_mtile_smem_stride': 8 * n_per_cta * 8, - 'c_ntile_smem_stride': 8 * 8, - 'dynm_total_bytes': dynm_total_bytes, - } - params |= smem_off - return params + @classmethod + def _dense_ws_setup(cls, setup, *, n_comp_warps): + n_per_cta = setup['n_per_cta'] + b_tile_bytes = setup['k_pad'] * n_per_cta * 8 + c_tile_bytes = setup['m_pad'] * n_per_cta * 8 + a_bytes = setup['a_elems'] * 8 + + regions = [('b1', b_tile_bytes), ('b2', b_tile_bytes), + ('c', c_tile_bytes), ('a', a_bytes), ('wid', 16)] + mbars = ('tma', 'bready', 'cready', 'cstored', + 'steal', 'wid_new', 'wid_used') + offsets, dynm_total_bytes = cls._dsmem_alloc(regions, mbars) + + return offsets | { + 'n_comp_warps': n_comp_warps, + 'blockx_total': 32 * (n_comp_warps + 2), + 'prod_warp': n_comp_warps, + 'steal_warp': n_comp_warps + 1, + 'comp_threads': 32 * n_comp_warps, + 'b_tile_bytes': b_tile_bytes, + 'c_mtile_smem_stride': 8 * n_per_cta * 8, + 'c_ntile_smem_stride': 8 * 8, + 'dynm_total_bytes': dynm_total_bytes, + } def _dense_mma_setup(self, *, nn, warps_per_cta): a = self.A m, k = a.shape m_tiles = -(-m // 8) - k_rem = k % 4 - k_iters = (k + (4 - k_rem if k_rem else 0)) // 4 + k_iters = -(-k // 4) + k_rem = k % 4 - # A in fragment layout: lane l -> A[m_tile*8 + l/4][k_iter*4 + l%4] + # A in fragment layout: lane l -> A[mt*8 + l//4][kt*4 + l%4] + # DMMA tiles are 8x8x4 so this loads a 8x4 tile, flattens it and + # and packs it as uint64 as a more robust way of storing the values in + # the template a_u64 = [] - for m_tile in range(m_tiles): - for k_iter in range(k_iters): + for mt in range(m_tiles): + for kt in range(k_iters): for lane in range(32): - i = m_tile * 8 + lane // 4 - j = k_iter * 4 + lane % 4 + i = mt * 8 + lane // 4 + j = kt * 4 + lane % 4 v = float(a[i, j]) if (i < m and j < k) else 0.0 - u = struct.unpack(' Date: Thu, 21 May 2026 06:26:38 -0700 Subject: [PATCH 08/11] General cleanups and moved smem to pyfr --- gimmik/cuda.py | 7 +++ gimmik/kernels/ptx/bstream-msplit.mako | 2 +- gimmik/kernels/ptx/bstream.mako | 10 ++--- gimmik/kernels/ptx/cstream-w2.mako | 12 +++--- gimmik/kernels/ptx/cstream.mako | 8 ++-- gimmik/kernels/ptx/dense-mma-gAd.mako | 4 +- gimmik/kernels/ptx/dense-mma-smem-gA.mako | 4 +- gimmik/kernels/ptx/dense-mma-ws.mako | 4 +- gimmik/ptx.py | 52 +++++++++++------------ 9 files changed, 53 insertions(+), 50 deletions(-) diff --git a/gimmik/cuda.py b/gimmik/cuda.py index b18c509..e40179c 100644 --- a/gimmik/cuda.py +++ b/gimmik/cuda.py @@ -8,6 +8,13 @@ class CUDAMatMul(MatMul): basemeta = {'block': (128, 1, 1), 'width': 1, 'shared': 0, 'dynamic_shared': 0} + @staticmethod + def is_suitable(arr): + nnz = np.count_nonzero(arr) + nuq = len(np.unique(np.abs(arr))) + density = nnz / arr.size + return (nuq <= 28) or (density <= 0.15) + def _kernel_generators(self, dtype, dsize, *, compute_capability=None): # B loading, C streaming kernel yield ('cstream', {}, {}) diff --git a/gimmik/kernels/ptx/bstream-msplit.mako b/gimmik/kernels/ptx/bstream-msplit.mako index 530b19f..2ef85e9 100644 --- a/gimmik/kernels/ptx/bstream-msplit.mako +++ b/gimmik/kernels/ptx/bstream-msplit.mako @@ -2,7 +2,7 @@ <% mx = partition(A, into=msplit, by='rows') -bchunks = chunk(bix_list, bsz) +bchunks = chunk(bix, bsz) m_per_group = max(len(mcx) for mcx in mx) bsub_bytes = 2 * bsz * blockx * dwidth_i def bsub_off(buf, idx): diff --git a/gimmik/kernels/ptx/bstream.mako b/gimmik/kernels/ptx/bstream.mako index 24b0acb..45eb1a7 100644 --- a/gimmik/kernels/ptx/bstream.mako +++ b/gimmik/kernels/ptx/bstream.mako @@ -17,7 +17,7 @@ % endif .reg .u32 n, id; .reg .u64 b, c, b_base, c_base; - .reg .${pftype} csub<${m}>, bv<${len(bix_list)}>; + .reg .${pftype} csub<${m}>, bv<${len(bix)}>; .reg .pred p1; % if n is None: @@ -50,7 +50,7 @@ } ## Batch-load active B columns -% for i, kx in enumerate(bix_list): +% for i, kx in enumerate(bix): % if n is None: { .reg .u32 _boff; @@ -89,13 +89,13 @@ % endif ## Main compute -% for kx in bix_list: +% for kx in bix: % for j, jx in enumerate(A[:, kx]): % if jx != 0: % if beta_zero and kx == afix[j]: - mul.${pftype} csub${j}, bv${bix_pos[kx]}, ${jx}; + mul.${pftype} csub${j}, bv${bix[kx]}, ${jx}; % else: - fma.rn.${pftype} csub${j}, bv${bix_pos[kx]}, ${jx}, csub${j}; + fma.rn.${pftype} csub${j}, bv${bix[kx]}, ${jx}, csub${j}; % endif % endif % if kx == alix[j]: diff --git a/gimmik/kernels/ptx/cstream-w2.mako b/gimmik/kernels/ptx/cstream-w2.mako index e6b4d75..ce7301d 100644 --- a/gimmik/kernels/ptx/cstream-w2.mako +++ b/gimmik/kernels/ptx/cstream-w2.mako @@ -5,7 +5,7 @@ { .reg .u32 n, id; .reg .u64 b, c, b_base, c_base; - .reg .f64 bv_a<${len(bix_list)}>, bv_b<${len(bix_list)}>, dotp_a, dotp_b; + .reg .f64 bv_a<${len(bix)}>, bv_b<${len(bix)}>, dotp_a, dotp_b; .reg .pred p1; mov.u32 n, ${-(-n // 2)}; @@ -33,7 +33,7 @@ } ## Batch-load B column pairs -% for i, kx in enumerate(bix_list): +% for i, kx in enumerate(bix): ld.weak.global.cg.v2.f64 {bv_a${i}, bv_b${i}}, [b_base + ${ldb*kx*dwidth_i}]; % endfor @@ -42,11 +42,11 @@ % if row_nz[j]: % for kx, jx in row_nz[j]: % if loop.first: - mul.f64 dotp_a, bv_a${bix_pos[kx]}, ${jx}; - mul.f64 dotp_b, bv_b${bix_pos[kx]}, ${jx}; + mul.f64 dotp_a, bv_a${bix[kx]}, ${jx}; + mul.f64 dotp_b, bv_b${bix[kx]}, ${jx}; % else: - fma.rn.f64 dotp_a, bv_a${bix_pos[kx]}, ${jx}, dotp_a; - fma.rn.f64 dotp_b, bv_b${bix_pos[kx]}, ${jx}, dotp_b; + fma.rn.f64 dotp_a, bv_a${bix[kx]}, ${jx}, dotp_a; + fma.rn.f64 dotp_b, bv_b${bix[kx]}, ${jx}, dotp_b; % endif % endfor % if beta_zero: diff --git a/gimmik/kernels/ptx/cstream.mako b/gimmik/kernels/ptx/cstream.mako index 726fe46..9ce4c4d 100644 --- a/gimmik/kernels/ptx/cstream.mako +++ b/gimmik/kernels/ptx/cstream.mako @@ -17,7 +17,7 @@ % endif .reg .u32 n, id; .reg .u64 b, c, b_base, c_base; - .reg .${pftype} bv<${len(bix_list)}>, dotp; + .reg .${pftype} bv<${len(bix)}>, dotp; .reg .pred p1; % if n is None: @@ -50,7 +50,7 @@ } ## Batch-load active B columns -% for i, kx in enumerate(bix_list): +% for i, kx in enumerate(bix): % if n is None: { .reg .u32 _boff; @@ -69,9 +69,9 @@ % if row_nz[j]: % for kx, jx in row_nz[j]: % if loop.first: - mul.${pftype} dotp, bv${bix_pos[kx]}, ${jx}; + mul.${pftype} dotp, bv${bix[kx]}, ${jx}; % else: - fma.rn.${pftype} dotp, bv${bix_pos[kx]}, ${jx}, dotp; + fma.rn.${pftype} dotp, bv${bix[kx]}, ${jx}, dotp; % endif % endfor % if beta_zero: diff --git a/gimmik/kernels/ptx/dense-mma-gAd.mako b/gimmik/kernels/ptx/dense-mma-gAd.mako index 8933e51..0996a17 100644 --- a/gimmik/kernels/ptx/dense-mma-gAd.mako +++ b/gimmik/kernels/ptx/dense-mma-gAd.mako @@ -125,7 +125,7 @@ % endfor % endfor -% for ki in range(k_iters): +% for ki in range(k_tiles): % for nt in range(nn): <% pvb = f'pvalid_bcol_{nt}' if not n_col_aligned else None @@ -151,7 +151,7 @@ } % endfor % for mt in range(m_tiles): - ld.weak.global.${pftype} a_frag, [ag_thr_base + ${(mt * k_iters + ki) * frag_stride_bytes}]; + ld.weak.global.${pftype} a_frag, [ag_thr_base + ${(mt * k_tiles + ki) * frag_stride_bytes}]; % for nt in range(nn): mma.sync.aligned.m8n8k4.row.col.${pftype}.${pftype}.${pftype}.${pftype} {c0_${nt}_${mt}, c1_${nt}_${mt}}, diff --git a/gimmik/kernels/ptx/dense-mma-smem-gA.mako b/gimmik/kernels/ptx/dense-mma-smem-gA.mako index d1b72a8..4516831 100644 --- a/gimmik/kernels/ptx/dense-mma-smem-gA.mako +++ b/gimmik/kernels/ptx/dense-mma-smem-gA.mako @@ -214,7 +214,7 @@ $L_LOOP: % endfor % endfor -% for ki in range(k_iters): +% for ki in range(k_tiles): % for nt in range(nn): <% pvb = f'pvalid_bcol_{nt}' if not n_col_aligned else None @@ -240,7 +240,7 @@ $L_LOOP: } % endfor % for mt in range(m_tiles): - ld.shared.${pftype} a_frag, [as_thr_base + ${(mt * k_iters + ki) * frag_stride_bytes}]; + ld.shared.${pftype} a_frag, [as_thr_base + ${(mt * k_tiles + ki) * frag_stride_bytes}]; % for nt in range(nn): mma.sync.aligned.m8n8k4.row.col.${pftype}.${pftype}.${pftype}.${pftype} {c0_${nt}_${mt}, c1_${nt}_${mt}}, diff --git a/gimmik/kernels/ptx/dense-mma-ws.mako b/gimmik/kernels/ptx/dense-mma-ws.mako index a7c9f88..de4314f 100644 --- a/gimmik/kernels/ptx/dense-mma-ws.mako +++ b/gimmik/kernels/ptx/dense-mma-ws.mako @@ -91,13 +91,13 @@ $L_WAIT_BRDY: .reg .${pftype} a_f; % for mt in range(m_tiles): -% for kt in range(k_iters): +% for kt in range(k_tiles): <% k_tail = (k_rem != 0 and loop.last) %> { .reg .b32 a_a; - add.u32 a_a, a_thr_a, ${(kt * 32 + mt * 32 * k_iters) * dwidth_i}; + add.u32 a_a, a_thr_a, ${(kt * 32 + mt * 32 * k_tiles) * dwidth_i}; ld.shared.${pftype} a_f, [a_a]; % if k_tail: .reg .pred pbrow_${mt}_${kt}; diff --git a/gimmik/ptx.py b/gimmik/ptx.py index 1f46384..a9c049d 100644 --- a/gimmik/ptx.py +++ b/gimmik/ptx.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- -import struct import numpy as np @@ -12,13 +11,12 @@ class PTXMatMul(MatMul): basemeta = {'block': (128, 1, 1), 'width': 1, 'shared': 0, 'dynamic_shared': 0} - DENSE_SMEM_MAX = 200*1024 PTX_SM = {(8, 0), (9, 0), (10, 0), (10, 3), (12, 0), (12, 1)} @classmethod def is_sparse_suitable(cls, arr, cc): - nnz = int(np.count_nonzero(arr)) - nuq = int(len(np.unique(np.abs(arr)))) + nnz = np.count_nonzero(arr) + nuq = len(np.unique(np.abs(arr))) density = nnz / arr.size return ((nuq <= 28) or (density <= 0.15)) and cc in cls.PTX_SM @@ -32,9 +30,12 @@ def is_dense_suitable(cls, arr, cc): def is_suitable(cls, arr, cc): return cls.is_sparse_suitable(arr, cc) or cls.is_dense_suitable(arr, cc) - def _kernel_generators(self, dtype, dsize, *, compute_capability=None): + def _kernel_generators(self, dtype, dsize, *, compute_capability=None, + smem_info=None): cc = compute_capability or (0, 0) + smem_info = smem_info or (48*1024, 48*1024) base_args = {'cc': cc, + 'smem_info': smem_info, 'pred_emit': self._pred_emit, 'pftype': 'f32' if dtype == 'float' else 'f64', 'dwidth_i': 4 if dtype == 'float' else 8, @@ -53,8 +54,6 @@ def _kernel_generators(self, dtype, dsize, *, compute_capability=None): def _sparse_kernel_generators(self, dtype, dsize, base_args): # Sparse-shared template constants base_args = base_args | { - 'bix_list': list(self.bix), - 'bix_pos': self.bix, 'has_zero_rows': bool(self.has_zero_rows), 'row_nz': [[(kx, self.A[j, kx]) for kx in range(self.k) if self.A[j, kx] != 0] for j in range(self.m)], @@ -126,10 +125,10 @@ def _dense_kernel_generators(self, dtype, dsize, base_args): ] for tpl, nn, w in dense_configs: - blkx = 32 * w if (n_per_cta := 8 * nn * w) > self.n: continue setup = self._dense_mma_setup(nn=nn, warps_per_cta=w) + blkx = 32 * w args = (base_args | {'warps_per_cta': w, 'nn': nn, 'block_stealing': block_steal} | setup) meta = { @@ -144,6 +143,8 @@ def _dense_kernel_generators(self, dtype, dsize, base_args): yield from self._dense_ws_kernel_generators(dtype, dsize, base_args) def _dense_ws_kernel_generators(self, dtype, dsize, base_args): + static_max, dynamic_max = base_args['smem_info'] + # (nn, compute) -- block has compute + 2 warps (producer, stealer) ws_configs = [(1, 4), (2, 4), (4, 4)] for nn, w in ws_configs: @@ -153,12 +154,13 @@ def _dense_ws_kernel_generators(self, dtype, dsize, base_args): setup = self._dense_mma_setup(nn=nn, warps_per_cta=w) ws_setup = self._dense_ws_setup(setup, n_comp_warps=w) - if ws_setup['dynm_total_bytes'] > self.DENSE_SMEM_MAX: + if ws_setup['dynm_total_bytes'] > dynamic_max: continue + blkx = 32 * (w + 2) args = base_args | {'nn': nn} | setup | ws_setup meta = { - 'block': (32 * (w + 2), 1, 1), + 'block': (blkx, 1, 1), 'grid': (-(-self.n // n_per_cta), 1, 1), 'desc': f'dense-mma-ws/nn{nn}-w{w}', 'ws_b_tile': (n_per_cta, setup['k_pad']), @@ -209,23 +211,17 @@ def _dense_ws_setup(cls, setup, *, n_comp_warps): def _dense_mma_setup(self, *, nn, warps_per_cta): a = self.A m, k = a.shape - m_tiles = -(-m // 8) - k_iters = -(-k // 4) + m_tiles = (m + 7) // 8 + k_tiles = (k + 3) // 4 k_rem = k % 4 - # A in fragment layout: lane l -> A[mt*8 + l//4][kt*4 + l%4] - # DMMA tiles are 8x8x4 so this loads a 8x4 tile, flattens it and - # and packs it as uint64 as a more robust way of storing the values in - # the template - a_u64 = [] - for mt in range(m_tiles): - for kt in range(k_iters): - for lane in range(32): - i = mt * 8 + lane // 4 - j = kt * 4 + lane % 4 - v = float(a[i, j]) if (i < m and j < k) else 0.0 - u, = struct.unpack(' A[mt*8 + l//4][kt*4 + l%4] + # i.e. an (m_tiles, k_tiles) grid of row-major 8x4 tiles, packed as + # uint64 + a_pad = np.zeros((m_tiles*8, k_tiles*4), dtype=np.float64) + a_pad[:m, :k] = a + tiles = a_pad.reshape(m_tiles, 8, k_tiles, 4).transpose(0, 2, 1, 3) + a_u64 = [f'0x{u:016x}' for u in tiles.view(np.uint64).ravel()] n_per_warp = 8 * nn n_per_cta = warps_per_cta * n_per_warp @@ -237,12 +233,12 @@ def pm_runtime(mt): return { 'm_tiles': m_tiles, - 'k_iters': k_iters, + 'k_tiles': k_tiles, 'k_rem': k_rem, 'm_pad': m_tiles * 8, - 'k_pad': k_iters * 4, + 'k_pad': k_tiles * 4, 'a_u64': a_u64, - 'a_elems': m_tiles * k_iters * 32, + 'a_elems': m_tiles * k_tiles * 32, 'n_per_warp': n_per_warp, 'n_per_cta': n_per_cta, 'frag_stride_bytes': 32 * 8, From 0e86053d3de9c70cfbcf5b6c7c85d3912aea771e Mon Sep 17 00:00:00 2001 From: WillTrojak Date: Thu, 21 May 2026 09:26:39 -0700 Subject: [PATCH 09/11] Fixed missing import --- gimmik/cuda.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/gimmik/cuda.py b/gimmik/cuda.py index e40179c..8afc755 100644 --- a/gimmik/cuda.py +++ b/gimmik/cuda.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- +import numpy as np + from gimmik.base import MatMul From 1f62b5f4aa09a5b77973b34ad64ad9251bb55135 Mon Sep 17 00:00:00 2001 From: WillTrojak Date: Thu, 21 May 2026 09:30:40 -0700 Subject: [PATCH 10/11] Fixed additional args --- gimmik/cuda.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/gimmik/cuda.py b/gimmik/cuda.py index 8afc755..9e1da43 100644 --- a/gimmik/cuda.py +++ b/gimmik/cuda.py @@ -17,7 +17,8 @@ def is_suitable(arr): density = nnz / arr.size return (nuq <= 28) or (density <= 0.15) - def _kernel_generators(self, dtype, dsize, *, compute_capability=None): + def _kernel_generators(self, dtype, dsize, *, compute_capability=None, + **kwargs): # B loading, C streaming kernel yield ('cstream', {}, {}) From 79f41cb9a689da9d2f89c8d85143911e4eccf0cf Mon Sep 17 00:00:00 2001 From: WillTrojak Date: Fri, 22 May 2026 07:15:58 -0700 Subject: [PATCH 11/11] Cleanup and added PTX Version to handle older drivers. --- gimmik/kernels/ptx/base.mako | 2 +- gimmik/kernels/ptx/dense-mma-gAd.mako | 2 +- gimmik/kernels/ptx/dense-mma-smem-gA.mako | 14 +-- gimmik/kernels/ptx/dense-mma-ws.mako | 6 +- gimmik/ptx.py | 104 +++++++++++++--------- 5 files changed, 75 insertions(+), 53 deletions(-) diff --git a/gimmik/kernels/ptx/base.mako b/gimmik/kernels/ptx/base.mako index 71eb414..dbd8433 100644 --- a/gimmik/kernels/ptx/base.mako +++ b/gimmik/kernels/ptx/base.mako @@ -1,4 +1,4 @@ -.version 8.6 +.version ${ptx[0]}.${ptx[1]} .target sm_${cc[0]}${cc[1]}${'a' if cc[0] >= 9 else ''} .address_size 64 ${next.body()} diff --git a/gimmik/kernels/ptx/dense-mma-gAd.mako b/gimmik/kernels/ptx/dense-mma-gAd.mako index 0996a17..3df43c0 100644 --- a/gimmik/kernels/ptx/dense-mma-gAd.mako +++ b/gimmik/kernels/ptx/dense-mma-gAd.mako @@ -1,6 +1,6 @@ <%inherit file='base'/> -.global .align 16 .b64 ${kname}_Ag[${a_elems}] = { +.global .align 16 .b64 ${kname}_Ag[${m_tiles*k_tiles*32}] = { ${', '.join(a_u64)} }; diff --git a/gimmik/kernels/ptx/dense-mma-smem-gA.mako b/gimmik/kernels/ptx/dense-mma-smem-gA.mako index 4516831..9a88b64 100644 --- a/gimmik/kernels/ptx/dense-mma-smem-gA.mako +++ b/gimmik/kernels/ptx/dense-mma-smem-gA.mako @@ -3,8 +3,8 @@ <% # Cooperative-copy params (gA-only) blockx = 32 * warps_per_cta -a_pairs = a_elems // 2 -a_pairs_tail = a_elems % 2 +a_pairs = m_tiles*k_tiles*32 // 2 +a_pairs_tail = m_tiles*k_tiles*32 % 2 copy_v2_iters = (a_pairs + blockx - 1) // blockx bs = bool(block_stealing) %> @@ -13,10 +13,10 @@ bs = bool(block_stealing) .shared .align 8 .b64 ${kname}_mbar; .shared .align 16 .b8 ${kname}_workid[16]; % endif -.global .align 16 .b64 ${kname}_Ag[${a_elems}] = { +.global .align 16 .b64 ${kname}_Ag[${m_tiles*k_tiles*32}] = { ${', '.join(a_u64)} }; -.shared .align 16 .b64 ${kname}_As[${a_elems}]; +.shared .align 16 .b64 ${kname}_As[${m_tiles*k_tiles*32}]; .visible .entry ${kname}(.param .u64 _b, .param .u64 _c) @@ -95,14 +95,14 @@ bs = bool(block_stealing) } % endfor % if a_pairs_tail: - // Tail element (only when a_elems is odd) + // Tail element (only when m_tiles*k_tiles*32 is odd) { .reg .pred plast; .reg .u64 gaddr, saddr; .reg .${pftype} v; setp.eq.u32 plast, tid, 0; - add.u64 gaddr, a_glb_base, ${(a_elems-1) * dwidth_i}; - add.u64 saddr, a_smem_base, ${(a_elems-1) * dwidth_i}; + add.u64 gaddr, a_glb_base, ${(m_tiles*k_tiles*32-1) * dwidth_i}; + add.u64 saddr, a_smem_base, ${(m_tiles*k_tiles*32-1) * dwidth_i}; @plast ld.weak.global.cg.${pftype} v, [gaddr]; @plast st.shared.${pftype} [saddr], v; } diff --git a/gimmik/kernels/ptx/dense-mma-ws.mako b/gimmik/kernels/ptx/dense-mma-ws.mako index de4314f..e151372 100644 --- a/gimmik/kernels/ptx/dense-mma-ws.mako +++ b/gimmik/kernels/ptx/dense-mma-ws.mako @@ -10,11 +10,11 @@ mov.u64 a_glb, ${kname}_Ag; cvta.to.global.u64 a_glb, a_glb; @p_warp_lead cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes - [a_smem], [a_glb], ${a_elems * 8}, [tma_mbar]; + [a_smem], [a_glb], ${m_tiles*k_tiles*32 * 8}, [tma_mbar]; @p_warp_lead cp.async.bulk.tensor.2d.shared::cta.global.tile.mbarrier::complete_tx::bytes [b1_smem], [bdesc_addr, {n_start0, 0}], [tma_mbar]; @p_warp_lead mbarrier.expect_tx.relaxed.cta.shared::cta.b64 - [tma_mbar], ${b_tile_bytes + a_elems * 8}; + [tma_mbar], ${b_tile_bytes + m_tiles*k_tiles*32 * 8}; bar.warp.sync 0xffffffff; .reg .b64 state; .reg .pred p1; @@ -321,7 +321,7 @@ $L_WAIT_WUSED: $L_AFTER_CTRL: -.global .align 16 .b64 ${kname}_Ag[${a_elems}] = { +.global .align 16 .b64 ${kname}_Ag[${m_tiles*k_tiles*32}] = { ${', '.join(a_u64)} }; .extern .shared .align 128 .b8 ${kname}_dynm[]; diff --git a/gimmik/ptx.py b/gimmik/ptx.py index a9c049d..f7be8f4 100644 --- a/gimmik/ptx.py +++ b/gimmik/ptx.py @@ -1,6 +1,3 @@ -# -*- coding: utf-8 -*- - - import numpy as np from gimmik.base import MatMul @@ -8,10 +5,16 @@ class PTXMatMul(MatMul): platform = 'ptx' - basemeta = {'block': (128, 1, 1), 'width': 1, 'shared': 0, - 'dynamic_shared': 0} + basemeta = { + 'block': (128, 1, 1), + 'width': 1, + 'shared': 0, + 'dynamic_shared': 0 + } - PTX_SM = {(8, 0), (9, 0), (10, 0), (10, 3), (12, 0), (12, 1)} + # Map Supported CC -> Minimum PTX version + PTX_SM = {(8, 0): (7, 0), (9, 0): (8, 0), (10, 0): (8, 7), (10, 3): (8, 7), + (12, 0): (8, 7), (12, 1): (8, 7)} @classmethod def is_sparse_suitable(cls, arr, cc): @@ -33,17 +36,20 @@ def is_suitable(cls, arr, cc): def _kernel_generators(self, dtype, dsize, *, compute_capability=None, smem_info=None): cc = compute_capability or (0, 0) + ptx = self.PTX_SM.get(cc, (0, 0)) smem_info = smem_info or (48*1024, 48*1024) - base_args = {'cc': cc, - 'smem_info': smem_info, - 'pred_emit': self._pred_emit, - 'pftype': 'f32' if dtype == 'float' else 'f64', - 'dwidth_i': 4 if dtype == 'float' else 8, - 'fzero': ('0f00000000' if dtype == 'float' - else '0d0000000000000000'), - 'beta_zero': self.beta == 0, - 'mbar_maxwait': '0x989680' - } + base_args = { + 'ptx': ptx, + 'cc': cc, + 'smem_info': smem_info, + 'pred_emit': self._pred_emit, + 'pftype': 'f32' if dtype == 'float' else 'f64', + 'dwidth_i': 4 if dtype == 'float' else 8, + 'fzero': ('0f00000000' if dtype == 'float' + else '0d0000000000000000'), + 'beta_zero': self.beta == 0, + 'mbar_maxwait': '0x989680', + } if self.is_sparse_suitable(self.A, cc): yield from self._sparse_kernel_generators(dtype, dsize, base_args) @@ -68,24 +74,32 @@ def _sparse_kernel_generators(self, dtype, dsize, base_args): # Four-way m-split B streaming, C accumulation kernel ms, bsz, blkx = 4, 24, 32 args = base_args | {'msplit': ms, 'bsz': bsz, 'blockx': blkx} - meta = {'block': (blkx, ms, 1), 'shared': 2*bsz*blkx*dsize, - 'desc': f'bstream-msplit/m{ms}-b{bsz}-x{blkx}'} + meta = { + 'block': (blkx, ms, 1), + 'shared': 2*bsz*blkx*dsize, + 'desc': f'bstream-msplit/m{ms}-b{bsz}-x{blkx}', + } yield ('bstream-msplit', args, meta) # Single-warp LDGSTS variant for medium-M beta=0 large-K cases if self.beta == 0 and self.m <= 320 and len(self.bix) >= 64: ms, bsz, blkx = 1, 32, 64 args = base_args | {'msplit': ms, 'bsz': bsz, 'blockx': blkx} - meta = {'block': (blkx, ms, 1), - 'shared': 2*bsz*blkx*dsize, - 'desc': f'bstream-msplit/m{ms}-b{bsz}-x{blkx}'} + meta = { + 'block': (blkx, ms, 1), + 'shared': 2*bsz*blkx*dsize, + 'desc': f'bstream-msplit/m{ms}-b{bsz}-x{blkx}', + } yield ('bstream-msplit', args, meta) # Two-way k-split B loading, C streaming kernel ks, csz, blkx = 2, 24, 32 args = base_args | {'ksplit': ks, 'csz': csz, 'blockx': blkx} - meta = {'block': (blkx, ks, 1), 'shared': (ks - 1)*csz*blkx*dsize, - 'desc': f'cstream-ksplit/k{ks}-c{csz}-x{blkx}'} + meta = { + 'block': (blkx, ks, 1), + 'shared': (ks - 1)*csz*blkx*dsize, + 'desc': f'cstream-ksplit/k{ks}-c{csz}-x{blkx}', + } yield ('cstream-ksplit', args, meta) # Four-way k-split for large K @@ -93,9 +107,11 @@ def _sparse_kernel_generators(self, dtype, dsize, base_args): if K_used > 500: ks, csz, blkx = 4, 20, 32 args = base_args | {'ksplit': ks, 'csz': csz, 'blockx': blkx} - meta = {'block': (blkx, ks, 1), - 'shared': (ks - 1)*csz*blkx*dsize, - 'desc': f'cstream-ksplit/k{ks}-c{csz}-x{blkx}'} + meta = { + 'block': (blkx, ks, 1), + 'shared': (ks - 1)*csz*blkx*dsize, + 'desc': f'cstream-ksplit/k{ks}-c{csz}-x{blkx}', + } yield ('cstream-ksplit', args, meta) # Width-2 vector cstream for fp64 small-K @@ -104,8 +120,11 @@ def _sparse_kernel_generators(self, dtype, dsize, base_args): and (self.aligne is None or self.aligne % 2 == 0)): blkx = 128 args = base_args | {'blockx': blkx} - meta = {'block': (blkx, 1, 1), 'width': 2, - 'desc': f'cstream-w2/x{blkx}'} + meta = { + 'block': (blkx, 1, 1), + 'width': 2, + 'desc': f'cstream-w2/x{blkx}', + } yield ('cstream-w2', args, meta) def _dense_kernel_generators(self, dtype, dsize, base_args): @@ -127,10 +146,9 @@ def _dense_kernel_generators(self, dtype, dsize, base_args): for tpl, nn, w in dense_configs: if (n_per_cta := 8 * nn * w) > self.n: continue - setup = self._dense_mma_setup(nn=nn, warps_per_cta=w) + setup = self._dense_mma_setup(nn, w, block_steal) blkx = 32 * w - args = (base_args | {'warps_per_cta': w, 'nn': nn, - 'block_stealing': block_steal} | setup) + args = base_args | setup meta = { 'block': (blkx, 1, 1), 'grid': (-(-self.n // n_per_cta), 1, 1), @@ -151,14 +169,14 @@ def _dense_ws_kernel_generators(self, dtype, dsize, base_args): if (n_per_cta := 8 * nn * w) > self.n: continue - setup = self._dense_mma_setup(nn=nn, warps_per_cta=w) - ws_setup = self._dense_ws_setup(setup, n_comp_warps=w) + setup = self._dense_mma_setup(nn, w, True) + ws_setup = self._dense_ws_setup(setup, w) if ws_setup['dynm_total_bytes'] > dynamic_max: continue blkx = 32 * (w + 2) - args = base_args | {'nn': nn} | setup | ws_setup + args = base_args | setup | ws_setup meta = { 'block': (blkx, 1, 1), 'grid': (-(-self.n // n_per_cta), 1, 1), @@ -184,11 +202,11 @@ def _dsmem_alloc(regions, mbars, align=16): return out, total @classmethod - def _dense_ws_setup(cls, setup, *, n_comp_warps): + def _dense_ws_setup(cls, setup, n_comp_warps): n_per_cta = setup['n_per_cta'] b_tile_bytes = setup['k_pad'] * n_per_cta * 8 c_tile_bytes = setup['m_pad'] * n_per_cta * 8 - a_bytes = setup['a_elems'] * 8 + a_bytes = setup['m_tiles'] * setup['k_tiles'] * 32 * 8 regions = [('b1', b_tile_bytes), ('b2', b_tile_bytes), ('c', c_tile_bytes), ('a', a_bytes), ('wid', 16)] @@ -196,7 +214,7 @@ def _dense_ws_setup(cls, setup, *, n_comp_warps): 'steal', 'wid_new', 'wid_used') offsets, dynm_total_bytes = cls._dsmem_alloc(regions, mbars) - return offsets | { + args = { 'n_comp_warps': n_comp_warps, 'blockx_total': 32 * (n_comp_warps + 2), 'prod_warp': n_comp_warps, @@ -208,7 +226,9 @@ def _dense_ws_setup(cls, setup, *, n_comp_warps): 'dynm_total_bytes': dynm_total_bytes, } - def _dense_mma_setup(self, *, nn, warps_per_cta): + return offsets | args + + def _dense_mma_setup(self, nn, warps_per_cta, block_steal): a = self.A m, k = a.shape m_tiles = (m + 7) // 8 @@ -218,9 +238,9 @@ def _dense_mma_setup(self, *, nn, warps_per_cta): # A in DMMA-fragment layout: lane l -> A[mt*8 + l//4][kt*4 + l%4] # i.e. an (m_tiles, k_tiles) grid of row-major 8x4 tiles, packed as # uint64 - a_pad = np.zeros((m_tiles*8, k_tiles*4), dtype=np.float64) + a_pad = np.zeros((m_tiles*8, k_tiles*4)) a_pad[:m, :k] = a - tiles = a_pad.reshape(m_tiles, 8, k_tiles, 4).transpose(0, 2, 1, 3) + tiles = a_pad.reshape(m_tiles, 8, k_tiles, 4).swapaxes(1, 2) a_u64 = [f'0x{u:016x}' for u in tiles.view(np.uint64).ravel()] n_per_warp = 8 * nn @@ -232,13 +252,14 @@ def pm_runtime(mt): return (mt + 1) * 8 > m return { + 'warps_per_cta': warps_per_cta, + 'nn': nn, 'm_tiles': m_tiles, 'k_tiles': k_tiles, 'k_rem': k_rem, 'm_pad': m_tiles * 8, 'k_pad': k_tiles * 4, 'a_u64': a_u64, - 'a_elems': m_tiles * k_tiles * 32, 'n_per_warp': n_per_warp, 'n_per_cta': n_per_cta, 'frag_stride_bytes': 32 * 8, @@ -248,6 +269,7 @@ def pm_runtime(mt): 'c_ntile_stride': 8 * 8, 'n_col_aligned': n_col_aligned, 'pm_runtime': pm_runtime, + 'block_stealing': block_steal, } @staticmethod