Skip to content

Commit a80f3db

Browse files
committed
Fix more bugs; make work
1 parent 10c57dc commit a80f3db

2 files changed

Lines changed: 21 additions & 24 deletions

File tree

tensorforge/backend/instructions/compute/multilinear.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from tensorforge.common.basic_types import Datatype
1111
from tensorforge.backend.writer import Writer
1212

13-
from .primitives import nvidia as nv
13+
from .primitives import nvidia as nvidia
1414
from .primitives import amd as amd
1515

1616
class MultilinearInstruction(ComputeInstruction):
@@ -66,10 +66,6 @@ def __init__(self,
6666

6767
self._analyze()
6868

69-
def _choose_lead_dim(self):
70-
self._shm_volume = 0
71-
pass
72-
7369
def _analyze(self):
7470
targetrank = 0
7571
for i, op in enumerate(self._ops):
@@ -207,7 +203,7 @@ def nonlead_writer(varlist):
207203
write_loops(self._context, writer, loopstack, nonlead_writer)
208204

209205
def _nonleading_dim_test(self, writer: Writer):
210-
can_use = self._context.get_vm().get_hw_descr().vendor == 'amd'
206+
can_use = self._context.get_vm().get_hw_descr().vendor in ['amd']
211207
can_use &= len(self._ops) == 2
212208

213209
if can_use:
@@ -222,6 +218,7 @@ def _nonleading_dim_test(self, writer: Writer):
222218
N *= mx - mi
223219

224220
M *= -(-(self._ns[0][1] - self._ns[0][0]) // self._num_threads)
221+
Mx = (self._ns[0][1] - self._ns[0][0])
225222

226223
def unwindJ(j):
227224
idx = [None]
@@ -236,12 +233,15 @@ def unwindI(i):
236233
idx = [LeadIndex(i % size + self._ns[0][0] // self._num_threads, self._num_threads, 1)]
237234
return idx
238235

236+
# TODO: remove
237+
kx = self._ks[0][0]
238+
239239
def unwindK(k, full):
240240
size = self._ks[0][1] - self._ks[0][0]
241241
if full:
242242
idx = [k % size + self._ks[0][0]]
243243
else:
244-
sizeL = -(-size // self._num_threads)
244+
sizeL = -(-(size + kx) // self._num_threads)
245245
idx = [LeadIndex(k % sizeL + self._ks[0][0] // self._num_threads, self._num_threads, 1)]
246246
k //= size
247247
for mi, mx in self._ks[1:]:
@@ -250,9 +250,6 @@ def unwindK(k, full):
250250
k //= size
251251
return idx
252252

253-
# TODO: remove
254-
kx = self._ks[0][0]
255-
256253
def unwindOp(i, j, k, opid, full):
257254
iidx = unwindI(i)
258255
jidx = unwindJ(j)
@@ -301,7 +298,7 @@ def A(writer, var, i, k):
301298
if self._context.get_vm().get_hw_descr().vendor == 'amd':
302299
amd.matmul(writer, C, A, B, M, N, K, kx, self._num_threads, self._dest.datatype, sparse, self._context)
303300
elif self._context.get_vm().get_hw_descr().vendor == 'nvidia':
304-
return nvidia.matmul(writer, C, A, B, M, N, K, kx, self._num_threads, self._dest.datatype, sparse, self._context, 'TODO', 0)
301+
return nvidia.matmul(writer, C, A, B, Mx, N, K, kx, self._num_threads, self._dest.datatype, sparse, self._context, 'TODO', 0)
305302
return True
306303
return False
307304

tensorforge/backend/instructions/compute/primitives/nvidia.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ def threadrange(start, size):
447447
B(writer, f'{Breg}_{k//threads}_{jj}', j + jj, k // threads)
448448
for jj in range(min(atom.n, N - j), atom.n):
449449
writer(f'{atom.d.ctype()} {Breg}_{k//threads}_{jj}{"{}"};')
450-
for ix in range(0, M):
450+
for i in range(0, M, threads):
451451
with writer.AnonymousScope():
452452
writer(f'{atom.d.ctype()} {Creg}[{cregs}][{threads // atom.m}]{"{}"};')
453453
for k in range(0, K, threads):
@@ -471,40 +471,40 @@ def threadrange(start, size):
471471
writer(f'{atom.d.ctype()} {Breg2}_{kkk + jj * kregs} = {shmptr}[{boffs} + (threadIdx.x % {ktile}) + (threadIdx.x / {ktile} + {jj * ntile}) * {atom.k} + {kkk * ktile}];')
472472

473473
for kkk in range(0, min(atom.k, K - k - kk)):
474-
A(writer, f'{Areg}_{kkk}', ix, k + kk + kkk)
474+
A(writer, f'{Areg}_{kkk}', i // threads, k + kk + kkk)
475475
for kkk in range(min(atom.k, K - k - kk), atom.k):
476476
writer(f'{atom.d.ctype()} {Areg}_{kkk}{"{}"};')
477477

478-
for iix in range(0, threads, atom.m):
478+
for ii in range(0, min(threads, M - i), atom.m):
479479
with writer.AnonymousScope():
480480
writer('__syncwarp();')
481-
with threadrange(iix, atom.m):
481+
with threadrange(ii, atom.m):
482482
for kkk in range(0, atom.k):
483-
writer(f'{shmptr}[{aoffs} + (threadIdx.x - {iix}) % {atom.m} + {kkk * atom.m}] = {Areg}_{kkk};')
483+
writer(f'{shmptr}[{aoffs} + (threadIdx.x - {ii}) % {atom.m} + {kkk * atom.m}] = {Areg}_{kkk};')
484484
writer('__syncwarp();')
485485

486486
for kk in range(0, kregs):
487-
for ii in range(0, mregs):
488-
writer(f'{atom.d.ctype()} {Areg2}_{ii + kk * mregs} = {shmptr}[{aoffs} + (threadIdx.x / {ktile}) + (threadIdx.x % {ktile} + {kk * ktile}) * {atom.m} + {ii * mtile}];')
487+
for iii in range(0, mregs):
488+
writer(f'{atom.d.ctype()} {Areg2}_{iii + kk * mregs} = {shmptr}[{aoffs} + (threadIdx.x / {ktile}) + (threadIdx.x % {ktile} + {kk * ktile}) * {atom.m} + {iii * mtile}];')
489489

490-
atom.generate(writer, ctx, [f'{Areg2}_{i}' for i in range (aregs)], [f'{Breg2}_{i}' for i in range (bregs)], [f'{Creg}[{i}][{iix // atom.m}]' for i in range (cregs)])
490+
atom.generate(writer, ctx, [f'{Areg2}_{i}' for i in range (aregs)], [f'{Breg2}_{i}' for i in range (bregs)], [f'{Creg}[{i}][{ii // atom.m}]' for i in range (cregs)])
491491

492492
for jj in range(0, atom.n):
493493
writer(f'{atom.d.ctype()} {Creg}_{jj}{"{}"};')
494494

495-
for i in range(0, threads, atom.m):
495+
for ii in range(0, threads, atom.m):
496496
with writer.AnonymousScope():
497497
for jj in range(0, nregs * 2):
498-
for ii in range(0, mregs):
499-
writer(f'{shmptr}[{coffs} + threadIdx.x * 2 + {ii} + {jj * 64}] = {Creg}[{ii + mregs * jj}][{i // atom.m}];')
498+
for iii in range(0, mregs):
499+
writer(f'{shmptr}[{coffs} + threadIdx.x * 2 + {iii} + {jj * 64}] = {Creg}[{iii + mregs * jj}][{ii // atom.m}];')
500500

501501
writer('__syncwarp();')
502-
with threadrange(i, atom.m):
502+
with threadrange(ii, atom.m):
503503
for jj in range(0, atom.n):
504504
writer(f'{Creg}_{jj} = {shmptr}[{coffs} + (threadIdx.x % {atom.m}) * {atom.n} + {jj}];')
505505
writer('__syncwarp();')
506506

507507
for jj in range(0, min(atom.n, N - j)):
508-
C(writer, f'{Creg}_{jj}', ix, j + jj)
508+
C(writer, f'{Creg}_{jj}', i // threads, j + jj)
509509

510510
return True

0 commit comments

Comments
 (0)