@@ -447,7 +447,7 @@ def threadrange(start, size):
447447 B (writer , f'{ Breg } _{ k // threads } _{ jj } ' , j + jj , k // threads )
448448 for jj in range (min (atom .n , N - j ), atom .n ):
449449 writer (f'{ atom .d .ctype ()} { Breg } _{ k // threads } _{ jj } { "{}" } ;' )
450- for ix in range (0 , M ):
450+ for i in range (0 , M , threads ):
451451 with writer .AnonymousScope ():
452452 writer (f'{ atom .d .ctype ()} { Creg } [{ cregs } ][{ threads // atom .m } ]{ "{}" } ;' )
453453 for k in range (0 , K , threads ):
@@ -471,40 +471,40 @@ def threadrange(start, size):
471471 writer (f'{ atom .d .ctype ()} { Breg2 } _{ kkk + jj * kregs } = { shmptr } [{ boffs } + (threadIdx.x % { ktile } ) + (threadIdx.x / { ktile } + { jj * ntile } ) * { atom .k } + { kkk * ktile } ];' )
472472
473473 for kkk in range (0 , min (atom .k , K - k - kk )):
474- A (writer , f'{ Areg } _{ kkk } ' , ix , k + kk + kkk )
474+ A (writer , f'{ Areg } _{ kkk } ' , i // threads , k + kk + kkk )
475475 for kkk in range (min (atom .k , K - k - kk ), atom .k ):
476476 writer (f'{ atom .d .ctype ()} { Areg } _{ kkk } { "{}" } ;' )
477477
478- for iix in range (0 , threads , atom .m ):
478+ for ii in range (0 , min ( threads , M - i ) , atom .m ):
479479 with writer .AnonymousScope ():
480480 writer ('__syncwarp();' )
481- with threadrange (iix , atom .m ):
481+ with threadrange (ii , atom .m ):
482482 for kkk in range (0 , atom .k ):
483- writer (f'{ shmptr } [{ aoffs } + (threadIdx.x - { iix } ) % { atom .m } + { kkk * atom .m } ] = { Areg } _{ kkk } ;' )
483+ writer (f'{ shmptr } [{ aoffs } + (threadIdx.x - { ii } ) % { atom .m } + { kkk * atom .m } ] = { Areg } _{ kkk } ;' )
484484 writer ('__syncwarp();' )
485485
486486 for kk in range (0 , kregs ):
487- for ii in range (0 , mregs ):
488- writer (f'{ atom .d .ctype ()} { Areg2 } _{ ii + kk * mregs } = { shmptr } [{ aoffs } + (threadIdx.x / { ktile } ) + (threadIdx.x % { ktile } + { kk * ktile } ) * { atom .m } + { ii * mtile } ];' )
487+ for iii in range (0 , mregs ):
488+ writer (f'{ atom .d .ctype ()} { Areg2 } _{ iii + kk * mregs } = { shmptr } [{ aoffs } + (threadIdx.x / { ktile } ) + (threadIdx.x % { ktile } + { kk * ktile } ) * { atom .m } + { iii * mtile } ];' )
489489
490- atom .generate (writer , ctx , [f'{ Areg2 } _{ i } ' for i in range (aregs )], [f'{ Breg2 } _{ i } ' for i in range (bregs )], [f'{ Creg } [{ i } ][{ iix // atom .m } ]' for i in range (cregs )])
490+ atom .generate (writer , ctx , [f'{ Areg2 } _{ i } ' for i in range (aregs )], [f'{ Breg2 } _{ i } ' for i in range (bregs )], [f'{ Creg } [{ i } ][{ ii // atom .m } ]' for i in range (cregs )])
491491
492492 for jj in range (0 , atom .n ):
493493 writer (f'{ atom .d .ctype ()} { Creg } _{ jj } { "{}" } ;' )
494494
495- for i in range (0 , threads , atom .m ):
495+ for ii in range (0 , threads , atom .m ):
496496 with writer .AnonymousScope ():
497497 for jj in range (0 , nregs * 2 ):
498- for ii in range (0 , mregs ):
499- writer (f'{ shmptr } [{ coffs } + threadIdx.x * 2 + { ii } + { jj * 64 } ] = { Creg } [{ ii + mregs * jj } ][{ i // atom .m } ];' )
498+ for iii in range (0 , mregs ):
499+ writer (f'{ shmptr } [{ coffs } + threadIdx.x * 2 + { iii } + { jj * 64 } ] = { Creg } [{ iii + mregs * jj } ][{ ii // atom .m } ];' )
500500
501501 writer ('__syncwarp();' )
502- with threadrange (i , atom .m ):
502+ with threadrange (ii , atom .m ):
503503 for jj in range (0 , atom .n ):
504504 writer (f'{ Creg } _{ jj } = { shmptr } [{ coffs } + (threadIdx.x % { atom .m } ) * { atom .n } + { jj } ];' )
505505 writer ('__syncwarp();' )
506506
507507 for jj in range (0 , min (atom .n , N - j )):
508- C (writer , f'{ Creg } _{ jj } ' , ix , j + jj )
508+ C (writer , f'{ Creg } _{ jj } ' , i // threads , j + jj )
509509
510510 return True
0 commit comments