diff --git a/test/test_ops_asm.py b/test/test_ops_asm.py new file mode 100644 index 0000000000000..18063516fffab --- /dev/null +++ b/test/test_ops_asm.py @@ -0,0 +1,801 @@ +import time, math, unittest, functools, os, torch +import numpy as np +from typing import List, Callable +import warnings +from tinygrad.helpers import DISABLE_COMPILER_CACHE, getenv, IMAGE, DEBUG, CI, Context, TRANSCENDENTAL, DEVECTORIZE, OSX, Context +from tinygrad.helpers import getenv, IMAGE, DEBUG, CI, Context, TRANSCENDENTAL, OSX, AMD_LLVM +from tinygrad import Tensor, Device, dtypes +from tinygrad.tensor import _to_np_dtype +from tinygrad.device import is_dtype_supported +from tinygrad.renderer.asm import Arch + +def skipU(flag: str): + if os.environ.get(flag): + return lambda func: func + return unittest.skip("") + +if getenv("TINY_BACKEND"): + import tinygrad.frontend.torch # noqa: F401 # pylint: disable=unused-import + torch.set_default_device("tiny") + +FORWARD_ONLY = getenv("FORWARD_ONLY", 1) +PRINT_TENSORS = getenv("PRINT_TENSORS", 0) +def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, grad_atol=1e-4, grad_rtol=1e-3, + forward_only=False, vals=None, low=-2, high=2): + if tinygrad_fxn is None: tinygrad_fxn = torch_fxn + ts, tst = prepare_test_op(low, high, shps, vals, forward_only) + + st = time.monotonic() + out = torch_fxn(*ts) + torch_fp = time.monotonic() - st + + # move inputs to a different device, test the device of intermediate tensors are correct + #if mt:=getenv("MOVE_TENSOR", ""): for t in tst: t.to_(mt) + + st = time.monotonic() + ret = tinygrad_fxn(*tst).realize() + tinygrad_fp = time.monotonic() - st + + def compare(s, tinygrad_output, torch_output, atol, rtol): + if PRINT_TENSORS: print(s, tinygrad_output, torch_output) + try: + assert tinygrad_output.shape == torch_output.shape, f"shape mismatch: tinygrad={tinygrad_output.shape} | torch={torch_output.shape}" + assert tinygrad_output.dtype == torch_output.dtype, f"dtype mismatch: tinygrad={tinygrad_output.dtype} | torch={torch_output.dtype}" + if np.issubdtype(tinygrad_output.dtype, np.floating): + np.testing.assert_allclose(tinygrad_output, torch_output, atol=atol, rtol=rtol) + else: + np.testing.assert_equal(tinygrad_output, torch_output) + except Exception as e: + raise Exception(f"{s} failed shape {tinygrad_output.shape}: {e}") + + if DEBUG >= 6: + np.set_printoptions(linewidth=200, suppress=True) + print(ret.numpy()) + print(out.detach().cpu().numpy()) + compare("forward pass", ret.numpy(), out.detach().cpu().numpy(), atol=atol, rtol=rtol) + + torch_fbp, tinygrad_fbp = np.nan, np.nan + if not forward_only and not FORWARD_ONLY and ts and tst: + st = time.monotonic() + torch_grads = torch.autograd.grad(torch_fxn(*ts).sum(), ts) + torch_fbp = time.monotonic() - st + + st = time.monotonic() + # NOTE: we now have to recompute the forward pass since we realized it + tiny_grads = tinygrad_fxn(*tst).sum().gradient(*tst) + Tensor.realize(*tiny_grads) + tinygrad_fbp = time.monotonic() - st + + for i, (t, torch_grad) in enumerate(zip(tiny_grads, torch_grads)): + compare(f"backward pass tensor {i}", t.numpy(), torch_grad.detach().cpu().numpy(), atol=grad_atol, rtol=grad_rtol) + + if not CI: + print("\ntesting %40r torch/tinygrad fp: %.2f / %.2f ms bp: %.2f / %.2f ms " % \ + (shps, torch_fp*1000, tinygrad_fp*1000, torch_fbp*1000, tinygrad_fbp*1000), end="") + +def prepare_test_op(low, high, shps, vals, forward_only=False): + if shps is None: + ts = [torch.tensor(x, requires_grad=(not forward_only)) for x in vals] + else: + np.random.seed(0) + np_data = [np.random.uniform(low=low, high=high, size=size).astype(_to_np_dtype(dtypes.default_float)) for size in shps] + if os.environ.get("INPUT_BYTES"): + print(f"{np_data=}") + b = np_data[0].tobytes() + for _b in b: print(f"{_b:#x}", end=", ") + print() + ts = [torch.tensor(data, requires_grad=(not forward_only)) for data in np_data] + for i in range(len(ts)): + # NOTE: torch default int64 for python ints input + if ts[i].dtype == torch.int64: ts[i] = ts[i].type(torch.int32) + tst = [Tensor(x.detach().cpu().numpy(), requires_grad=(not forward_only and not FORWARD_ONLY)) for x in ts] + return ts, tst + + +class TestOps(unittest.TestCase): + def test_full(self): + a = Tensor.full((4, 4), 20, dtype=dtypes.int).contiguous().realize() + np.testing.assert_equal(a.numpy(), np.full((4,4), 20)) + def test_full_int64(self): + a = Tensor.full((4, 4), 20, dtype=dtypes.int64).contiguous().realize() + np.testing.assert_equal(a.numpy(), np.full((4,4), 20, dtype=np.int64)) + def test_zeros(self): + a = Tensor.zeros(4, 4, dtype=dtypes.int32).contiguous().realize() + np.testing.assert_equal(a.numpy(), np.zeros((4,4), dtype=np.int32)) + def test_full_float32(self): + a = Tensor.full((4,4), 20.0, dtype=dtypes.float32).contiguous().numpy() + np.testing.assert_equal(a, np.full((4,4), 20.0, dtype=np.float32)) + + @unittest.skip("need to handle MOD") + def test_eye(self): + print(Tensor.eye(10).numpy()) + + def test_split(self): + tensors = Tensor.arange(16).reshape((4,4)).split((2,2)) + ret = tensors[1].tolist() + assert ret == [ + [8, 9, 10, 11], + [12, 13, 14, 15] + ] + + def test_chunk(self): + t = Tensor.arange(13).repeat((8, 1)) + print(f"{t.shape=}") + ts = t.chunk(6, 1) + for _t in ts: print(f"{_t.shape=}") + print(ts[0].numpy()) + + def test_meshgrid(self): + x, y = Tensor([1, 2, 3]), Tensor([4, 5, 6]) + grid_x, grid_y = x.meshgrid(y) + grid_x, grid_y = x.meshgrid(y, indexing="ij") + print(grid_x.numpy()) + #print(grid_y.numpy()) + + def test_arange(self): + helper_test_op([], lambda: torch.arange(10, dtype=torch.int32), lambda: Tensor.arange(10), forward_only=True) + helper_test_op([], lambda: torch.arange(36, dtype=torch.int32), lambda: Tensor.arange(36), forward_only=True) + helper_test_op([], lambda: torch.arange(5, 10, 3, dtype=torch.int32), lambda: Tensor.arange(5, 10, 3), forward_only=True) + helper_test_op([], lambda: torch.arange(10, 5, -3, dtype=torch.int32), lambda: Tensor.arange(10, 5, -3), forward_only=True) + helper_test_op([], lambda: torch.arange(11, 5, -3, dtype=torch.int32), lambda: Tensor.arange(11, 5, -3), forward_only=True) + helper_test_op([], lambda: torch.arange(1, 78, 2, dtype=torch.int32), lambda: Tensor.arange(1, 78, 2), forward_only=True) + + def test_arange_float(self): + helper_test_op([], lambda: torch.arange(5.5, 175.5, 2.5), lambda: Tensor.arange(5.5, 175.5, 2.5), forward_only=True) + end = 164 #164 would fail, 163 passes + helper_test_op([], lambda: torch.arange(5.5, end, 2.5), + lambda: Tensor.arange(5.5, end, 2.5), forward_only=True) + + def test_argmax(self): + # check if it returns the first index for multiple occurences + helper_test_op(None, lambda x: x.argmax().type(torch.int32), lambda x: x.argmax(), forward_only=True, vals=[[2, 2]]) + helper_test_op(None, lambda x: x.argmax().type(torch.int32), lambda x: x.argmax(), forward_only=True, vals=[[1, 2, 2]]) + np.testing.assert_equal(Tensor([2,2]).argmax().numpy(), 0) + np.testing.assert_equal(Tensor([1,2,2]).argmax().numpy(), 1) + helper_test_op([(10,20)], lambda x: x.argmax().type(torch.int32), lambda x: x.argmax(), forward_only=True) + helper_test_op([(10,20)], lambda x: x.argmax(0, False).type(torch.int32), lambda x: x.argmax(0, False), forward_only=True) + helper_test_op([(10,20)], lambda x: x.argmax(1, False).type(torch.int32), lambda x: x.argmax(1, False), forward_only=True) + helper_test_op([(10,20)], lambda x: x.argmax(1, True).type(torch.int32), lambda x: x.argmax(1, True), forward_only=True) + helper_test_op(None, lambda x: x.argmax().type(torch.int32), lambda x: x.argmax(), forward_only=True, vals=[[0, -2**31]]) + helper_test_op(None, lambda x: x.argmax().type(torch.int32), lambda x: x.argmax(), forward_only=True, vals=[[-2**31, 0]]) + helper_test_op(None, lambda x: x.type(torch.int32).argmax().type(torch.int32), lambda x: x.argmax(), forward_only=True, vals=[[False, True]]) + helper_test_op(None, lambda x: x.type(torch.int32).argmax().type(torch.int32), lambda x: x.argmax(), forward_only=True, vals=[[True, False]]) + helper_test_op(None, lambda x: (~x).argmax().type(torch.int32), lambda x: (~x).argmax(), forward_only=True, vals=[[2, 2]]) + + def test_max(self): + helper_test_op([(45,3)], lambda x: x.max()) + helper_test_op([(45,3)], lambda x: x.max().mul(0.5)) + helper_test_op(None, lambda x: x.max().mul(0.5), vals=[[[1.0,1.0,0.0,1.0]],]) + helper_test_op(None, lambda x: x.max().mul(0.5), vals=[[[1.0,1.0,0.0,1.0]],]) + helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0], lambda x: x.max(axis=1)) + helper_test_op([()], lambda x: x.max()) + helper_test_op(None, lambda x: x.max(), forward_only=True, vals=[[False, True]]) + helper_test_op(None, lambda x: x.max(), forward_only=True, vals=[[True, False]]) + helper_test_op(None, lambda x: x.max(), forward_only=True, vals=[[0, -2**31]]) + helper_test_op(None, lambda x: x.max(), forward_only=True, vals=[[-2**31, 0]]) + + def test_argsort(self): + for dim in [-1, 0, 1]: + for descending in [True, False]: + helper_test_op([(8,8,6)], lambda x: torch.argsort(x, dim=dim, descending=descending, stable=True).type(torch.int32), + lambda x: x.argsort(dim, descending), forward_only=True) + + + def test_linespace(self): + print(Tensor.linspace(5, 10, 3).numpy()) + + def test_abs(self): + with Context(NOOPT=1): helper_test_op([(2,2)], torch.abs, Tensor.abs) + with Context(NOOPT=0): helper_test_op([(8,8)], torch.abs, Tensor.abs) + helper_test_op([(45,65)], torch.abs, Tensor.abs) + + def _test_abs(self, data, dtype): + a = Tensor(data, dtype=dtype, device="asm") + np.testing.assert_equal(a.abs().numpy(), np.abs(np.array(data).astype(_to_np_dtype(dtype)))) + + def test_abs_f32(self): + self._test_abs([-1, 0, 2, -4], dtypes.float32) + with Context(NOOPT=1): self._test_abs([-1, 0, 2, -4], dtypes.float32) + def test_abs_f64(self): + self._test_abs([-1, 0, 2, -4], dtypes.float64) + with Context(NOOPT=1): self._test_abs([-1, 0, 2, -4], dtypes.float64) + def test_abs_i32(self): + self._test_abs([-1, 0, 2, -4], dtypes.int32) + with Context(NOOPT=1): self._test_abs([-1, 0, 2, -4], dtypes.int32) + def test_abs_i64(self): + self._test_abs([-1, 0, 2, -4], dtypes.int64) + with Context(NOOPT=1): self._test_abs([-1, 0, 2, -4], dtypes.int64) + + def test_acos_noopt(self): + with Context(NOOPT=1): + helper_test_op([(2,2)], lambda x: x.acos(), low=-1, high=1) + def test_acos(self): + helper_test_op([(4,)], lambda x: x.acos(), low=-1, high=1) + helper_test_op([(45,65)], lambda x: x.acos(), low=-1, high=1) + def test_acos_large(self): + helper_test_op([(20,)], lambda x: x.acos(), low=-300, high=-297) + helper_test_op([(45,65)], lambda x: x.acos(), low=-300, high=-297) + helper_test_op([(45,65)], lambda x: x.acos(), low=300, high=303) + + def test_sum(self): + np_ret = np.ones((16, 16)).sum(axis=1) + tg_ret = Tensor.ones(16, 16, dtype=dtypes.int).sum(axis=1).numpy() + np.testing.assert_equal(tg_ret, np_ret) + + def test_where(self): + a = Tensor([1, 2, 3]) + b = (a > 2).where(8, 9) + assert b.tolist() == [9, 9, 8] + + def test_matmul_int64(self): + with Context(DEBUG=0): + a = Tensor.arange(12, device="PYTHON", dtype=dtypes.int64).reshape((3,4)).realize() + b = Tensor.arange(8, device="PYTHON", dtype=dtypes.int64).reshape((4,2)).realize() + a = a.to(Device.DEFAULT) + b = b.to(Device.DEFAULT) + c = a.dot(b) + print(c.numpy()) + np.testing.assert_equal(c.numpy(), np.array([ + [28, 34], + [76, 98], + [124, 162] + ], dtype=np.int64)) + def test_matmul_int64_noopt(self): + with Context(DEBUG=0): + a = Tensor.arange(12, device="PYTHON", dtype=dtypes.int64).reshape((3,4)).realize() + b = Tensor.arange(8, device="PYTHON", dtype=dtypes.int64).reshape((4,2)).realize() + a = a.to(Device.DEFAULT) + b = b.to(Device.DEFAULT) + c = a.dot(b) + with Context(NOOPT=1): + np.testing.assert_equal(c.numpy(), np.array([ + [28, 34], + [76, 98], + [124, 162] + ], dtype=np.int64)) + def test_matmul_int32(self): + with Context(DEBUG=0): + a = Tensor.arange(12, device="PYTHON", dtype=dtypes.int32).reshape((3,4)).realize() + b = Tensor.arange(8, device="PYTHON", dtype=dtypes.int32).reshape((4,2)).realize() + a = a.to(Device.DEFAULT) + b = b.to(Device.DEFAULT) + c = a.dot(b) + np.testing.assert_equal(c.numpy(), np.array([ + [28, 34], + [76, 98], + [124, 162] + ], dtype=np.int32)) + def test_matmul_int32_noopt(self): + with Context(DEBUG=0): + a = Tensor.arange(12, device="PYTHON", dtype=dtypes.int32).reshape((3,4)).realize() + b = Tensor.arange(8, device="PYTHON", dtype=dtypes.int32).reshape((4,2)).realize() + a = a.to(Device.DEFAULT) + print(f"{a.dtype=} {a.dtype.itemsize=}") + b = b.to(Device.DEFAULT) + c = a.dot(b) + with Context(NOOPT=1): + np.testing.assert_equal(c.numpy(), np.array([ + [28, 34], + [76, 98], + [124, 162] + ], dtype=np.int32)) + + def test_matmul_f32_noopt(self): + with Context(DEBUG=0): + a = Tensor.arange(12, device="PYTHON", dtype=dtypes.float32).reshape((3,4)).realize() + b = Tensor.arange(8, device="PYTHON", dtype=dtypes.float32).reshape((4,2)).realize() + a = a.to(Device.DEFAULT) + b = b.to(Device.DEFAULT) + c = a.dot(b) + with Context(NOOPT=1): + np.testing.assert_equal(c.numpy(), np.array([ + [28, 34], + [76, 98], + [124, 162] + ], dtype=np.float32)) + + def test_matmul_f32(self): + with Context(DEBUG=0): + a = Tensor.arange(12, device="PYTHON", dtype=dtypes.float32).reshape((3,4)).realize() + b = Tensor.arange(8, device="PYTHON", dtype=dtypes.float32).reshape((4,2)).realize() + a = a.to(Device.DEFAULT) + b = b.to(Device.DEFAULT) + c = a.dot(b) + with Context(NOOPT=1): + np.testing.assert_equal(c.numpy(), np.array([ + [28, 34], + [76, 98], + [124, 162] + ], dtype=np.float32)) + + def _test_matmul_f32_rand(self, shape_a, shape_b): + np.random.seed(0) + a = np.random.rand(*shape_a).astype(np.float32) + b = np.random.rand(*shape_b).astype(np.float32) + a_t = Tensor(a) + b_t = Tensor(b) + np_res = np.matmul(a, b) + with Context(NOOPT=1, DEBUG=0): + clang_res = a_t.to("cpu").dot(b_t.to("cpu")).numpy() + with Context(NOOPT=1): + asm_res = a_t.to("asm").dot(b_t.to("asm")).numpy() + np.testing.assert_allclose(clang_res, asm_res) + with Context(NOOPT=0): + asm_res_2 = a_t.to("asm").dot(b_t.to("asm")).numpy() + np.testing.assert_allclose(clang_res, asm_res_2) + + def test_matmul_f32_rand_3_2_4(self): + self._test_matmul_f32_rand((3,4), (4,2)) + + def test_matmul_f32_rand_3_2_8(self): + self._test_matmul_f32_rand((3,8), (8,2)) + + def test_matmul_f32_rand_3_2_16(self): + self._test_matmul_f32_rand((3,16), (16,2)) + + @unittest.skipUnless(os.environ.get("MANUAL"), "") + def test_matmul_f32_rand_small(self): + np.random.seed(0) + a = np.random.rand(2, 2).astype(np.float32) + np.save("../tg-dev/matmul6/a.npy", a) + b = np.random.rand(2, 2).astype(np.float32) + np.save("../tg-dev/matmul6/b.npy", b) + print(f"{a=}") + print(f"{b=}") + a_t = Tensor(a) + b_t = Tensor(b) + np_res = np.matmul(a, b) + np.save("../tg-dev/matmul6/np_res.npy", np_res) + with Context(NOOPT=1, DEBUG=5, SHOW_DISASM=1, CLANG_O_LEVEL=0, FFP=0): + clang_res = a_t.to("cpu").dot(b_t.to("cpu")).numpy() + np.save("../tg-dev/matmul6/clang_res.npy", clang_res) + with Context(NOOPT=1): + asm_res = a_t.to("asm").dot(b_t.to("asm")).numpy() + np.save("../tg-dev/matmul6/asm_res.npy", asm_res) + np.testing.assert_equal(clang_res, asm_res) + + @unittest.skipUnless(os.environ.get("MANUAL"), "speed test") + def test_matmul_f32_speed(self): + """ + (1024,512) @ (512, 256) + atol=1e-04 x86 asm:2.893 clang0:2.606 clang1:0.988 clang2:0.999 + arm asm:6.19 clang0:5.08 clang1:0.935 clang2:1.04 + """ + with Context(DEBUG=0): + np.random.seed(0) + a = np.random.rand(1024, 512).astype(np.float32) + b = np.random.rand(512, 256).astype(np.float32) + + with Context(NOOPT=1, BEAM=0, FFP=0): + with Context(CLANG_O_LEVEL=1, SHOW_DISASM=0): + repeats = 10 + a = Tensor(a) + b = Tensor(b) + with Context(DEBUG=2): + c_cpu = speedrun("clang", a.to("cpu").dot(b.to("cpu")), repeats) + with Context(DEBUG=6): + c_asm = speedrun("asm", a.to("asm").dot(b.to("asm")), repeats) + np.testing.assert_equal(c_asm, c_cpu, ) + def test_idiv(self): + helper_test_op(None, functools.partial(torch.div, rounding_mode="trunc"), Tensor.idiv, forward_only=True, + vals=[[-4, 7, 5, 4, -7, 8], [2, -3, 8, -2, 3, 5]]) + + def test_acosh(self): + helper_test_op([(2,)], lambda x: x.acosh(), grad_atol=1e-6) + helper_test_op([(45,65)], lambda x: x.acosh(), grad_atol=1e-3, grad_rtol=1e-2, low=-300, high=-297) + + def test_acosh_high(self): + helper_test_op([(45,65)], lambda x: x.acosh(), grad_atol=1e-6, low=300, high=303) + + def test_log(self): + helper_test_op([(45,65)], lambda x: x.log(), grad_atol=1e-6, low=300, high=303) + + def test_log2(self): + with Context(NOOPT=0): + helper_test_op([(1,)], lambda x: x.log2(), grad_atol=1e-6) + with Context(NOOPT=1): + helper_test_op([(45,65)], lambda x: x.log2(), grad_atol=1e-6) + + def test_log2_unroll(self): + with Context(NOOPT=0): + helper_test_op([(4,)], lambda x: x.log2(), grad_atol=1e-6) + + def test_recip(self): + with Context(NOOPT=1, TRANSCENDENTAL=1): + helper_test_op([(4,)], lambda x: x.reciprocal(), grad_atol=1e-6) + with Context(NOOPT=0, TRANSCENDENTAL=1): + helper_test_op([(4,)], lambda x: x.reciprocal(), grad_atol=1e-6) + + def test_and(self): + data = [[1,-8,1],[32,1,6]] + tor = torch.tensor(data, dtype=torch.int) + ten = Tensor(data, dtype=dtypes.int32) + helper_test_op([], lambda: tor&tor, lambda: ten&ten, forward_only=True) + helper_test_op([], lambda: tor&0x1337, lambda: ten&0x1337, forward_only=True) + helper_test_op([], lambda: 0x1337&tor, lambda: 0x1337&ten, forward_only=True) + + data = [[True, True, False, False], [True, False, True, False]] + tor0, tor1 = torch.tensor(data[0], dtype=torch.bool), torch.tensor(data[1], dtype=torch.bool) + ten0, ten1 = Tensor(data[0], dtype=dtypes.bool), Tensor(data[1], dtype=dtypes.bool) + helper_test_op([], lambda: tor0&tor1, lambda: ten0&ten1, forward_only=True) + + helper_test_op(None, lambda x: (1 < x) & (x < 2), forward_only=True, vals=[[1.2, 1.2, 1.2, 3.2]]) + def test_all(self): + helper_test_op([(3,4,5,6)], lambda x: x.all(), forward_only=True) + helper_test_op(None, lambda x: x.all(), vals=[[True, True]], forward_only=True) + helper_test_op(None, lambda x: x.all(), vals=[[True, False]], forward_only=True) + helper_test_op(None, lambda x: x.all(), vals=[[False, False]], forward_only=True) + helper_test_op([()], lambda x: x.all(), forward_only=True) + + def test_avg_pool2d(self): + shape = (32,2,111,28) + for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1)]: + with self.subTest(kernel_size=ksz): + helper_test_op([shape], + lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=ksz), + lambda x: Tensor.avg_pool2d(x, kernel_size=ksz), rtol=1e-5) + + # regression test for https://github.com/tinygrad/tinygrad/pull/7581 + helper_test_op([(1,1,8,8)], + lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=(1,2), padding=(0,1), stride=(5,1)), + lambda x: Tensor.avg_pool2d(x, kernel_size=(1,2), padding=(0,1), stride=(5,1)), rtol=1e-5) + + def test_avg_pool2d_ceil_mode(self): + shape = (1,1,6,6) + for ksz in [(3,3), 3, (3,2), 4]: + with self.subTest(kernel_size=ksz): + helper_test_op([shape], + lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=ksz, padding=1, stride=3, ceil_mode=True), + lambda x: Tensor.avg_pool2d(x, kernel_size=ksz, padding=1, stride=3, ceil_mode=True), rtol=1e-5) + + def test_avg_pool2d_padding(self): + shape = (32,2,111,28) + for ksz in [(2,2), (3,3), 2, 3, (3,2)]: + for p in [1, (1,0), (0,1)]: + with self.subTest(kernel_size=ksz, padding=p): + helper_test_op([shape], + lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=ksz, padding=p), + lambda x: Tensor.avg_pool2d(x, kernel_size=ksz, padding=p), rtol=1e-5) + with self.assertRaises(ValueError): + Tensor.avg_pool2d(Tensor.randn((32,2,111,28)), kernel_size=(2,2), padding=(1,1,1)) + + #no candidates left + def test_pool_sum(self): + shape = [(1,1,16,16,16)] + x, x2 = prepare_test_op(-2, 2, shape, True) + x2 = x2[0] + padding = [1,1,1,1,1,1] + axis = (-3, -2, -1) + kernel_size = (8,8,8) + stride = 5 + dilation = 1 + x2.ones_like().pad(padding)._pool(kernel_size, stride, dilation).sum(axis).realize() + #y = Tensor.avg_pool2d(x2, kernel_size=(8,8,8), stride=5, padding=1, count_include_pad=False) + #y.realize() + + def test_avg_pool3d_failure(self): + with Context(NOOPT=0): + helper_test_op([(1,1,16,16,16)], + lambda x: torch.nn.functional.avg_pool3d(x, kernel_size=(8,8,8), stride=5, padding=1, count_include_pad=False), + lambda x: Tensor.avg_pool2d(x, kernel_size=(8,8,8), stride=5, padding=1, count_include_pad=False), rtol=1e-5, forward_only=True) + + def test_sigmoid(self): + helper_test_op([(45,65)], torch.sigmoid, Tensor.sigmoid) + def test_sigmoid_extreme(self): + helper_test_op([(45,65)], torch.sigmoid, Tensor.sigmoid, low=300, high=400) + helper_test_op([(45,65)], torch.sigmoid, Tensor.sigmoid, low=-400, high=-300) + x = Tensor([300.0]) + self.assertAlmostEqual(x.sigmoid()[0].gradient(x)[0].item(), 0.0) + x = Tensor([-300.0]) + self.assertAlmostEqual(x.sigmoid()[0].gradient(x)[0].item(), 0.0) + + def test_sigmoid_alt_extreme(self): + def sigmoid(x:Tensor): return x.exp() / (1 + x.exp()) + x = Tensor([300.0]) + self.assertAlmostEqual(sigmoid(x)[0].gradient(x)[0].item(), 0.0) + x = Tensor([-300.0]) + self.assertAlmostEqual(sigmoid(x)[0].gradient(x)[0].item(), 0.0) + + def test_logsigmoid(self): + helper_test_op([(45,65)], torch.nn.functional.logsigmoid, Tensor.logsigmoid) + helper_test_op([()], torch.nn.functional.logsigmoid, Tensor.logsigmoid) + + def test_hardsigmoid(self): + helper_test_op([(45,65)], torch.nn.functional.hardsigmoid, Tensor.hardsigmoid) + helper_test_op([()], torch.nn.functional.hardsigmoid, Tensor.hardsigmoid) + def test_hardsigmoid_extreme(self): + helper_test_op([(45,65)], torch.sigmoid, Tensor.sigmoid, low=300, high=400) + helper_test_op([(45,65)], torch.sigmoid, Tensor.sigmoid, low=-400, high=-300) + def test_softplus(self): + helper_test_op([(45,65)], torch.nn.functional.softplus, Tensor.softplus, grad_atol=1e-6) + helper_test_op([(45,65)], lambda t: torch.nn.functional.softplus(t, beta=3), lambda t: Tensor.softplus(t, beta=3), grad_atol=1e-6) + helper_test_op([(45,65)], lambda t: torch.nn.functional.softplus(t, beta=1/3), lambda t: Tensor.softplus(t, beta=1/3), grad_atol=1e-6) + # # TODO: support threshold and enable this + # helper_test_op([(45,65)], torch.nn.functional.softplus, Tensor.softplus, grad_atol=1e-6, low=300, high=400) + helper_test_op([(45,65)], torch.nn.functional.softplus, Tensor.softplus, grad_atol=1e-6, low=-400, high=-300) + helper_test_op([()], torch.nn.functional.softplus, Tensor.softplus, grad_atol=1e-6) + + def test_cross_entropy_1(self): + r = "none" + #shape = [(32, 10), (32, 10)] + shape = [(5,4), (5,4)] + x1, x2 = prepare_test_op(-2, 2, shape, True) + x, y = x2 + with Context(NOOPT=0): + x.sigmoid().binary_crossentropy(y.clip(0,1), reduction=r).realize() + helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy(x.sigmoid(),y.clip(0,1)), + lambda x,y: x.sigmoid().binary_crossentropy(y.clip(0,1))) + + def test_binary_crossentropy(self): + helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy(x.sigmoid(),y.clip(0,1)), + lambda x,y: x.sigmoid().binary_crossentropy(y.clip(0,1))) + helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy_with_logits(x,y.clip(0,1)), + lambda x,y: x.binary_crossentropy_logits(y.clip(0,1))) + helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy_with_logits(x,y.clip(0,1)), + lambda x,y: x.sigmoid().binary_crossentropy(y.clip(0,1))) + helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy(x.sigmoid(),y.clip(0,1)), + lambda x,y: x.binary_crossentropy_logits(y.clip(0,1))) + def test_binary_crossentropy_reductions(self): + for r in ("mean", "sum", "none"): + helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy(x.sigmoid(), y.clip(0,1), reduction=r), + lambda x,y: x.sigmoid().binary_crossentropy(y.clip(0,1), reduction=r)) + helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy_with_logits(x, y.clip(0,1), reduction=r), + lambda x,y: x.binary_crossentropy_logits(y.clip(0,1), reduction=r)) + def test_binary_crossentropy_logits_pos_weights(self): + pos_weight = [0.25, 0.5, 0.75, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0] + helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy_with_logits(x,y.clip(0,1), + pos_weight=torch.tensor(pos_weight)), + lambda x,y: x.binary_crossentropy_logits(y.clip(0,1),pos_weight=Tensor(pos_weight))) + def test_cross_entropy_class_probabilities(self): + helper_test_op([(32,), (32,)], lambda x,y: torch.nn.functional.cross_entropy(x, y), lambda x,y: x.cross_entropy(y)) + helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.cross_entropy(x, y), lambda x,y: x.cross_entropy(y)) + helper_test_op([(32,4,4,4), (32,4,4,4)], lambda x,y: torch.nn.functional.cross_entropy(x, y), lambda x,y: x.cross_entropy(y)) + + def test_cross_entropy_2(self): + r = "none" + shape = [(32, 10), (32, 10)] + shape = [(5, 4), (5, 4)] + x1, x2 = prepare_test_op(-2, 2, shape, True) + x, y = x2 + with Context(NOOPT=0): + x.sigmoid().binary_crossentropy(y.clip(0,1)).realize() + + def test_pad_reflect_mode(self): + helper_test_op([(1,1,5,5)], lambda x: torch.nn.functional.pad(x, (0,2,3,2), mode="reflect"), lambda x: x.pad((0,2,3,2), mode="reflect")) + helper_test_op([(5,5,5)], lambda x: torch.nn.functional.pad(x, (0,2), mode="reflect"), lambda x: x.pad((0,2), mode="reflect")) + helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (-1,2,2,-1), mode="reflect"), lambda x: x.pad((-1,2,2,-1), mode="reflect")) + helper_test_op([(1,1,5,5)], lambda x: torch.nn.functional.pad(x, (3,-3,0,-3), mode="reflect"), lambda x: x.pad((3,-3,0,-3), mode="reflect")) + helper_test_op([(1,1,5,5)], lambda x: torch.nn.functional.pad(x, (3,-5,1,-5), mode="reflect"), lambda x: x.pad((3,-5,1,-5), mode="reflect")) + helper_test_op([(1,1,5,5)], lambda x: torch.nn.functional.pad(x, (0,0,0,-5), mode="reflect"), lambda x: x.pad((0,0,0,-5), mode="reflect")) + helper_test_op([(1,1,5,5)], lambda x: torch.nn.functional.pad(x, (4,4,0,4), mode="reflect"), lambda x:x.pad((4,4,0,4),mode="reflect")) + self.helper_test_exception([(1,1,5,5)], + lambda x: torch.nn.functional.pad(x, (3,5,0,0),mode="reflect"), lambda x: x.pad((3,5,0,0),mode="reflect"), + expected=(RuntimeError, ValueError)) + + def helper_test_exception(self, shps, torch_fxn, tinygrad_fxn, expected, forward_only=False, exact=False, vals=None, low=-1.5, high=1.5): + if getenv("MOCKGPU") and Device.DEFAULT == "NV": self.skipTest('helper_test_exception fails in CI CUDA') + ts, tst = prepare_test_op(low, high, shps, vals, forward_only) + with self.assertRaises(expected) as torch_cm: + torch_fxn(*ts) + with self.assertRaises(expected) as tinygrad_cm: + tinygrad_fxn(*tst) + if exact: self.assertEqual(str(torch_cm.exception), str(tinygrad_cm.exception)) + if not CI: print("\ntesting %40r torch/tinygrad exception: %s / %s" % (shps, torch_cm.exception, tinygrad_cm.exception), end="") + + def test_pad_reflect_mode(self): + helper_test_op([(1,1,5,5)], lambda x: torch.nn.functional.pad(x, (0,2,3,2), mode="reflect"), lambda x: x.pad((0,2,3,2), mode="reflect")) + helper_test_op([(5,5,5)], lambda x: torch.nn.functional.pad(x, (0,2), mode="reflect"), lambda x: x.pad((0,2), mode="reflect")) + helper_test_op([(1,1,5,5,5)], lambda x: torch.nn.functional.pad(x, (1,2,3,4,1,2), mode="reflect"), + lambda x: x.pad((1,2,3,4,1,2), mode="reflect")) + helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (-1,2,2,-1), mode="reflect"), lambda x: x.pad((-1,2,2,-1), mode="reflect")) + helper_test_op([(1,1,5,5)], lambda x: torch.nn.functional.pad(x, (3,-3,0,-3), mode="reflect"), lambda x: x.pad((3,-3,0,-3), mode="reflect")) + helper_test_op([(1,1,5,5)], lambda x: torch.nn.functional.pad(x, (3,-5,1,-5), mode="reflect"), lambda x: x.pad((3,-5,1,-5), mode="reflect")) + helper_test_op([(1,1,5,5)], lambda x: torch.nn.functional.pad(x, (0,0,0,-5), mode="reflect"), lambda x: x.pad((0,0,0,-5), mode="reflect")) + + # max pad size for reflect is exactly once: pad < input size + helper_test_op([(1,1,5,5)], lambda x: torch.nn.functional.pad(x, (4,4,0,4), mode="reflect"), lambda x:x.pad((4,4,0,4),mode="reflect")) + # raise error for relfection padding when: pad >= input size + self.helper_test_exception([(1,1,5,5)], + lambda x: torch.nn.functional.pad(x, (3,5,0,0),mode="reflect"), lambda x: x.pad((3,5,0,0),mode="reflect"), + expected=(RuntimeError, ValueError)) + @skipU("MANUAL") + def test_manual(self): + #shape = (1,1,5,5,5) + #np_data = np.random.uniform(low=-2, high=2, size=shape).astype(_to_np_dtype(dtypes.default_float)) + #x = Tensor(np_data).reshape((1,1,5,5,5)).pad((1,2,3,4,1,2), mode="reflect") + #print(x.numpy()) + #return + helper_test_op([(1,1,5,5,5)], lambda x: torch.nn.functional.pad(x, (1,2,3,4,1,2), mode="reflect"), + lambda x: x.pad((1,2,3,4,1,2), mode="reflect")) + + def test_all_axis(self): + helper_test_op([(3,4,5,6)], lambda x: x.all(axis=(1,2)), forward_only=True) + + + def test_broadcast_full(self): + for torch_op, tinygrad_op in [(torch.add, Tensor.add), (torch.sub, Tensor.sub), (torch.mul, Tensor.mul), + (torch.div, Tensor.div), (torch.pow, Tensor.pow)]: + for shapes in [((5,13,24,16), (5,1,24,1)), ((1,3,1,7,1), (2,1,5,1,8))]: + print(f"{tinygrad_op=} {shapes=}") + with self.subTest(op=torch_op.__name__, shapes=shapes): + if tinygrad_op != Tensor.pow: + helper_test_op(shapes, torch_op, tinygrad_op) + else: + helper_test_op(shapes, torch_op, tinygrad_op, low=0, high=3) + + def test_broadcast_pow(self): + tinygrad_op = Tensor.pow + torch_op = torch.pow + shapes = ((5, 13, 24, 16), (5, 1, 24, 1)) + s = 20 + shapes = ((5, 13, s, 16), (5, 1, s, 1)) + helper_test_op(shapes, torch_op, tinygrad_op, low=0, high=3) + + def test_cast(self): + helper_test_op([(3, 3)], lambda x: x.float()) + helper_test_op(None, lambda x: x.float(), vals=[[0, 1, 2, 3]], forward_only=True) + helper_test_op(None, lambda x: x.float(), vals=[[True, False]], forward_only=True) + helper_test_op([(3, 3)], lambda x: x.int(), forward_only=True) + helper_test_op([(3, 3)], lambda x: x.bool(), forward_only=True) + def test_all(self): + helper_test_op([(3,4,5,6)], lambda x: x.all(), forward_only=True) + helper_test_op(None, lambda x: x.all(), vals=[[True, True]], forward_only=True) + helper_test_op(None, lambda x: x.all(), vals=[[True, False]], forward_only=True) + helper_test_op(None, lambda x: x.all(), vals=[[False, False]], forward_only=True) + helper_test_op([()], lambda x: x.all(), forward_only=True) + + def test_cmp_lt_backwards(self): + Tensor.manual_seed(0) + tt = Tensor.randn(4, requires_grad=True) + (tt*(tt < 0)).sum().backward() + print(f"tinygrad: {tt.grad.numpy()=}") + t = torch.tensor(tt.numpy(), requires_grad=True) + (t*(t < 0)).sum().backward() + print(f"torch: {t.grad.cpu().numpy()=}") + np.testing.assert_allclose(tt.grad.numpy(), t.grad.cpu().numpy(), rtol=1e-5) + def test_cmp_ne_backwards(self): + # new grad zeroes these out + """ + t1 = torch.ones(4, requires_grad=True) + t2 = torch.ones(4, requires_grad=True) + self.assertRaises(RuntimeError, (t1 != t2).sum().backward) + tt1 = Tensor.ones(4, requires_grad=True) + tt2 = Tensor.ones(4, requires_grad=True) + self.assertRaises(RuntimeError, (tt1 != tt2).sum().backward) + """ + Tensor.manual_seed(0) + tt = Tensor.randn(4, requires_grad=True) + (tt*(tt != 0)).sum().backward() + t = torch.tensor(tt.numpy(), requires_grad=True) + (t*(t != 0)).sum().backward() + np.testing.assert_allclose(tt.grad.numpy(), t.grad.cpu().numpy(), rtol=1e-5) + + def test_logical_not(self): + helper_test_op(None, torch.logical_not, Tensor.logical_not, vals=[[True, False, True]], forward_only=True) + helper_test_op(None, torch.logical_not, Tensor.logical_not, + vals=[[1.,2.,0.,0.5]], forward_only=True) + + def test_min(self): + helper_test_op(None, + lambda x: x.type(torch.uint8).min(), + lambda x: x.cast(dtypes.uint8).min(), forward_only=True, vals=[[[0, 1, 2], [3, 4, 5]]]) + helper_test_op(None, + lambda x: x.type(torch.uint8).min(), + lambda x: x.cast(dtypes.uint8).min(), forward_only=True, vals=[[0, 128, 255, 64, 32, 16]]) + + def test_scatter_reduce(self): + b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False) + a = Tensor(b.detach().cpu().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False) + for reduce in ("sum", "prod", "mean", "amin", "amax"): + for dim in (-1,1,-3): + helper_test_op([(4,5,6), (4,5,6)], + lambda x,src: x.scatter_reduce(dim=dim, index=b, src=src, reduce=reduce), + lambda x,src: x.scatter_reduce(dim=dim, index=a, src=src, reduce=reduce), forward_only=True) + helper_test_op([(4,5,6), (4,5,6)], + lambda x,src: x.scatter_reduce(dim=dim, index=b, src=src, reduce=reduce, include_self=False), + lambda x,src: x.scatter_reduce(dim=dim, index=a, src=src, reduce=reduce, include_self=False), forward_only=True) + + @skipU("MANUAL") + def test_scatter_reduce2(self): + torch.manual_seed(0) + b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False) + a = Tensor(b.detach().cpu().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False) + reduce = "mean" + dim = -1 + #helper_test_op([(4,5,6), (4,5,6)], + # lambda x,src: x.scatter_reduce(dim=dim, index=b, src=src, reduce=reduce), + # lambda x,src: x.scatter_reduce(dim=dim, index=a, src=src, reduce=reduce), forward_only=True) + helper_test_op([(4,5,6), (4,5,6)], + lambda x,src: x.scatter_reduce(dim=dim, index=b, src=src, reduce=reduce, include_self=False), + lambda x,src: x.scatter_reduce(dim=dim, index=a, src=src, reduce=reduce, include_self=False), forward_only=True) + + def test_scatter_reduce_small(self): + torch.manual_seed(0) + b = torch.tensor([[0,1,0], [1,0,1]], dtype=torch.int64, requires_grad=False) + a = Tensor(b.detach().cpu().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False) + reduce = "mean" + dim = -1 + shape = (2,3) + helper_test_op([shape, shape], + lambda x,src: x.scatter_reduce(dim=dim, index=b, src=src, reduce=reduce, include_self=False), + lambda x,src: x.scatter_reduce(dim=dim, index=a, src=src, reduce=reduce, include_self=False), forward_only=True) + + def test_softmax(self): + helper_test_op([(45,65)], torch.nn.Softmax(dim=1), Tensor.softmax, atol=1e-7, grad_atol=1e-7) + helper_test_op([(9,)], torch.nn.Softmax(dim=0), Tensor.softmax, atol=1e-7, grad_atol=1e-7) + helper_test_op([()], torch.nn.Softmax(dim=0), Tensor.softmax, atol=1e-7, grad_atol=1e-7) + helper_test_op([()], torch.nn.Softmax(dim=-1), Tensor.softmax, atol=1e-7, grad_atol=1e-7) + + + def test_scaled_dot_product_attention_causal(self): + helper_test_op([(32,8,16,64), (32,8,16,64), (32,8,16,64)], + lambda x,y,z: torch.nn.functional.scaled_dot_product_attention(x,y,z,is_causal=True), + lambda x,y,z: Tensor.scaled_dot_product_attention(x,y,z,is_causal=True)) + + def test_params_6(self): + t1 = Tensor(np.array([10], dtype=np.int32), dtype=dtypes.int32) + t2 = Tensor(np.array([10], dtype=np.int32), dtype=dtypes.int32) + t3 = Tensor(np.array([10], dtype=np.int32), dtype=dtypes.int32) + t4 = Tensor(np.array([10], dtype=np.int32), dtype=dtypes.int32) + t5 = Tensor(np.array([10], dtype=np.int32), dtype=dtypes.int32) + t6 = Tensor(np.array([20], dtype=np.int32), dtype=dtypes.int32) + t7 = Tensor(np.array([30], dtype=np.int32), dtype=dtypes.int32) + t = (t1 + t2 + t3 + t4+t5+t6+t7).numpy() + assert t == [100] + print(t) + + def test_interpolate_bilinear2(self): + out_sz = (2, 1) + helper_test_op([(1,1,1,4)], + lambda x: torch.nn.functional.interpolate((10*x).relu().type(torch.uint8), size=out_sz, mode="bilinear"), + lambda x: Tensor.interpolate((10*x).relu().cast('uint8'), size=out_sz, mode="linear"), forward_only=True) + + def test_interpolate_bilinear(self): + out_sz = (10, 10) + helper_test_op([(2,3,64,64)], + lambda x: torch.nn.functional.interpolate((10*x).relu().type(torch.uint8), size=out_sz, mode="bilinear"), + lambda x: Tensor.interpolate((10*x).relu().cast('uint8'), size=out_sz, mode="linear"), forward_only=True) + + @skipU("MANUAL") + def test_gemm_fp16(self): + helper_test_op([(64,64), (64,64)], lambda x,y: x.half().matmul(y.half()), atol=5e-3, rtol=5e-3) + @skipU("MANUAL") + def test_manual(self): + helper_test_op([(64,64), (64,64)], lambda x,y: x.half().matmul(y.half()), atol=5e-3, rtol=5e-3) + +def speedrun(name: str, c: Tensor, repeat: int,) -> np.ndarray: + res = c.clone().numpy() + t0 = time.time() + for i in range(repeat): + c.clone().realize() + t1 = time.time() + print(f"Took {name} {(t1-t0)}s") + return res + +@unittest.skipUnless(os.environ.get("MANUAL"), "") +class TestMatmul(unittest.TestCase): + def _setup_data(self, shapeA, shapeB): + with Context(DEBUG=0): + np.random.seed(0) + a = np.random.rand(*shapeA).astype(np.float32) + b = np.random.rand(*shapeB).astype(np.float32) + return Tensor(a, device="cpu"), Tensor(b, device="cpu") + + def _clang(self, a: Tensor, b: Tensor): + with Context(DEBUG=4): + c_cpu = a.to("cpu").dot(b.to("cpu")).numpy() + return c_cpu + def _asm(self, a: Tensor, b: Tensor): + with Context(DEBUG=6): + c_asm = a.to("asm").dot(b.to("asm")).numpy() + return c_asm + def test_(self): + with Context(): + K = 2 + a, b = self._setup_data((3, K), (K, 4)) + if os.environ.get("MANUAL_CLANG"): + clang_res = self._clang(a, b) + if os.environ.get("MANUAL_ASM"): + asm_res = self._asm(a, b) + #np.testing.assert_equal(clang_res, asm_res) + diff --git a/test/test_ops_sass.py b/test/test_ops_sass.py new file mode 100644 index 0000000000000..1e479af028256 --- /dev/null +++ b/test/test_ops_sass.py @@ -0,0 +1,100 @@ +import time, math, unittest, functools, os, torch +import numpy as np +from typing import List, Callable +import warnings +from tinygrad.helpers import DISABLE_COMPILER_CACHE, getenv, IMAGE, DEBUG, CI, Context, TRANSCENDENTAL, DEVECTORIZE, OSX, Context +from tinygrad.helpers import getenv, IMAGE, DEBUG, CI, Context, TRANSCENDENTAL, OSX, AMD_LLVM +from tinygrad import Tensor, Device, dtypes +from tinygrad.tensor import _to_np_dtype +from tinygrad.device import is_dtype_supported + +def skipU(flag: str): + if os.environ.get(flag): + return lambda func: func + return unittest.skip("") + +if getenv("TINY_BACKEND"): + import tinygrad.frontend.torch # noqa: F401 # pylint: disable=unused-import + torch.set_default_device("tiny") + +FORWARD_ONLY = getenv("FORWARD_ONLY", 1) +PRINT_TENSORS = getenv("PRINT_TENSORS", 0) +def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, grad_atol=1e-4, grad_rtol=1e-3, + forward_only=False, vals=None, low=-2, high=2): + if tinygrad_fxn is None: tinygrad_fxn = torch_fxn + ts, tst = prepare_test_op(low, high, shps, vals, forward_only) + + st = time.monotonic() + out = torch_fxn(*ts) + torch_fp = time.monotonic() - st + + # move inputs to a different device, test the device of intermediate tensors are correct + #if mt:=getenv("MOVE_TENSOR", ""): for t in tst: t.to_(mt) + + st = time.monotonic() + ret = tinygrad_fxn(*tst).realize() + tinygrad_fp = time.monotonic() - st + + def compare(s, tinygrad_output, torch_output, atol, rtol): + if PRINT_TENSORS: print(s, tinygrad_output, torch_output) + try: + assert tinygrad_output.shape == torch_output.shape, f"shape mismatch: tinygrad={tinygrad_output.shape} | torch={torch_output.shape}" + assert tinygrad_output.dtype == torch_output.dtype, f"dtype mismatch: tinygrad={tinygrad_output.dtype} | torch={torch_output.dtype}" + if np.issubdtype(tinygrad_output.dtype, np.floating): + np.testing.assert_allclose(tinygrad_output, torch_output, atol=atol, rtol=rtol) + else: + np.testing.assert_equal(tinygrad_output, torch_output) + except Exception as e: + raise Exception(f"{s} failed shape {tinygrad_output.shape}: {e}") + + if DEBUG >= 6: + np.set_printoptions(linewidth=200, suppress=True) + print(ret.numpy()) + print(out.detach().cpu().numpy()) + compare("forward pass", ret.numpy(), out.detach().cpu().numpy(), atol=atol, rtol=rtol) + + torch_fbp, tinygrad_fbp = np.nan, np.nan + if not forward_only and not FORWARD_ONLY and ts and tst: + st = time.monotonic() + torch_grads = torch.autograd.grad(torch_fxn(*ts).sum(), ts) + torch_fbp = time.monotonic() - st + + st = time.monotonic() + # NOTE: we now have to recompute the forward pass since we realized it + tiny_grads = tinygrad_fxn(*tst).sum().gradient(*tst) + Tensor.realize(*tiny_grads) + tinygrad_fbp = time.monotonic() - st + + for i, (t, torch_grad) in enumerate(zip(tiny_grads, torch_grads)): + compare(f"backward pass tensor {i}", t.numpy(), torch_grad.detach().cpu().numpy(), atol=grad_atol, rtol=grad_rtol) + + if not CI: + print("\ntesting %40r torch/tinygrad fp: %.2f / %.2f ms bp: %.2f / %.2f ms " % \ + (shps, torch_fp*1000, tinygrad_fp*1000, torch_fbp*1000, tinygrad_fbp*1000), end="") + +def prepare_test_op(low, high, shps, vals, forward_only=False): + if shps is None: + ts = [torch.tensor(x, requires_grad=(not forward_only)) for x in vals] + else: + np.random.seed(0) + np_data = [np.random.uniform(low=low, high=high, size=size).astype(_to_np_dtype(dtypes.default_float)) for size in shps] + if os.environ.get("INPUT_BYTES"): + print(f"{np_data=}") + b = np_data[0].tobytes() + for _b in b: print(f"{_b:#x}", end=", ") + print() + ts = [torch.tensor(data, requires_grad=(not forward_only)) for data in np_data] + for i in range(len(ts)): + # NOTE: torch default int64 for python ints input + if ts[i].dtype == torch.int64: ts[i] = ts[i].type(torch.int32) + tst = [Tensor(x.detach().cpu().numpy(), requires_grad=(not forward_only and not FORWARD_ONLY)) for x in ts] + return ts, tst + +class TestOps(unittest.TestCase): + def test_add(self): + with Context(DEBUG=0): + _a = np.array([1.2, 1.3, 1.4]).astype(np.float32) + _b = np.array([1.2, 1.3, 1.4]).astype(np.float32) + a = Tensor(_a).realize() + b = Tensor(_b).realize() + np.testing.assert_equal((a*b).numpy(), _a*_b) diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 4199478080894..cbd0709cb1176 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -135,7 +135,7 @@ def __lt__(self, x): return self.value < x SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("RING", 1) PICKLE_BUFFERS, PROFILE, LRU = ContextVar("PICKLE_BUFFERS", 1), ContextVar("PROFILE", getenv("VIZ")), ContextVar("LRU", 1) CACHELEVEL, IGNORE_BEAM_CACHE, DEVECTORIZE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0), ContextVar("DEVECTORIZE", 1) -DISABLE_COMPILER_CACHE = ContextVar("DISABLE_COMPILER_CACHE", 0) +DISABLE_COMPILER_CACHE = ContextVar("DISABLE_COMPILER_CACHE", 1) DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES = ContextVar("DONT_REALIZE_EXPAND", 0), ContextVar("DONT_GROUP_REDUCES", 0) QUANTIZE, VALIDATE_WITH_CPU = ContextVar("QUANTIZE", 0), ContextVar("VALIDATE_WITH_CPU", 0) CORRECT_DIVMOD_FOLDING, FUSE_OPTIM = ContextVar("CORRECT_DIVMOD_FOLDING", 0), ContextVar("FUSE_OPTIM", 0) diff --git a/tinygrad/renderer/sass.py b/tinygrad/renderer/sass.py new file mode 100644 index 0000000000000..6aebf89f209f7 --- /dev/null +++ b/tinygrad/renderer/sass.py @@ -0,0 +1,644 @@ +from typing import cast, Callable +import struct +from collections import defaultdict +from tinygrad.codegen.opt import tc +from tinygrad.uop.ops import Ops, UOp, PatternMatcher, UPat, GroupOp +from tinygrad.dtype import dtypes, DType, PtrDType, AddrSpace +from tinygrad.renderer import Renderer +from tinygrad.renderer.cstyle import CUDARenderer +from tinygrad.helpers import flatten, get_single_element, prod + + +class SASSRenderer(Renderer): + device = "CUDA" + suffix = "SASS" + global_max, local_max, shared_max = CUDARenderer.global_max, CUDARenderer.local_max, CUDARenderer.shared_max + tc_sm80 = [x for x in tc.cuda_sm80 if x.dtype_in in [dtypes.half, dtypes.float]] + def __init__(self, arch:str, device="CUDA"): + self.device, self.arch = device, arch + def __reduce__(self): return self.__class__, (self.arch, self.device) + + def render(self, uops:list[UOp]) -> str: + return """ +// --------------------- FileHeader -------------------------- + // All file header info is kept as is (unless offset/size attributes) + // The original header flags is not complete, thus discarded. + // .headerflags @"EF_CUDA_TEXMODE_UNIFIED EF_CUDA_64BIT_ADDRESS EF_CUDA_SM86 EF_CUDA_VIRTUAL_SM(EF_CUDA_SM86)" + // .elftype @"ET_EXEC" + // + // + .__elf_ident_osabi 51 + .__elf_ident_abiversion 7 + .__elf_type ET_EXEC + .__elf_machine EM_CUDA + .__elf_version 129 // CUDA toolkit version + .__elf_entry 0 // entry point address + .__elf_phoff 0xa00 // program header offset, maybe updated by assembler + .__elf_shoff 0x700 // section header offset, maybe updated by assembler + .__elf_flags 0x560556 // Flags, SM_86(0x56), COMPUTE_86(0x56) + .__elf_ehsize 64 // elf header size + .__elf_phentsize 56 // program entry size + .__elf_phnum 3 // number of program entries + .__elf_shentsize 64 // section entry size + .__elf_shnum 12 // number of sections, currently no sections can be appended/removed + .__elf_shstrndx 1 // Section name string table index + + + //------------------------------------------------- + //------------ END of FileHeader ------------------ + //------------------------------------------------- + + + +// --------------------- -------------------------- + // there will always be an empty section at index 0 + .section "", 0, SHT_NULL + .__section_name 0x0 // offset in .shstrtab + .__section_type SHT_NULL + .__section_flags 0x0 + .__section_addr 0x0 + .__section_offset 0x0 // maybe updated by assembler + .__section_size 0x0 // maybe updated by assembler + .__section_link 0 + .__section_info 0x0 + .__section_entsize 0 + .align 0 // equivalent to set sh_addralign + +// --------------------- .shstrtab -------------------------- + .section ".shstrtab", 0, SHT_STRTAB + // all strings in .shstrtab section will be kept as is. + .__section_name 0x1 // offset in .shstrtab + .__section_type SHT_STRTAB + .__section_flags 0x0 + .__section_addr 0x0 + .__section_offset 0x40 // maybe updated by assembler + .__section_size 0xdb // maybe updated by assembler + .__section_link 0 + .__section_info 0x0 + .__section_entsize 0 + .align 1 // equivalent to set sh_addralign + // .shstrtab[0] = b'\x00' + /*0000*/ .byte 0x00 + + // .shstrtab[1] = b'.shstrtab\x00' + /*0001*/ .byte 0x2e, 0x73, 0x68, 0x73, 0x74, 0x72, 0x74, 0x61 + /*0009*/ .byte 0x62, 0x00 + + // .shstrtab[2] = b'.strtab\x00' + /*000b*/ .byte 0x2e, 0x73, 0x74, 0x72, 0x74, 0x61, 0x62, 0x00 + + // .shstrtab[3] = b'.symtab\x00' + /*0013*/ .byte 0x2e, 0x73, 0x79, 0x6d, 0x74, 0x61, 0x62, 0x00 + + // .shstrtab[4] = b'.symtab_shndx\x00' + /*001b*/ .byte 0x2e, 0x73, 0x79, 0x6d, 0x74, 0x61, 0x62, 0x5f + /*0023*/ .byte 0x73, 0x68, 0x6e, 0x64, 0x78, 0x00 + + // .shstrtab[5] = b'.nv.info\x00' + /*0029*/ .byte 0x2e, 0x6e, 0x76, 0x2e, 0x69, 0x6e, 0x66, 0x6f + /*0031*/ .byte 0x00 + + // .shstrtab[6] = b'.text.E_3\x00' + /*0032*/ .byte 0x2e, 0x74, 0x65, 0x78, 0x74, 0x2e, 0x45, 0x5f + /*003a*/ .byte 0x33, 0x00 + + // .shstrtab[7] = b'.nv.info.E_3\x00' + /*003c*/ .byte 0x2e, 0x6e, 0x76, 0x2e, 0x69, 0x6e, 0x66, 0x6f + /*0044*/ .byte 0x2e, 0x45, 0x5f, 0x33, 0x00 + + // .shstrtab[8] = b'.nv.shared.E_3\x00' + /*0049*/ .byte 0x2e, 0x6e, 0x76, 0x2e, 0x73, 0x68, 0x61, 0x72 + /*0051*/ .byte 0x65, 0x64, 0x2e, 0x45, 0x5f, 0x33, 0x00 + + // .shstrtab[9] = b'.nv.constant0.E_3\x00' + /*0058*/ .byte 0x2e, 0x6e, 0x76, 0x2e, 0x63, 0x6f, 0x6e, 0x73 + /*0060*/ .byte 0x74, 0x61, 0x6e, 0x74, 0x30, 0x2e, 0x45, 0x5f + /*0068*/ .byte 0x33, 0x00 + + // .shstrtab[10] = b'.rel.nv.constant0.E_3\x00' + /*006a*/ .byte 0x2e, 0x72, 0x65, 0x6c, 0x2e, 0x6e, 0x76, 0x2e + /*0072*/ .byte 0x63, 0x6f, 0x6e, 0x73, 0x74, 0x61, 0x6e, 0x74 + /*007a*/ .byte 0x30, 0x2e, 0x45, 0x5f, 0x33, 0x00 + + // .shstrtab[11] = b'.debug_frame\x00' + /*0080*/ .byte 0x2e, 0x64, 0x65, 0x62, 0x75, 0x67, 0x5f, 0x66 + /*0088*/ .byte 0x72, 0x61, 0x6d, 0x65, 0x00 + + // .shstrtab[12] = b'.rel.debug_frame\x00' + /*008d*/ .byte 0x2e, 0x72, 0x65, 0x6c, 0x2e, 0x64, 0x65, 0x62 + /*0095*/ .byte 0x75, 0x67, 0x5f, 0x66, 0x72, 0x61, 0x6d, 0x65 + /*009d*/ .byte 0x00 + + // .shstrtab[13] = b'.rela.debug_frame\x00' + /*009e*/ .byte 0x2e, 0x72, 0x65, 0x6c, 0x61, 0x2e, 0x64, 0x65 + /*00a6*/ .byte 0x62, 0x75, 0x67, 0x5f, 0x66, 0x72, 0x61, 0x6d + /*00ae*/ .byte 0x65, 0x00 + + // .shstrtab[14] = b'.nv.callgraph\x00' + /*00b0*/ .byte 0x2e, 0x6e, 0x76, 0x2e, 0x63, 0x61, 0x6c, 0x6c + /*00b8*/ .byte 0x67, 0x72, 0x61, 0x70, 0x68, 0x00 + + // .shstrtab[15] = b'.nv.prototype\x00' + /*00be*/ .byte 0x2e, 0x6e, 0x76, 0x2e, 0x70, 0x72, 0x6f, 0x74 + /*00c6*/ .byte 0x6f, 0x74, 0x79, 0x70, 0x65, 0x00 + + // .shstrtab[16] = b'.nv.rel.action\x00' + /*00cc*/ .byte 0x2e, 0x6e, 0x76, 0x2e, 0x72, 0x65, 0x6c, 0x2e + /*00d4*/ .byte 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x00 + + +// --------------------- .strtab -------------------------- + .section ".strtab", 0, SHT_STRTAB + // all strings in .strtab section will be kept as is. + .__section_name 0xb // offset in .shstrtab + .__section_type SHT_STRTAB + .__section_flags 0x0 + .__section_addr 0x0 + .__section_offset 0x11b // maybe updated by assembler + .__section_size 0xdf // maybe updated by assembler + .__section_link 0 + .__section_info 0x0 + .__section_entsize 0 + .align 1 // equivalent to set sh_addralign + // .strtab[0] = b'\x00' + /*0000*/ .byte 0x00 + + // .strtab[1] = b'.shstrtab\x00' + /*0001*/ .byte 0x2e, 0x73, 0x68, 0x73, 0x74, 0x72, 0x74, 0x61 + /*0009*/ .byte 0x62, 0x00 + + // .strtab[2] = b'.strtab\x00' + /*000b*/ .byte 0x2e, 0x73, 0x74, 0x72, 0x74, 0x61, 0x62, 0x00 + + // .strtab[3] = b'.symtab\x00' + /*0013*/ .byte 0x2e, 0x73, 0x79, 0x6d, 0x74, 0x61, 0x62, 0x00 + + // .strtab[4] = b'.symtab_shndx\x00' + /*001b*/ .byte 0x2e, 0x73, 0x79, 0x6d, 0x74, 0x61, 0x62, 0x5f + /*0023*/ .byte 0x73, 0x68, 0x6e, 0x64, 0x78, 0x00 + + // .strtab[5] = b'.nv.info\x00' + /*0029*/ .byte 0x2e, 0x6e, 0x76, 0x2e, 0x69, 0x6e, 0x66, 0x6f + /*0031*/ .byte 0x00 + + // .strtab[6] = b'.text.E_3\x00' + /*0032*/ .byte 0x2e, 0x74, 0x65, 0x78, 0x74, 0x2e, 0x45, 0x5f + /*003a*/ .byte 0x33, 0x00 + + // .strtab[7] = b'.nv.info.E_3\x00' + /*003c*/ .byte 0x2e, 0x6e, 0x76, 0x2e, 0x69, 0x6e, 0x66, 0x6f + /*0044*/ .byte 0x2e, 0x45, 0x5f, 0x33, 0x00 + + // .strtab[8] = b'.nv.shared.E_3\x00' + /*0049*/ .byte 0x2e, 0x6e, 0x76, 0x2e, 0x73, 0x68, 0x61, 0x72 + /*0051*/ .byte 0x65, 0x64, 0x2e, 0x45, 0x5f, 0x33, 0x00 + + // .strtab[9] = b'.rel.nv.constant0.E_3\x00' + /*0058*/ .byte 0x2e, 0x72, 0x65, 0x6c, 0x2e, 0x6e, 0x76, 0x2e + /*0060*/ .byte 0x63, 0x6f, 0x6e, 0x73, 0x74, 0x61, 0x6e, 0x74 + /*0068*/ .byte 0x30, 0x2e, 0x45, 0x5f, 0x33, 0x00 + + // .strtab[10] = b'.nv.constant0.E_3\x00' + /*006e*/ .byte 0x2e, 0x6e, 0x76, 0x2e, 0x63, 0x6f, 0x6e, 0x73 + /*0076*/ .byte 0x74, 0x61, 0x6e, 0x74, 0x30, 0x2e, 0x45, 0x5f + /*007e*/ .byte 0x33, 0x00 + + // .strtab[11] = b'.debug_frame\x00' + /*0080*/ .byte 0x2e, 0x64, 0x65, 0x62, 0x75, 0x67, 0x5f, 0x66 + /*0088*/ .byte 0x72, 0x61, 0x6d, 0x65, 0x00 + + // .strtab[12] = b'.rel.debug_frame\x00' + /*008d*/ .byte 0x2e, 0x72, 0x65, 0x6c, 0x2e, 0x64, 0x65, 0x62 + /*0095*/ .byte 0x75, 0x67, 0x5f, 0x66, 0x72, 0x61, 0x6d, 0x65 + /*009d*/ .byte 0x00 + + // .strtab[13] = b'.rela.debug_frame\x00' + /*009e*/ .byte 0x2e, 0x72, 0x65, 0x6c, 0x61, 0x2e, 0x64, 0x65 + /*00a6*/ .byte 0x62, 0x75, 0x67, 0x5f, 0x66, 0x72, 0x61, 0x6d + /*00ae*/ .byte 0x65, 0x00 + + // .strtab[14] = b'.nv.callgraph\x00' + /*00b0*/ .byte 0x2e, 0x6e, 0x76, 0x2e, 0x63, 0x61, 0x6c, 0x6c + /*00b8*/ .byte 0x67, 0x72, 0x61, 0x70, 0x68, 0x00 + + // .strtab[15] = b'.nv.prototype\x00' + /*00be*/ .byte 0x2e, 0x6e, 0x76, 0x2e, 0x70, 0x72, 0x6f, 0x74 + /*00c6*/ .byte 0x6f, 0x74, 0x79, 0x70, 0x65, 0x00 + + // .strtab[16] = b'.nv.rel.action\x00' + /*00cc*/ .byte 0x2e, 0x6e, 0x76, 0x2e, 0x72, 0x65, 0x6c, 0x2e + /*00d4*/ .byte 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x00 + + // .strtab[17] = b'E_3\x00' + /*00db*/ .byte 0x45, 0x5f, 0x33, 0x00 + + +// --------------------- .symtab -------------------------- + .section ".symtab", 0, SHT_SYMTAB + // all symbols in .symtab sections will be kept + // but the symbol size may be changed accordingly + .__section_name 0x13 // offset in .shstrtab + .__section_type SHT_SYMTAB + .__section_flags 0x0 + .__section_addr 0x0 + .__section_offset 0x200 // maybe updated by assembler + .__section_size 0xa8 // maybe updated by assembler + .__section_link 2 + .__section_info 0x6 + .__section_entsize 24 + .align 8 // equivalent to set sh_addralign + // Symbol[0] "": Container({'st_name': 0, 'st_info': Container({'bind': 'STB_LOCAL', 'type': 'STT_NOTYPE'}), 'st_other': Container({'local': 0, 'visibility': 'STV_DEFAULT'}), 'st_shndx': 'SHN_UNDEF', 'st_value': 0, 'st_size': 0}) + /*0000*/ .byte 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 + /*0008*/ .byte 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 + /*0010*/ .byte 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 + + // Symbol[1] ".text.E_3": Container({'st_name': 50, 'st_info': Container({'bind': 'STB_LOCAL', 'type': 'STT_SECTION'}), 'st_other': Container({'local': 0, 'visibility': 'STV_DEFAULT'}), 'st_shndx': 11, 'st_value': 0, 'st_size': 0}) + /*0018*/ .byte 0x32, 0x00, 0x00, 0x00, 0x03, 0x00, 0x0b, 0x00 + /*0020*/ .byte 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 + /*0028*/ .byte 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 + + // Symbol[2] ".nv.constant0.E_3": Container({'st_name': 110, 'st_info': Container({'bind': 'STB_LOCAL', 'type': 'STT_SECTION'}), 'st_other': Container({'local': 0, 'visibility': 'STV_DEFAULT'}), 'st_shndx': 10, 'st_value': 0, 'st_size': 0}) + /*0030*/ .byte 0x6e, 0x00, 0x00, 0x00, 0x03, 0x00, 0x0a, 0x00 + /*0038*/ .byte 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 + /*0040*/ .byte 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 + + // Symbol[3] ".debug_frame": Container({'st_name': 128, 'st_info': Container({'bind': 'STB_LOCAL', 'type': 'STT_SECTION'}), 'st_other': Container({'local': 0, 'visibility': 'STV_DEFAULT'}), 'st_shndx': 4, 'st_value': 0, 'st_size': 0}) + /*0048*/ .byte 0x80, 0x00, 0x00, 0x00, 0x03, 0x00, 0x04, 0x00 + /*0050*/ .byte 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 + /*0058*/ .byte 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 + + // Symbol[4] ".nv.callgraph": Container({'st_name': 176, 'st_info': Container({'bind': 'STB_LOCAL', 'type': 'STT_SECTION'}), 'st_other': Container({'local': 0, 'visibility': 'STV_DEFAULT'}), 'st_shndx': 7, 'st_value': 0, 'st_size': 0}) + /*0060*/ .byte 0xb0, 0x00, 0x00, 0x00, 0x03, 0x00, 0x07, 0x00 + /*0068*/ .byte 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 + /*0070*/ .byte 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 + + // Symbol[5] ".nv.rel.action": Container({'st_name': 204, 'st_info': Container({'bind': 'STB_LOCAL', 'type': 'STT_SECTION'}), 'st_other': Container({'local': 0, 'visibility': 'STV_DEFAULT'}), 'st_shndx': 8, 'st_value': 0, 'st_size': 0}) + /*0078*/ .byte 0xcc, 0x00, 0x00, 0x00, 0x03, 0x00, 0x08, 0x00 + /*0080*/ .byte 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 + /*0088*/ .byte 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 + + // Symbol[6] "E_3": Container({'st_name': 219, 'st_info': Container({'bind': 'STB_GLOBAL', 'type': 'STT_FUNC'}), 'st_other': Container({'local': 0, 'visibility': 'STV_DEFAULT'}), 'st_shndx': 11, 'st_value': 0, 'st_size': 384}) + /*0090*/ .byte 0xdb, 0x00, 0x00, 0x00, 0x12, 0x10, 0x0b, 0x00 + /*0098*/ .byte 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 + /*00a0*/ .byte 0x80, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 + + +// --------------------- .debug_frame -------------------------- + .section .debug_frame,"",@progbits + .__section_name 0x80 // offset in .shstrtab + .__section_type SHT_PROGBITS + .__section_flags 0x0 + .__section_addr 0x0 + .__section_offset 0x2a8 // maybe updated by assembler + .__section_size 0x70 // maybe updated by assembler + .__section_link 0 + .__section_info 0x0 + .__section_entsize 0 + .align 1 // equivalent to set sh_addralign + .debug_frame: + /*0000*/ .byte 0xff, 0xff, 0xff, 0xff, 0x24, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff + /*0010*/ .byte 0xff, 0xff, 0xff, 0xff, 0x03, 0x00, 0x04, 0x7c, 0xff, 0xff, 0xff, 0xff, 0x0f, 0x0c, 0x81, 0x80 + /*0020*/ .byte 0x80, 0x28, 0x00, 0x08, 0xff, 0x81, 0x80, 0x28, 0x08, 0x81, 0x80, 0x80, 0x28, 0x00, 0x00, 0x00 + /*0030*/ .byte 0xff, 0xff, 0xff, 0xff, 0x34, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 + /*0040*/ .byte 0x00, 0x00, 0x00, 0x00 + /*0044*/ .dword E_3 + /*004c*/ .byte 0x80, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x00, 0x00, 0x00, 0x04, 0x2c, 0x00 + /*005c*/ .byte 0x00, 0x00, 0x0c, 0x81, 0x80, 0x80, 0x28, 0x00, 0x04, 0xfc, 0xff, 0xff, 0x3f, 0x00, 0x00, 0x00 + /*006c*/ .byte 0x00, 0x00, 0x00, 0x00 + + +// --------------------- .nv.info -------------------------- + .section .nv.info,"",@"SHT_CUDA_INFO" + .__section_name 0x29 // offset in .shstrtab + .__section_type 1879048192 + .__section_flags 0x0 + .__section_addr 0x0 + .__section_offset 0x318 // maybe updated by assembler + .__section_size 0x24 // maybe updated by assembler + .__section_link 3 + .__section_info 0x0 + .__section_entsize 0 + .align 4 // equivalent to set sh_addralign + .align 4 + + + //----- nvinfo : EIATTR_REGCOUNT + .align 4 + /*0000*/ .byte 0x04, 0x2f + /*0002*/ .short (.L_1 - .L_0) + .align 4 + .L_0: + /*0004*/ .word index@(E_3) + /*0008*/ .word 0x0000000c + + + //----- nvinfo : EIATTR_FRAME_SIZE + .align 4 + .L_1: + /*000c*/ .byte 0x04, 0x11 + /*000e*/ .short (.L_3 - .L_2) + .align 4 + .L_2: + /*0010*/ .word index@(E_3) + /*0014*/ .word 0x00000000 + + + //----- nvinfo : EIATTR_MIN_STACK_SIZE + .align 4 + .L_3: + /*0018*/ .byte 0x04, 0x12 + /*001a*/ .short (.L_5 - .L_4) + .align 4 + .L_4: + /*001c*/ .word index@(E_3) + /*0020*/ .word 0x00000000 + .L_5: + + +// --------------------- .nv.info.E_3 -------------------------- + .section .nv.info.E_3,"",@"SHT_CUDA_INFO" + .__section_name 0x3c // offset in .shstrtab + .__section_type 1879048192 + .__section_flags 0x40 + .__section_addr 0x0 + .__section_offset 0x33c // maybe updated by assembler + .__section_size 0x68 // maybe updated by assembler + .__section_link 3 + .__section_info 0xb + .__section_entsize 0 + .align 4 // equivalent to set sh_addralign + .sectionflags @"" + .align 4 + + + //----- nvinfo : EIATTR_CUDA_API_VERSION + .align 4 + /*0000*/ .byte 0x04, 0x37 + /*0002*/ .short (.L_7 - .L_6) + .L_6: + /*0004*/ .word 0x00000081 + + + //----- nvinfo : EIATTR_SW2861232_WAR + .align 4 + .L_7: + /*0008*/ .byte 0x01, 0x35 + .zero 2 + + + //----- nvinfo : EIATTR_PARAM_CBANK + .align 4 + /*000c*/ .byte 0x04, 0x0a + /*000e*/ .short (.L_9 - .L_8) + .align 4 + .L_8: + /*0010*/ .word index@(.nv.constant0.E_3) + /*0014*/ .short 0x0160 + /*0016*/ .short 0x0018 + + + //----- nvinfo : EIATTR_CBANK_PARAM_SIZE + .align 4 + .L_9: + /*0018*/ .byte 0x03, 0x19 + /*001a*/ .short 0x0018 + + + //----- nvinfo : EIATTR_KPARAM_INFO + .align 4 + /*001c*/ .byte 0x04, 0x17 + /*001e*/ .short (.L_11 - .L_10) + .L_10: + /*0020*/ .word 0x00000000 + /*0024*/ .short 0x0002 + /*0026*/ .short 0x0010 + /*0028*/ .byte 0x00, 0xf0, 0x21, 0x00 + + + //----- nvinfo : EIATTR_KPARAM_INFO + .align 4 + .L_11: + /*002c*/ .byte 0x04, 0x17 + /*002e*/ .short (.L_13 - .L_12) + .L_12: + /*0030*/ .word 0x00000000 + /*0034*/ .short 0x0001 + /*0036*/ .short 0x0008 + /*0038*/ .byte 0x00, 0xf0, 0x21, 0x00 + + + //----- nvinfo : EIATTR_KPARAM_INFO + .align 4 + .L_13: + /*003c*/ .byte 0x04, 0x17 + /*003e*/ .short (.L_15 - .L_14) + .L_14: + /*0040*/ .word 0x00000000 + /*0044*/ .short 0x0000 + /*0046*/ .short 0x0000 + /*0048*/ .byte 0x00, 0xf0, 0x21, 0x00 + + + //----- nvinfo : EIATTR_MAXREG_COUNT + .align 4 + .L_15: + /*004c*/ .byte 0x03, 0x1b + /*004e*/ .short 0x00ff + + + //----- nvinfo : EIATTR_EXIT_INSTR_OFFSETS + .align 4 + /*0050*/ .byte 0x04, 0x1c + /*0052*/ .short (.L_17 - .L_16) + + + // ....[0].... + .L_16: + /*0054*/ .word 0x000000b0 + + + //----- nvinfo : EIATTR_MAX_THREADS + .align 4 + .L_17: + /*0058*/ .byte 0x04, 0x05 + /*005a*/ .short (.L_19 - .L_18) + .L_18: + /*005c*/ .word 0x00000003 + /*0060*/ .word 0x00000001 + /*0064*/ .word 0x00000001 + .L_19: + + +// --------------------- .nv.callgraph -------------------------- + .section .nv.callgraph,"",@"SHT_CUDA_CALLGRAPH" + .__section_name 0xb0 // offset in .shstrtab + .__section_type 1879048193 + .__section_flags 0x0 + .__section_addr 0x0 + .__section_offset 0x3a4 // maybe updated by assembler + .__section_size 0x20 // maybe updated by assembler + .__section_link 3 + .__section_info 0x0 + .__section_entsize 8 + .align 4 // equivalent to set sh_addralign + .align 4 + .sectionentsize 8 + .align 4 + /*0000*/ .word 0x00000000 + .align 4 + /*0004*/ .word 0xffffffff + .align 4 + /*0008*/ .word 0x00000000 + .align 4 + /*000c*/ .word 0xfffffffe + .align 4 + /*0010*/ .word 0x00000000 + .align 4 + /*0014*/ .word 0xfffffffd + .align 4 + /*0018*/ .word 0x00000000 + .align 4 + /*001c*/ .word 0xfffffffc + + +// --------------------- .nv.rel.action -------------------------- + .section .nv.rel.action,"",@"SHT_CUDA_RELOCINFO" + .__section_name 0xcc // offset in .shstrtab + .__section_type 1879048203 + .__section_flags 0x0 + .__section_addr 0x0 + .__section_offset 0x3c8 // maybe updated by assembler + .__section_size 0x10 // maybe updated by assembler + .__section_link 0 + .__section_info 0x0 + .__section_entsize 8 + .align 8 // equivalent to set sh_addralign + .align 8 + .sectionentsize 8 + /*0000*/ .byte 0x73, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x11, 0x25, 0x00, 0x05, 0x36 + + +// --------------------- .rel.debug_frame -------------------------- + .section ".rel.debug_frame", 64, SHT_REL + // all relocation sections will be dynamically generated by assembler + // but most of the section header will be kept as is. + .__section_name 0x8d // offset in .shstrtab + .__section_type SHT_REL + .__section_flags 0x40 + .__section_addr 0x0 + .__section_offset 0x3d8 // maybe updated by assembler + .__section_size 0x10 // maybe updated by assembler + .__section_link 3 + .__section_info 0x4 + .__section_entsize 16 + .align 8 // equivalent to set sh_addralign + // Relocation[0] : E_3, Container({'r_offset': 68, 'r_info': 25769803778, 'r_info_sym': 6, 'r_info_type': 2}) + +// --------------------- .nv.constant0.E_3 -------------------------- + .section .nv.constant0.E_3,"a",@progbits + .__section_name 0x58 // offset in .shstrtab + .__section_type SHT_PROGBITS + .__section_flags 0x42 + .__section_addr 0x0 + .__section_offset 0x3e8 // maybe updated by assembler + .__section_size 0x178 // maybe updated by assembler + .__section_link 0 + .__section_info 0xb + .__section_entsize 0 + .align 4 // equivalent to set sh_addralign + .sectionflags @"" + .align 4 + .nv.constant0.E_3: + .zero 376 + + +// --------------------- .text.E_3 -------------------------- + .section .text.E_3,"ax",@progbits + .__section_name 0x32 // offset in .shstrtab + .__section_type SHT_PROGBITS + .__section_flags 0x6 + .__section_addr 0x0 + .__section_offset 0x580 // maybe updated by assembler + .__section_size 0x180 // maybe updated by assembler + .__section_link 3 + .__section_info 0xc000006 + .__section_entsize 0 + .align 128 // equivalent to set sh_addralign + .sectioninfo @"SHI_REGISTERS=12" + .align 128 + .global E_3 + .type E_3,@function + .size E_3,(.L_x_1 - E_3) + .other E_3,@"STO_CUDA_ENTRY STV_DEFAULT" + E_3: + .text.E_3: + [B------:R-:W-:-:S02] /*0000*/ MOV R1, c[0x0][0x28] ; + [B------:R-:W0:-:S01] /*0010*/ S2R R6, SR_TID.X ; + [B------:R-:W-:-:S01] /*0020*/ MOV R7, 0x4 ; + [B------:R-:W-:Y:S04] /*0030*/ ULDC.64 UR4, c[0x0][0x118] ; + [B0-----:R-:W-:Y:S04] /*0040*/ IMAD.WIDE R2, R6, R7, c[0x0][0x168] ; + [B------:R-:W-:-:S02] /*0050*/ IMAD.WIDE R4, R6.reuse, R7.reuse, c[0x0][0x170] ; + [B------:R-:W2:-:S04] /*0060*/ LDG.E R2, desc[UR4][R2.64] ; + [B------:R-:W2:-:S01] /*0070*/ LDG.E R5, desc[UR4][R4.64] ; + [B------:R-:W-:-:S01] /*0080*/ IMAD.WIDE R6, R6, R7, c[0x0][0x160] ; + [B--2---:R-:W-:Y:S05] /*0090*/ FMUL R9, R2, R5 ; + [B------:R-:W-:-:S01] /*00a0*/ STG.E desc[UR4][R6.64], R9 ; + [B------:R-:W-:-:S05] /*00b0*/ EXIT ; + .L_x_0: + [B------:R-:W-:Y:S00] /*00c0*/ BRA `(.L_x_0); + [B------:R-:W-:Y:S00] /*00d0*/ NOP; + [B------:R-:W-:Y:S00] /*00e0*/ NOP; + [B------:R-:W-:Y:S00] /*00f0*/ NOP; + [B------:R-:W-:Y:S00] /*0100*/ NOP; + [B------:R-:W-:Y:S00] /*0110*/ NOP; + [B------:R-:W-:Y:S00] /*0120*/ NOP; + [B------:R-:W-:Y:S00] /*0130*/ NOP; + [B------:R-:W-:Y:S00] /*0140*/ NOP; + [B------:R-:W-:Y:S00] /*0150*/ NOP; + [B------:R-:W-:Y:S00] /*0160*/ NOP; + [B------:R-:W-:Y:S00] /*0170*/ NOP; + .L_x_1: + + //------------------------------------------------- + //---------------- END of sections ---------------- + //------------------------------------------------- + + +// Program segment PT_PHDR, 5 + .__segment "PT_PHDR", 5 + .__segment_offset 0xa00 // maybe updated by assembler + .__segment_vaddr 0x0 // Seems always 0? + .__segment_paddr 0x0 // ??? + .__segment_filesz 0xa8 // file size, maybe updated by assembler + .__segment_memsz 0xa8 // file size + nobits sections, maybe updated by assembler + .__segment_align 8 // + +// Program segment PT_LOAD, 5 + .__segment "PT_LOAD", 5 + .__segment_offset 0x3e8 // maybe updated by assembler + .__segment_vaddr 0x0 // Seems always 0? + .__segment_paddr 0x0 // ??? + .__segment_filesz 0x318 // file size, maybe updated by assembler + .__segment_memsz 0x318 // file size + nobits sections, maybe updated by assembler + .__segment_align 8 // + .__segment_startsection ".nv.constant0.E_3" // first section in this segment + .__segment_endsection ".text.E_3" // last section in this segment + +// Program segment PT_LOAD, 5 + .__segment "PT_LOAD", 5 + .__segment_offset 0xa00 // maybe updated by assembler + .__segment_vaddr 0x0 // Seems always 0? + .__segment_paddr 0x0 // ??? + .__segment_filesz 0xa8 // file size, maybe updated by assembler + .__segment_memsz 0xa8 // file size + nobits sections, maybe updated by assembler + .__segment_align 8 // + .__segment_startsection "@PROGRAM_HEADER" // first section in this segment + .__segment_endsection "@PROGRAM_HEADER" // last section in this segment + + + //------------------------------------------------- + //---------------- END of segments ---------------- + //------------------------------------------------- + + + """ diff --git a/tinygrad/runtime/ops_cuda.py b/tinygrad/runtime/ops_cuda.py index 326ae84f01769..c13f239fa55de 100644 --- a/tinygrad/runtime/ops_cuda.py +++ b/tinygrad/runtime/ops_cuda.py @@ -1,11 +1,12 @@ from __future__ import annotations -import ctypes, ctypes.util, functools +import ctypes, ctypes.util, functools, os from tinygrad.helpers import DEBUG, getenv, mv_address, init_c_var, init_c_struct_t, suppress_finalizing from tinygrad.device import Compiled, BufferSpec, LRUAllocator from tinygrad.renderer.cstyle import CUDARenderer from tinygrad.renderer.ptx import PTXRenderer +from tinygrad.renderer.sass import SASSRenderer from tinygrad.runtime.autogen import cuda -from tinygrad.runtime.support.compiler_cuda import pretty_ptx, CUDACompiler, PTXCompiler, PTX +from tinygrad.runtime.support.compiler_cuda import SASSCompiler, SASSCompiler2, pretty_ptx, CUDACompiler, PTXCompiler, PTX if getenv("IOCTL"): import extra.nv_gpu_driver.nv_ioctl # noqa: F401 # pylint: disable=unused-import if MOCKGPU:=getenv("MOCKGPU"): from test.mockgpu.cuda import cuda # type: ignore # pylint: disable=reimported @@ -115,8 +116,18 @@ def __init__(self, device:str): CUDADevice.devices.append(self) from tinygrad.runtime.graph.cuda import CUDAGraph - super().__init__(device, CUDAAllocator(self), PTXRenderer(self.arch) if PTX else CUDARenderer(self.arch), - PTXCompiler(self.arch) if PTX else CUDACompiler(self.arch), functools.partial(CUDAProgram, self), None if MOCKGPU else CUDAGraph) + if os.environ.get("SASS"): + renderer = SASSRenderer(self.arch) + compiler = SASSCompiler(self.arch) + elif os.environ.get("SASS2"): + renderer = CUDARenderer(self.arch) + compiler = SASSCompiler2(self.arch) + else: + renderer = PTXRenderer(self.arch) if PTX else CUDARenderer(self.arch) + compiler = PTXCompiler(self.arch) if PTX else CUDACompiler(self.arch) + + super().__init__(device, CUDAAllocator(self), renderer, + compiler, functools.partial(CUDAProgram, self), None if MOCKGPU else CUDAGraph) def synchronize(self): check(cuda.cuCtxSetCurrent(self.context)) diff --git a/tinygrad/runtime/support/assembler/CuAsmLogger.py b/tinygrad/runtime/support/assembler/CuAsmLogger.py new file mode 100644 index 0000000000000..d367f27dfd569 --- /dev/null +++ b/tinygrad/runtime/support/assembler/CuAsmLogger.py @@ -0,0 +1,276 @@ +# -*- coding: utf-8 -*- + +import logging +import logging.handlers +import sys +import time +import os +import tempfile +import random + +class CuAsmLogger(object): + ''' A logger private to current module. + + A customized logging style is used to show the progress better, + without affecting the logging of other modules. + + ''' + __LoggerRepos = {} + __CurrLogger = None + __LogFileRepos = {} + __IndentLevel = 0 + __IndentString = '' + + # Predefined levels: + # CRITICAL 50 + # ERROR 40 + # WARNING 30 + # INFO 20 + # DEBUG 10 + # NOTSET 0 + + # Custom log levels + + ENTRY = 35 # main entry of a module + PROCEDURE = 25 # procedures of some module + SUBROUTINE = 15 # some internal subroutines + + @staticmethod + def getDefaultLoggerFile(name): + ''' Default log file in temp dir. + + NOTE: this is not safe, since several instances may run simultaneously. + ''' + fpath = tempfile.gettempdir() + return os.path.join(fpath, name + '.log') + + @staticmethod + def getTemporaryLoggerFile(name): + ''' Temporary logfile in temp dir.''' + fpath = tempfile.gettempdir() + while True: + ttag = time.strftime('.%m%d-%H%M%S.', time.localtime()) + tmpname = ''.join(random.choices('abcdefghijklmnopqrstuvwxyz0123456789', k = 8)) + fname = os.path.join(fpath, name + ttag + tmpname + '.log') + if not os.path.exists(fname): + break + + return fname + + @staticmethod + def initLogger(log_file='', *, name='cuasm', file_level=logging.DEBUG, file_max_bytes=1<<30, file_backup_count=3, stdout_level=25): + ''' Init a logger with given name and logfile. + + log_file: set to None for no file log; + set to '' for default temporary log file; (DEFAULT) + set to filename for user specified log file; + + CuAsmLogger uses RotatingFileHandler for logging, thus if given log_file exists or file size exceeds the max_bytes, + it will roll over and rename previous files to logfile.log.1, logfile.log.2, etc... + + NOTE: Temporary logfiles will not be deleted automatically, since we usually need to check the log after running a program. + + name : logger instance name, default to 'cuasm' + several loggers may exist simultaneously, use setActiveLogger(name) to switch between them. + file_level : log level of file + file_max_bytes: max size of logfile(in bytes), default to 1GB. + file_backup_count: number of maximum rolling over files, default to 3. + stdout_level: log level for standard output. + ''' + # if name in CuAsmLogger.__LoggerRepos: + # CuAsmLogger.__CurrLogger = CuAsmLogger.__LoggerRepos[name] + # print('CuAsmLogger %s already exists! Skipping init...' % name) + # return + + logger = logging.getLogger(name) + hs = [h for h in logger.handlers] + for h in hs: + logger.removeHandler(h) + + logger.setLevel(logging.DEBUG) + + fmt = logging.Formatter('%(asctime)s - %(message)s') + if log_file is not None: + if len(log_file) == 0: + full_log_file = CuAsmLogger.getTemporaryLoggerFile(name) + else: + # fpath, fbase = os.path.split(log_file) + + # if fbase.lower().endswith('.log'): + # full_log_file = os.path.join(fpath, name + '.' + fbase) + # else: + # full_log_file = os.path.join(fpath, name + '.' + fbase + '.log') + if log_file.endswith('.log'): + full_log_file = log_file + else: + full_log_file = log_file + '.log' + + # fh = logging.FileHandler(full_log_file, mode='a') + print(f'InitLogger({name}) with logfile "{full_log_file}"...') + + # once RotatingFileHandler is created, the log file will be created at the same time + # thus we need to detect whether the logfile needs to be rolled over before handler creation + needsRollOver = os.path.exists(full_log_file) + fh = logging.handlers.RotatingFileHandler(full_log_file, mode='a', maxBytes=file_max_bytes, backupCount=file_backup_count) + + # default mode is 'a', but we may want a new log for every run, but still keeping old logs as backup. + if needsRollOver: + print(f'Logfile {full_log_file} already exists! Rolling over...') + fh.doRollover() + + fh.setFormatter(fmt) + fh.setLevel(file_level) + + logger.addHandler(fh) + CuAsmLogger.__LogFileRepos[name] = full_log_file + else: + CuAsmLogger.__LogFileRepos[name] = None + + if stdout_level is not None: + sh = logging.StreamHandler(sys.stdout) + sh.setFormatter(fmt) + sh.setLevel(stdout_level) + logger.addHandler(sh) + + # + CuAsmLogger.__LoggerRepos[name] = logger + CuAsmLogger.__CurrLogger = logger + + @staticmethod + def setActiveLogger(name): + if name in CuAsmLogger.__LoggerRepos: + CuAsmLogger.__CurrLogger = CuAsmLogger.__LoggerRepos[name] + else: + print('CuAsmLogger %s does not exist! Keeping current logger...' % name) + + @staticmethod + def getCurrentLogFile(): + return CuAsmLogger.__LogFileRepos[CuAsmLogger.__CurrLogger.name] + + @staticmethod + def logDebug(msg, *args, **kwargs): + CuAsmLogger.__CurrLogger.debug(' DEBUG - ' + msg, *args, **kwargs) + + @staticmethod + def logInfo(msg, *args, **kwargs): + CuAsmLogger.__CurrLogger.info(' INFO - ' + msg, *args, **kwargs) + + @staticmethod + def logWarning(msg, *args, **kwargs): + CuAsmLogger.__CurrLogger.warning(' WARNING - ' + msg, *args, **kwargs) + + @staticmethod + def logError(msg, *args, **kwargs): + CuAsmLogger.__CurrLogger.error(' ERROR - ' + msg, *args, **kwargs) + + @staticmethod + def logCritical(msg, *args, **kwargs): + CuAsmLogger.__CurrLogger.critical('CRITICAL - ' + msg, *args, **kwargs) + + @staticmethod + def logEntry(msg, *args, **kwargs): + full_msg = ' ENTRY - ' + CuAsmLogger.__IndentString + msg + CuAsmLogger.__CurrLogger.log(CuAsmLogger.ENTRY, full_msg, *args, **kwargs) + + + @staticmethod + def logProcedure(msg, *args, **kwargs): + + full_msg = ' PROC - ' + CuAsmLogger.__IndentString + msg + CuAsmLogger.__CurrLogger.log(CuAsmLogger.PROCEDURE, full_msg, *args, **kwargs) + + + @staticmethod + def logSubroutine(msg, *args, **kwargs): + full_msg = ' SUB - ' + CuAsmLogger.__IndentString + msg + CuAsmLogger.__CurrLogger.log(CuAsmLogger.SUBROUTINE, full_msg, *args, **kwargs) + + + @staticmethod + def logLiteral(msg, *args, **kwargs): + full_msg = ' - ' + CuAsmLogger.__IndentString + msg + CuAsmLogger.__CurrLogger.log(CuAsmLogger.PROCEDURE, full_msg, *args, **kwargs) + + + @staticmethod + def log(level, msg, *args, **kwargs): + CuAsmLogger.__CurrLogger.log(level, msg, *args, **kwargs) + + + @staticmethod + def logTimeIt(func): + ''' Logging of a (usually) long running function. + + ''' + def wrapper(*args, **kwargs): + CuAsmLogger.logLiteral('Running %s...'%func.__qualname__) + CuAsmLogger.incIndent() + + t0 = time.time() + ret = func(*args, **kwargs) + t1 = time.time() + + CuAsmLogger.decIndent() + CuAsmLogger.logLiteral('Func %s completed! Time=%8.4f secs.'%(func.__qualname__, t1-t0)) + + return ret + + return wrapper + + @staticmethod + def logIndentIt(func): + ''' + ''' + def wrapper(*args, **kwargs): + CuAsmLogger.incIndent() + ret = func(*args, **kwargs) + CuAsmLogger.decIndent() + + return ret + + return wrapper + + @staticmethod + def logTraceIt(func): + ''' + ''' + def wrapper(*args, **kwargs): + CuAsmLogger.logLiteral('Running %s...'%func.__qualname__) + CuAsmLogger.incIndent() + + ret = func(*args, **kwargs) + CuAsmLogger.decIndent() + + return ret + + return wrapper + + @staticmethod + def incIndent(): + CuAsmLogger.__IndentLevel += 1 + CuAsmLogger.__IndentString = ' ' * CuAsmLogger.__IndentLevel + + @staticmethod + def decIndent(): + CuAsmLogger.__IndentLevel -= 1 + if CuAsmLogger.__IndentLevel < 0: + CuAsmLogger.__IndentLevel = 0 + CuAsmLogger.__IndentString = ' ' * CuAsmLogger.__IndentLevel + + @staticmethod + def resetIndent(val=0): + if val<0: + val = 0 + CuAsmLogger.__IndentLevel = val + CuAsmLogger.__IndentString = ' ' * CuAsmLogger.__IndentLevel + + @staticmethod + def setLevel(level): + CuAsmLogger.__CurrLogger.setLevel(level) + + @staticmethod + def disable(): + CuAsmLogger.__CurrLogger.setLevel(logging.ERROR) + +# Init a default logger when the module is imported +CuAsmLogger.initLogger(log_file=None) diff --git a/tinygrad/runtime/support/assembler/CuAsmParser.py b/tinygrad/runtime/support/assembler/CuAsmParser.py new file mode 100644 index 0000000000000..580b650a51555 --- /dev/null +++ b/tinygrad/runtime/support/assembler/CuAsmParser.py @@ -0,0 +1,2063 @@ +# -*- coding: utf-8 -*- + +import re +import os +from io import BytesIO +from collections import OrderedDict, defaultdict + +from elftools.elf.elffile import ELFFile + +from tinygrad.runtime.support.assembler.CuKernelAssembler import CuKernelAssembler +from tinygrad.runtime.support.assembler.CuInsAssemblerRepos import CuInsAssemblerRepos +from tinygrad.runtime.support.assembler.CuSMVersion import CuSMVersion +from tinygrad.runtime.support.assembler.CuNVInfo import CuNVInfo +from tinygrad.runtime.support.assembler.CuAsmLogger import CuAsmLogger +from tinygrad.runtime.support.assembler.CubinFile import PROGRAM_HEADER_TAG + +from tinygrad.runtime.support.assembler.config import Config +from tinygrad.runtime.support.assembler.common import splitAsmSection, alignTo, bytes2Asm +from tinygrad.runtime.support.assembler.CuControlCode import c_ControlCodesPattern + +def printb(b: bytes, s: int=-1, e: int=-1): + for i, _b in enumerate(b): + ret = f"{_b:02x}" + if i == s: ret = "[" + ret + elif i == e: ret = ret + "]" + else: ret = " " + ret + " " + print(f"{ret}", end=" ") + if i % 8 == 0: print() + print() + +m_hex = re.compile(r'\b0x[a-fA-F0-9]+\b') +m_int = re.compile(r'\b[0-9]+\b') +m_intval = re.compile(r'\b(0x[a-fA-F0-9]+)|([0-9]+)\b') + +def updateDictWithInput(din, dout, label='', kprefix=''): + ''' Update a dict with input from another dict. + + The key will be prefixed with kprefix. + the value will be converted to int if possible (for hex or dec int). + + label is only used for error tracing. + ''' + for k,v in din.items(): + kp = kprefix + k + if kp not in dout: + # CuAsmLogger.logWarning('Unknown header attribute (%s) for %s!!!'%(k,label)) + pass + + if isinstance(v, str): + if m_hex.match(v): + vv = int(v, 16) + elif m_int.match(v): + vv = int(v) + else: + vv = v + else: + vv = v + + dout[kp] = vv + +def buildStringDict(bytelist): + ''' build strings dict from b'\x00' joined byte list. + + The dict key/value is just the offset/value of the string. + ''' + p = 0 + counter = 0 + + sdict = OrderedDict() + while True: + counter += 1 + pnext = bytelist.find(b'\x00', p) + if pnext<0: + break + + s = bytelist[p:pnext] # not include the ending b'\x00' + sdict[p] = s.decode() + p = pnext+1 + + return sdict + +class CuAsmSymbol(object): + ''' + typedef struct + { + Elf64_Word st_name; /* Symbol name */ + unsigned char st_info; /* Type and Binding attributes */ + unsigned char st_other; /* Reserved */ + Elf64_Half st_shndx; /* Section table index */ + Elf64_Addr st_value; /* Symbol value */ + Elf64_Xword st_size; /* Size of object (e.g., common) */ + } Elf64_Sym; + + // + typedef uint64_t Elf64_Addr; + typedef uint16_t Elf64_Half; + typedef uint64_t Elf64_Off; + typedef int32_t Elf64_Sword; + typedef int64_t Elf64_Sxword; + typedef uint32_t Elf64_Word; + typedef uint64_t Elf64_Lword; + typedef uint64_t Elf64_Xword; + + + All internal symbols should also be defined as labels. + The label offset is just the symbol value, and the section where the label + is defined will affect the behavior of jump/branch instructions. + + FIXME: Currently some attributes in st_other (such as "STO_CUDA_ENTRY") cannot be + recognized by pyelftools, thus may be lost if parsed and built again. + ''' + + # TODO: Not implemented yet, just copied from cubin + SymbolTypes = {'@function' :0, + '@object' :1, + '@"STT_CUDA_TEXTURE"':2, + '@"STT_CUDA_SURFACE"':3} + + def __init__(self, name): + self.name = name + self.type = None + self.value = None + self.size = None + self.sizeval = None + self.other = None + self.index = None # + + self.entry = Config.defaultSymbol.copy() + + def __str__(self): + s = 'name=%s, type=%s, value=%s, size(%s)=%s'%( + self.name, self.type, self.value, self.sizeval, self.size) + return s + + def build(self): + ''' Build symbol entry. + + TODO: not implemented, symtab entries are copied from cubin + but value/size may be updated + ''' + return Config.CubinELFStructs.Elf_Sym.build(self.entry) + + @staticmethod + def buildSymbolDict(strtab, symbytes): + symdict = OrderedDict() + symsize = Config.CubinELFStructs.Elf_Sym.sizeof() + index = 0 + for p in range(0, len(symbytes), symsize): + sym = Config.CubinELFStructs.Elf_Sym.parse(symbytes[p:p+symsize]) + nameidx = sym['st_name'] + if nameidx not in strtab: + raise Exception('Unknown symbol @%#x with name string index 0x%x!'%(p, nameidx)) + + name = strtab[nameidx] + if name in symdict: + raise Exception('Duplicate symbol @%#x with name %s!', p, name) + symdict[name] = index, sym + index += 1 + + return symdict + + @staticmethod + def resetSymtabEntryValueSize(bio, base_offset, value, size): + ''' reset Symbol entry value/size in symtab byte stream. + + bio: BytesIO stream + base_offset: base offset of current entry + value/size: symbol value/size to be set + ''' + + p = bio.tell() # save current pos + bio.seek(base_offset + 8) # +8 is offset for the value + bio.write(int.to_bytes(value, 8, 'little')) + bio.write(int.to_bytes(size, 8, 'little')) + bio.seek(p) # restore pos + +class CuAsmLabel(object): + ''' A label is defined by "label:" + + Every symbol (non-external) is also a label, the symbol value is just the label offset. + ''' + def __init__(self, name, section, offset, lineno): + self.name = name + self.section = section + self.offset = offset + self.lineno = lineno + CuAsmLogger.logSubroutine('Line %6d: New Label "%s" at section "%s":%#x'%(lineno, name, section.name, offset)) + + def __str__(self): + s = 'Label @Line %4d in section %s : %-#7x(%6d) %s'%(self.lineno, self.section.name, self.offset, self.offset, + self.name) + return s + +class CuAsmFixup(object): + ''' Fixups are a set of undetermined values during the first scan. + + Some fixups can be evaluated after first scan. Then the true values will be filled. + There are also some fixups cannot be determined during compile time, thus they will + go to relocations and the true values will be filled by the program loader. + ''' + + def __init__(self, section, offset, expr, dtype, lineno): + self.section = section + self.offset = offset + self.lineno = lineno + self.dtype = dtype + self.expr = expr + self.value = None + + CuAsmLogger.logSubroutine('Line %6d: New Fixup "%s" at section "%s":%#x'%(lineno, expr, section.name, offset)) + + def __str__(self): + s = 'section=%s, offset=%d, lineno=%d, dtype=%s, expr=%s, value=%s'%( + self.section.name, self.offset, self.lineno, self.dtype, self.expr, self.value) + return s + +class CuAsmSection(object): + ''' + Section header struct (Only ELF64 supported): + + typedef struct + { + Elf64_Word sh_name; /* Section name */ + Elf64_Word sh_type; /* Section type */ + Elf64_Xword sh_flags; /* Section attributes */ + Elf64_Addr sh_addr; /* Virtual address in memory */ + Elf64_Off sh_offset; /* Offset in file */ + Elf64_Xword sh_size; /* Size of section */ + Elf64_Word sh_link; /* Link to other section */ + Elf64_Word sh_info; /* Miscellaneous information */ + Elf64_Xword sh_addralign; /* Address alignment boundary */ + Elf64_Xword sh_entsize; /* Size of entries, if section has table */ + } Elf64_Shdr; + ''' + + def __init__(self, sname, stype, sflags): + '''Construct an ELF section. + + Currently there are 3 systems for section headers. + 1. self.name/type/flags/... work for predefined directives, such as .section/.sectioninfo + 2. self.header['name']... work for supplementary directives, namely .section_* + 3. self.__mSectionHeader is the struct form for building header bytes + + Only 1 and 2 can be set in assembly, 1 has higher priority. + Information from 1 and 2 will be combined to form the final header. + + Surely there are redundencies here, but it's the safest way to keep some attributes + set by ptxas, yet still give user enough flexibility to modify them. + + ''' + self.name = sname + self.type = stype # “A” stands for SHF_ALLOC + # “W” for SHF_WRITE + # “X” for SHF_EXECINSTR + self.flags = [sflags] # some extra flags may be appended later + + self.info = [] + self.offset = None + self.size = None + self.addralign = None + self.entsize = 0 + + self.header = {} + self.extra = {} # barnum/regnum, only for update nvinfo + + # + self.padsize = 0 + self.padbytes = b'' + + self.__isTextSection = sname.startswith('.text') + self.__mSectionHeader = Config.defaultSectionHeader.copy() + self.__mData = BytesIO() + + def updateHeader(self): + '''Update section header with user inputs. + + TODO: currently only offset/size will be updated. + ''' + + updateDictWithInput(self.header, self.__mSectionHeader, + label='section %s'%self.name, kprefix = 'sh_') + + # maybe we can just update self.header? + if self.header['type'] == 'SHT_NULL': + self.__mSectionHeader['sh_offset'] = 0 + else: + self.__mSectionHeader['sh_offset'] = self.offset + self.__mSectionHeader['sh_size'] = self.getDataSize() #self.size + + def getHeaderStruct(self): + return self.__mSectionHeader + + def updateResourceInfo(self): + '''Update register/barrier number. + + Examples: + .sectionflags @"SHF_BARRIERS=1" + .sectioninfo @"SHI_REGISTERS=12" + ''' + + # + p_regnum = re.compile(r'@"SHI_REGISTERS=(\d+)"') + p_barnum = re.compile(r'@"SHF_BARRIERS=(\d+)"') + + regnum = None + barnum = 0 # There may be no barrier used in a kernel + + for info in self.info: + res = p_regnum.match(info) + if res is not None: + regnum = int(res.groups()[0]) + + for flag in self.flags: + res = p_barnum.match(flag) + if res is not None: + barnum = int(res.groups()[0]) + + if regnum is None: + raise Exception("Unknown register number for section %s!"%self.name) + elif regnum > 255 or regnum<0: # TODO: use MAX_REG_COUNT instead? + raise Exception("Invalid register number %d for section %s!"%(regnum, self.name)) + else: + rinfo = self.header['info'] + self.header['info'] = (rinfo & 0x00ffffff) + (regnum<<24) + self.extra['regnum'] = regnum + + if barnum>15: # always rewrite bar number~ + raise Exception("Invalid barrier number %d for section %s!"%(barnum, self.name)) + else: + rflag = self.header['flags'] + self.header['flags'] = (rflag&0xff0fffff) + (barnum<<20) + self.extra['barnum'] = barnum + + def buildHeader(self): + ''' Build section header bytes with current header struct. ''' + + self.updateHeader() + # print(self.__mSectionHeader) + return Config.CubinELFStructs.Elf_Shdr.build(self.__mSectionHeader) + + def emitBytes(self, bs): + self.__mData.write(bs) + + def updateForFixup(self, offset, bs): + ''' Update corresponding bytes for fixup. + + Input: + offset the absolute w.r.t the beginning of the section + bs bytes to be updated + ''' + blen = len(bs) + + if (offset+blen) > self.getDataSize(): + raise Exception('Fixup out of boundary!') + + # save original pos + opos = self.tell() + self.__mData.seek(offset) + + # value is guaranteed within range during fixup evaluation. + self.__mData.write(bs) + self.__mData.seek(opos) + + def emitAlign(self, align): + ''' Set alignment of next bytes. + + Note: When current position is section start, the alignment is the addralign of current section. + Then the padding is done to previous section. + ''' + + pos = self.tell() + if pos == 0: + self.addralign = align + self.header['addralign'] = align + else: + ppos, padsize = alignTo(pos, align) + if ppos > pos: # do padding with required 0-bytes/nops + self.emitBytes(b'\x00' * (ppos-pos)) + + def emitPadding(self, bs): + ''' This is only for .text sections. + + Emitting padding here will change the size of current text section. + For non-text sections, the padding should be done without changing the size. + ''' + pos = self.tell() + self.seek(0, 2) # seek to end + self.emitBytes(bs) + self.seek(pos) # restore original position + + def seek(self, pos, whence=0): + return self.__mData.seek(pos, whence) + + def tell(self): + return self.__mData.tell() + + def getData(self): + return self.__mData.getvalue() + + def writePaddedData(self, stream): + if self.header['type'] == 'SHT_NOBITS': # nobits sections will not write to file. + return + else: + stream.write(self.__mData.getvalue()) + stream.write(self.padbytes) + + def setData(self, bs): + ''' Update section data with given bytes. ''' + + self.__mData = BytesIO(bs) + self.size = len(bs) + + def getDataSize(self): + ''' Get memory size of current section. + + For section of type nobits, no actual file contents. + ''' + return len(self.__mData.getvalue()) + + def getPaddedDataSize(self): + return self.getDataSize() + self.padsize + + def getRegNum(self): + return self.extra['regnum'] + + def __str__(self): + s = 'Section:\n' + s += ' name : %s\n' % self.name + s += ' type : %s\n' % self.type + s += ' flags : %s\n' % str(self.flags) + s += ' info : %s\n' % self.info + s += ' offset : %s\n' % self.offset + s += ' addralign : %s\n' % self.addralign + + return s + +class CuAsmSegment(object): + def __init__(self, p_type, p_flags): + self.header = {'type':p_type, 'flags':p_flags} + self.__mSegmentHeader = Config.defaultSegmentHeader.copy() + + def updateHeader(self): + ''' Update header with inputs''' + + updateDictWithInput(self.header, self.__mSegmentHeader, + label='segment', kprefix = 'p_') + + def getHeaderStruct(self): + return self.__mSegmentHeader + + def build(self): + return Config.CubinELFStructs.Elf_Phdr.build(self.__mSegmentHeader) + +class CuAsmRelocation(object): + ''' Relocation class. + + Relocation is a special section that may modify some contents of its linked section. + This procedure is generally done during loading, the modified contents are typically + the real memory address of some symbols. + + typedef struct + { + Elf64_Addr r_offset; /* Address of reference */ + Elf64_Xword r_info; /* Symbol index and type of relocation */ + } Elf64_Rel; + + typedef struct + { + Elf64_Addr r_offset; /* Address of reference */ + Elf64_Xword r_info; /* Symbol index and type of relocation */ + Elf64_Sxword r_addend; /* Constant part of expression */ + } Elf64_Rela; + + + Relocations are typically for some dynamic variables (symbols). + Sources of relocations: + 1. .dword/.word defined values in normal sections + 2. 32lo@* or 32hi@* kind of operands in text sections + + such as : + /*0040*/ MOV R2, 32@lo(flist) ; + /*0060*/ MOV R3, 32@hi(flist) ; + + RELA is a relocation section with extra offsets, such as: + /*00f0*/ MOV R20, 32@lo((_Z4testPiS_S_ + .L_6@srel)) ; + /*0100*/ MOV R21, 32@hi((_Z4testPiS_S_ + .L_6@srel)) ; + + 3. `(symbol) in text sections (for symbols not defined in current section) + + ''' + + REL_TYPES = { + 'R_CUDA_32' : 1, + 'R_CUDA_64' : 2, + 'R_CUDA_G64' : 4, + 'R_CUDA_TEX_HEADER_INDEX' : 6, + 'R_CUDA_SURF_HEADER_INDEX': 52, + 'R_CUDA_ABS32_20' : 42, + 'R_CUDA_ABS32_LO_20' : 43, + 'R_CUDA_ABS32_HI_20' : 44, + 'R_CUDA_ABS32_LO_32' : 56, + 'R_CUDA_ABS32_HI_32' : 57, + 'R_CUDA_ABS47_34' : 58} + + def __init__(self, section, offset, relsymname, relsymid, reltype, reladd=None): + self.section = section + self.offset = offset + self.relsymname = relsymname + self.relsymid = relsymid + self.reltype = reltype + self.reladd = reladd # reladd=None means rel, otherwise rela + + CuAsmLogger.logSubroutine('New Relocation "%s" at section "%s":%#x'%(relsymname, section.name, offset)) + + def isRELA(self): + return self.reladd is not None + + def buildEntry(self): + ''' Build binary entry of current relocation. + + Examples: + _Z4testPiS_S_, Container({'r_offset': 528, 'r_info': 124554051586, 'r_info_sym': 29, 'r_info_type': 2}) + _Z4testPiS_S_, Container({'r_offset': 2288, 'r_info': 124554051641, 'r_info_sym': 29, 'r_info_type': 57, 'r_addend': 2352}) + ''' + if self.isRELA(): # RELA + rela = Config.defaultRela.copy() + rela['r_offset'] = self.offset + rela['r_info_sym'] = self.relsymid + rela['r_info_type'] = self.REL_TYPES[self.reltype] + rela['r_info'] = (rela['r_info_sym']<<32) + rela['r_info_type'] + rela['r_addend'] = self.reladd + # print(rela) + return Config.CubinELFStructs.Elf_Rela.build(rela) + + else: # REL + rel = Config.defaultRel.copy() + rel['r_offset'] = self.offset + rel['r_info_sym'] = self.relsymid + rel['r_info_type'] = self.REL_TYPES[self.reltype] + rel['r_info'] = (rel['r_info_sym']<<32) + rel['r_info_type'] + # print(rel) + return Config.CubinELFStructs.Elf_Rel.build(rel) + + def __str__(self): + s = '@section %s: offset=%s, relsym=%d(%s), reltype=%s, reladd=%s'%( + self.section.name, + self.offset, + self.relsymid, + self.relsymname, + self.reltype, + self.reladd) + return s + +class CuAsmFile(object): + + def __init__(self): + self.mSMVersion = None # sm version + + self.headerflags = None + self.elftype = None + + self.fileHeader = {} # unprocessed elf file header + self.__mFileHeader = Config.defaultCubinFileHeader.copy() + + self.__mSectionList = OrderedDict() + self.__mSegmentList = [] + + self.__mLastSection = None + self.__mCurrSection = None + + self.__mBuf = BytesIO() # global buffer for whole elf file, but without current section + + def buildFileHeader(self): + + self.__mFileHeader['e_ident']['EI_OSABI'] = self.fileHeader['ident_osabi'] + self.__mFileHeader['e_ident']['EI_ABIVERSION'] = self.fileHeader['ident_abiversion'] + self.__mFileHeader['e_type'] = self.fileHeader['type'] + self.__mFileHeader['e_machine'] = self.fileHeader['machine'] + self.__mFileHeader['e_version'] = self.fileHeader['version'] + self.__mFileHeader['e_entry'] = self.fileHeader['entry'] + self.__mFileHeader['e_phoff'] = self.fileHeader['phoff'] + self.__mFileHeader['e_shoff'] = self.fileHeader['shoff'] + self.__mFileHeader['e_flags'] = self.fileHeader['flags'] + self.__mFileHeader['e_ehsize'] = self.fileHeader['ehsize'] + self.__mFileHeader['e_phentsize'] = self.fileHeader['phentsize'] + self.__mFileHeader['e_phnum'] = self.fileHeader['phnum'] + self.__mFileHeader['e_shentsize'] = self.fileHeader['shentsize'] + self.__mFileHeader['e_shnum'] = self.fileHeader['shnum'] + self.__mFileHeader['e_shstrndx'] = self.fileHeader['shstrndx'] + + return Config.CubinELFStructs.Elf_Ehdr.build(self.__mFileHeader) + + def getFileHeaderStruct(self): + return self.__mFileHeader + + def emitAlign(self, align): + ''' padding last section to required alignments. + + Return the padded length. + ''' + + pos = self.tell() + ppos = align * ((pos+align-1) // align) + if ppos > pos: # do padding with required 0-bytes/nops + if self.__mLastSection is not None: + padbytes = self.__mLastSection.genSectionPaddingBytes(ppos - pos) + else: + padbytes = b'\x00' * (ppos - pos) + self.__mBuf.write(padbytes) + + return ppos-pos + + def seek(self, offset): + self.__mBuf.seek(offset) + + def tell(self): + return self.__mBuf.tell() + + def saveAsCubin(self, cubinname): + with open(cubinname, 'wb') as fout: + fout.write(self.__mBuf.getvalue()) + +class CuAsmParser(object): + ''' Parser for cuasm file.''' + +#### static variables, mostly re patterns + m_cppcomment = re.compile(r'//.*$') # cpp style line comments + m_ccomment = re.compile(r'\/\*.*?\*\/') # c style line + m_bracomment = re.compile(r'\(\*.*\*\)') # notes for bra targets in sm_5x/6x + # such as (*"INDIRECT_CALL"*) + + m_directive = re.compile(r'(\.[a-zA-Z0-9_]+)\s*(.*)') + m_label = re.compile(r'([a-zA-Z0-9._$@#]+?)\s*:\s*(.*)') # "#" for offset label auto rename + m_symbol = re.compile(r'[a-zA-Z0-9._$@]+') #??? + + m_byte = re.compile(r'\b0x[a-fA-F0-9]{2}\b') + m_short = re.compile(r'\b0x[a-fA-F0-9]{4}\b') + m_word = re.compile(r'\b0x[a-fA-F0-9]{8}\b') + m_dword = re.compile(r'\b0x[a-fA-F0-9]{16}\b') # arch dependent? + m_zero = re.compile(r'\b[1-9][0-9]*\b') + + m_sufrel = re.compile(r'\[20@lo\(0x0\)=fun@R_CUDA_SURF_HEADER_INDEX\((\w+)\)\]') + m_texrel = re.compile(r'\[20@lo\(0x0\)=(\w+)\]') + + # dtype that may take relocation arguments. + rel_dtypes = {'dword':0, 'word' :1} + + dtype_pattern = {'byte' : (m_byte , 1), + 'short' : (m_short, 2), + 'word' : (m_word , 4), + 'dword' : (m_dword, 8)} + +#### constructors, and parsing entries + def __init__(self): + + self.__mCuInsAsmRepos = None + + # directive dict + self.__dirDict = { + # predefined directives in nvdisasm + '.headerflags' : self.__dir_headerflags, # set ELF header + '.elftype' : self.__dir_elftype, # set ELF type + '.section' : self.__dir_section, # declare a section + '.sectioninfo' : self.__dir_sectioninfo, # set section info + '.sectionflags' : self.__dir_sectionflags, # set section flags + '.sectionentsize' : self.__dir_sectionentsize, # set section entsize + '.align' : self.__dir_align, # set alignment + '.byte' : self.__dir_byte, # emit bytes + '.short' : self.__dir_short, # emit shorts + '.word' : self.__dir_word, # emit word (4B?) + '.dword' : self.__dir_dword, # emit dword (8B?) + '.type' : self.__dir_type, # set symbol type + '.size' : self.__dir_size, # set symbol size + '.global' : self.__dir_global, # declare a global symbol + '.weak' : self.__dir_weak, # declare a weak symbol + '.zero' : self.__dir_zero, # emit zero bytes + '.other' : self.__dir_other, # set symbol other + # supplementary directives defined by cuasm + # all for setting some ELF/Section/Segment header attributes + # some may with same funtionality as predefined directives + # predefined directives of nvdisasm have higher priority + '.__elf_ident_osabi' : (lambda args: self.__dir_elfheader('ident_osabi' , args)), + '.__elf_ident_abiversion' : (lambda args: self.__dir_elfheader('ident_abiversion', args)), + '.__elf_type' : (lambda args: self.__dir_elfheader('type' , args)), + '.__elf_machine' : (lambda args: self.__dir_elfheader('machine' , args)), + '.__elf_version' : (lambda args: self.__dir_elfheader('version' , args)), + '.__elf_entry' : (lambda args: self.__dir_elfheader('entry' , args)), + '.__elf_phoff' : (lambda args: self.__dir_elfheader('phoff' , args)), + '.__elf_shoff' : (lambda args: self.__dir_elfheader('shoff' , args)), + '.__elf_flags' : (lambda args: self.__dir_elfheader('flags' , args)), + '.__elf_ehsize' : (lambda args: self.__dir_elfheader('ehsize' , args)), + '.__elf_phentsize' : (lambda args: self.__dir_elfheader('phentsize' , args)), + '.__elf_phnum' : (lambda args: self.__dir_elfheader('phnum' , args)), + '.__elf_shentsize' : (lambda args: self.__dir_elfheader('shentsize' , args)), + '.__elf_shnum' : (lambda args: self.__dir_elfheader('shnum' , args)), + '.__elf_shstrndx' : (lambda args: self.__dir_elfheader('shstrndx' , args)), + # + '.__section_name' : (lambda args: self.__dir_sectionheader('name' , args)), + '.__section_type' : (lambda args: self.__dir_sectionheader('type' , args)), + '.__section_flags' : (lambda args: self.__dir_sectionheader('flags' , args)), + '.__section_addr' : (lambda args: self.__dir_sectionheader('addr' , args)), + '.__section_offset' : (lambda args: self.__dir_sectionheader('offset' , args)), + '.__section_size' : (lambda args: self.__dir_sectionheader('size' , args)), + '.__section_link' : (lambda args: self.__dir_sectionheader('link' , args)), + '.__section_info' : (lambda args: self.__dir_sectionheader('info' , args)), + '.__section_entsize' : (lambda args: self.__dir_sectionheader('entsize' , args)), + # + '.__segment' : self.__dir_segment, + '.__segment_offset' : (lambda args: self.__dir_segmentheader('offset' , args)), + '.__segment_vaddr' : (lambda args: self.__dir_segmentheader('vaddr' , args)), + '.__segment_paddr' : (lambda args: self.__dir_segmentheader('paddr' , args)), + '.__segment_filesz' : (lambda args: self.__dir_segmentheader('filesz' , args)), + '.__segment_memsz' : (lambda args: self.__dir_segmentheader('memsz' , args)), + '.__segment_align' : (lambda args: self.__dir_segmentheader('align' , args)), + '.__segment_startsection' : (lambda args: self.__dir_segmentheader('startsection' , args)), + '.__segment_endsection' : (lambda args: self.__dir_segmentheader('endsection' , args))} + + def reset(self): + self.__mLineNo = 0 + self.__mInTextSection = False + + self.__mCurrSection = None + self.__mCurrSegment = None + self.__mCuAsmFile = CuAsmFile() + + self.__mSectionDict = OrderedDict() + self.__mSymbolDict = OrderedDict() + self.__mSegmentList = [] + self.__mFixupList = [] # Fixup values that should be modified + + self.__mLabelDict = OrderedDict() # labels + self.__mSecSizeLabel = OrderedDict() # labels that defined at last of one section + self.__mRelList = [] # relocations + + self.__mNVInfoOffsetLabels = {} # key:sectionname, value: tuple(NVInfo_Attr, prefix) + self.__mInsIndex = 0 # Current instruction index + self.m_Arch = None + + self.__mPadSizeBeforeSecHeader = 0 # number of padding bytes before section header + + # TODO: not implemented yet + # current all the entries are copied from cubin + # self.__mStrList = [] # string may have identical entries + # self.__mShstrDict = OrderedDict() # entries + + @CuAsmLogger.logTimeIt + def parse(self, fname): + ''' Parsing input file + + General parsing work flow: + - scan whole file, gathering file headers, section headers/contents, segment headers + build fixup lists, split kernel text sections for kernel assembler. + - build internal tables, such as .shstrtab, .strtab. .symtab (Currently just copied except symbol size) + - build kernel text sections, update .nv.info sections if necessary. + update relocations if there are any. + - evaluate fixups, patching the bytes of corresponding section data. + - build relocation sections + - layout sections, update file header, section header, segment header accordingly + - write to file/stream + ''' + self.reset() + + CuAsmLogger.logEntry('Parsing file %s'%fname) + + self.__mFilename = fname + if not os.path.isfile(fname): + raise self.__assert(False, "Cannot find input cuasm file %s!!!"%fname) + else: + with open(fname, 'r') as fin: + self.__mLines = fin.readlines() + + self.__preScan() + self.__gatherTextSectionSizeLabel() + + self.__buildInternalTables() + self.__evalFixups() # + self.__parseKernels() + # + self.__buildRelocationSections() + + # Section layouting should be called when all sizes of sections are determined. + # But section contents can be modified (but not resized) + # + # The layout will also determine the size label of text sections + # which may affect the symbol size in symtab + self.__layoutSections() + + self.__updateSymtab() + + @CuAsmLogger.logTimeIt + def saveAsCubin(self, fstream): + if isinstance(fstream, str): + fout = open(fstream, 'wb') + needClose = True + CuAsmLogger.logEntry('Saving as cubin file %s...'%fstream) + else: + fout = fstream + needClose = False + CuAsmLogger.logEntry('Saving as cubin file to stream...') + + disppos = lambda s: CuAsmLogger.logSubroutine("%#08x(%08d) : %s"%(fout.tell(), fout.tell(), s)) + # write ELF file header + disppos('FileHeader') + fout.write(self.__mCuAsmFile.buildFileHeader()) + + # write section data + for sname,sec in self.__mSectionDict.items(): + print(f"{sname=}") + b = sec.getData() + printb(b) + disppos('SectionData %s'%sname) + sec.writePaddedData(fout) + + # write padding bytes before section header + if self.__mPadSizeBeforeSecHeader > 0: + disppos('Padding %d bytes before section header' % self.__mPadSizeBeforeSecHeader) + fout.write(b'\x00' * self.__mPadSizeBeforeSecHeader) + + # write section headers + for sname,sec in self.__mSectionDict.items(): + disppos('SectionHeader %s'%sname) + fout.write(sec.buildHeader()) + + # write segment headers + for seg in self.__mSegmentList: + disppos('Segment') + fout.write(seg.build()) + + if needClose: + fout.close() + + def setInsAsmRepos(self, fname, arch): + self.__mCuInsAsmRepos = CuInsAssemblerRepos(fname, arch=arch) + +#### Procedures, every function is a seperate parsing step. + @CuAsmLogger.logTraceIt + def __preScan(self): + ''' first scan to gather sections/symbol + + build all entries for labels. + ''' + + for line in self.__mLines: + nline = CuAsmParser.stripComments(line).strip() + self.__mLineNo += 1 + + if len(nline)==0: # skip blank/all-comments lines + continue + + ltype = self.__getLineType(nline) + if ltype is None: + self.__assert(False, "Unreconized line contents:\n %s"%line) + + elif ltype == 'label': + res = self.m_label.match(nline) + rlabel = res.groups()[0] + pos = self.__tellLocal() + + label = self.__checkNVInfoOffsetLabels(self.__mCurrSection, rlabel, pos) + + if label not in self.__mLabelDict: + self.__mLabelDict[label] = CuAsmLabel(label, self.__mCurrSection, + pos, self.__mLineNo) + else: + v = self.__mLabelDict[label] + self.__assert(False, 'Redefinition of label %s! First occurrence in Line%d!'% + (v.name, v.lineno)) + + elif ltype == 'directive': + + res = self.m_directive.match(nline) + cmd = res.groups()[0] + + # print('Run directive %s @line %d.'%(cmd, self.__mLineNo)) + + self.__assert(cmd in self.__dirDict, 'Unknown directive %s!!!' %cmd) + + farg = res.groups()[1].strip() + if len(farg) == 0: + args = [] + else: + args = re.split(r'\s*,\s*', farg) + + # run the directive + self.__dirDict[cmd](args) + elif ltype == 'code': + # During prescan, write all zeros for placeholder + pos = self.m_Arch.getInsOffsetFromIndex(self.__mInsIndex) + self.__mCurrSection.seek(pos) + + # all contents of .text section will be re-written + self.__emitBytes(b'\x00'*self.m_Arch.getInstructionLength()) + self.__mInsIndex += 1 + + elif ltype == 'blank': + continue + + @CuAsmLogger.logTraceIt + def __gatherTextSectionSizeLabel(self): + self.__mSecSizeLabel = OrderedDict() + for label, labelobj in self.__mLabelDict.items(): + secname = labelobj.section.name + if not secname.startswith('.text'): + continue + + if labelobj.offset == self.__mSectionDict[secname].getDataSize(): + # print(f'Size label {label} for {secname}!') + self.__mSecSizeLabel[secname] = labelobj + + @CuAsmLogger.logTraceIt + def __parseKernels(self): + # scan text sections to assemble kernels + section_markers = splitAsmSection(self.__mLines) + regnumdict = {} + for secname in section_markers: + if secname.startswith('.text.'): + section = self.__mSectionDict[secname] + m0, m1 = section_markers[secname] + self.__mCurrSection = section + self.__parseKernelText(section, m0, m1) + section.updateResourceInfo() + kname = secname[6:] # strip ".text." + symidx = self.__getSymbolIdx(kname) + regnumdict[symidx] = section.extra['regnum'] + + sec = self.__mSectionDict['.nv.info'] + + # print(sec.getData().hex()) + nvinfo = CuNVInfo(sec.getData(), self.m_Arch) + self.m_Arch.setRegCountInNVInfo(nvinfo, regnumdict) + sec.setData(nvinfo.serialize()) + + @CuAsmLogger.logTraceIt + def __buildInternalTables(self): + ''' Build .shstrtab/.strtab/.symtab entries. + + ''' + self.__mShstrtabDict = buildStringDict(self.__mSectionDict['.shstrtab'].getData()) + self.__mStrtabDict = buildStringDict(self.__mSectionDict['.strtab'].getData()) + self.__mSymtabDict = CuAsmSymbol.buildSymbolDict(self.__mStrtabDict, + self.__mSectionDict['.symtab'].getData()) + + # @CuAsmLogger.logTraceIt + def __parseKernelText(self, section, line_start, line_end): + CuAsmLogger.logProcedure('Parsing kernel text of "%s"...'%section.name) + + kasm = CuKernelAssembler(ins_asm_repos=self.__mCuInsAsmRepos, version=self.m_Arch) + + p_textline = re.compile(r'\[([\w:-]+)\](.*)') + + ins_idx = 0 + for lineidx in range(line_start, line_end): + line = self.__mLines[lineidx] + + nline = CuAsmParser.stripComments(line).strip() + self.__mLineNo = lineidx + 1 + + if len(nline)==0 or (self.m_label.match(nline) is not None) or (self.m_directive.match(nline) is not None): + continue + + res = p_textline.match(nline) + if res is None: + self.__assert(False, 'Unrecognized code text!') + + ccode_s = res.groups()[0] + icode_s = res.groups()[1] + + if c_ControlCodesPattern.match(ccode_s) is None: + self.__assert(False, f'Illegal control code text "{ccode_s}"!') + + addr = self.m_Arch.getInsOffsetFromIndex(ins_idx) + c_icode_s = self.__evalInstructionFixup(section, addr, icode_s) + + print("Parsing %s : %s "%(ccode_s, c_icode_s)) + try: + kasm.push(addr, c_icode_s, ccode_s) + except Exception as e: + self.__assert(False, 'Error when assembling instruction "%s":\n %s'%(nline, e)) + + ins_idx += 1 + + # rewrite text sections + codebytes = kasm.genCode() + section.seek(0) + section.emitBytes(codebytes) + + # update offsets in NVInfo + kname = section.name[6:] # strip '.text.' + info_sec = self.__mSectionDict['.nv.info.' + kname] + + if kname in self.__mNVInfoOffsetLabels: + offset_label_dict = self.__mNVInfoOffsetLabels[kname] + offset_label_dict.update(kasm.m_ExtraInfo) + else: + offset_label_dict = kasm.m_ExtraInfo.copy() + + nvinfo = CuNVInfo(info_sec.getData(), self.m_Arch) + nvinfo.updateNVInfoFromDict(offset_label_dict) + info_sec.setData(nvinfo.serialize()) + + @CuAsmLogger.logTraceIt + def __sortSections(self): + ''' Sort the sections. (TODO: Not implemented yet, all sections are kept as is.) + + Some section orders may do not matter, but the ELF segments may have some requirements ??? (TODO: checkit.) + This is a sample layout of sections: + + Index Offset Size ES Align Type Flags Link Info Name + 1 40 2d9 0 1 STRTAB 0 0 0 .shstrtab + 2 319 416 0 1 STRTAB 0 0 0 .strtab + 3 730 2e8 18 8 SYMTAB 0 2 10 .symtab + 4 a18 2a0 0 1 PROGBITS 0 0 0 .debug_frame + 5 cb8 b4 0 4 CUDA_INFO 0 3 0 .nv.info + 6 d6c 6c 0 4 CUDA_INFO 0 3 17 .nv.info._Z4testPiS_S_ + 7 dd8 40 0 4 CUDA_INFO 0 3 1b .nv.info._Z5childPii + 8 e18 40 0 4 CUDA_INFO 0 3 1c .nv.info._Z5stestfPf + 9 e58 4 0 4 CUDA_INFO 0 3 1a .nv.info._Z2f3ii + a e5c 4 0 4 CUDA_INFO 0 3 18 .nv.info._Z2f1ii + b e60 4 0 4 CUDA_INFO 0 3 19 .nv.info._Z2f2ii + c e68 40 10 8 REL 0 3 14 .rel.nv.constant0._Z4testPiS_S_ + d ea8 50 10 8 REL 0 3 17 .rel.text._Z4testPiS_S_ + e ef8 60 18 8 RELA 0 3 17 .rela.text._Z4testPiS_S_ + f f58 20 10 8 REL 0 3 1b .rel.text._Z5childPii + 10 f78 30 10 8 REL 0 3 1d .rel.nv.global.init + 11 fa8 60 10 8 REL 0 3 4 .rel.debug_frame + 12 1008 118 0 4 PROGBITS 2 0 0 .nv.constant3 + 13 1120 8 0 8 PROGBITS 2 0 17 .nv.constant2._Z4testPiS_S_ + 14 1128 188 0 4 PROGBITS 2 0 17 .nv.constant0._Z4testPiS_S_ + 15 12b0 16c 0 4 PROGBITS 2 0 1b .nv.constant0._Z5childPii + 16 141c 170 0 4 PROGBITS 2 0 1c .nv.constant0._Z5stestfPf + 17 1600 900 0 80 PROGBITS 6 3 18000011 .text._Z4testPiS_S_ + 18 1f00 80 0 80 PROGBITS 6 3 18000012 .text._Z2f1ii + 19 1f80 200 0 80 PROGBITS 6 3 18000013 .text._Z2f2ii + 1a 2180 200 0 80 PROGBITS 6 3 18000014 .text._Z2f3ii + 1b 2380 180 0 80 PROGBITS 6 3 a000016 .text._Z5childPii + 1c 2500 100 0 80 PROGBITS 6 3 8000017 .text._Z5stestfPf + 1d 2600 24 0 8 PROGBITS 3 0 0 .nv.global.init + 1e 2624 40 0 4 NOBITS 3 0 0 .nv.global + ''' + + # TODO: + # section_weights = ['.shstrtab', '.strtab', '.symtab', '.debug_frame', '.nv.info'] + + pass + + @CuAsmLogger.logTraceIt + def __buildRelocationSections(self): + + relSecDict = defaultdict(lambda : []) + + for rel in self.__mRelList: + if rel.isRELA(): + sname = '.rela' + rel.section.name + else: + sname = '.rel' + rel.section.name + + # FIXME: insert REL/RELA sections if necessary + relSecDict[sname].append(rel) + + # CHECK: The order of rel entries probably does not matter + # But to reduce unmatchness w.r.t. original cubin + # The order is reversed as the official toolkit does. + for sname in relSecDict: + section = self.__mSectionDict[sname] + rellist = relSecDict[sname] + nrel = len(rellist) + for i in range(nrel): + rel = rellist.pop() # FIFO of list + section.emitBytes(rel.buildEntry()) + + @CuAsmLogger.logTraceIt + def __evalFixups(self): + for i,fixup in enumerate(self.__mFixupList): + try: + # check relocation + # Relocation rules for fixups (NOT include the text section): + # 1. dtype in dword/word + # 2. expr is non-literal (0x**) + # 3. expr not started with index@, no @srel present + # + # CHECK: what if "Symbol + label@srel ? " + # seems still a relocation, but the value is the label value instead of zero. + + expr = fixup.expr + + if fixup.dtype not in self.rel_dtypes or expr.startswith('index@'): + val, _ = self.__evalExpr(expr) + fixup.value = val + self.__updateSectionForFixup(fixup) + else: # + # TODO: check other types of relocations + + # Check relocations for texture/surface references + if fixup.dtype == 'word': + res = self.m_texrel.match(expr) + if res is not None: + symname = res.groups()[0] + relsymid = self.__getSymbolIdx(symname) + reltype = 'R_CUDA_TEX_HEADER_INDEX' + + rel = CuAsmRelocation(fixup.section, fixup.offset, symname, relsymid, reltype=reltype, reladd=None) + self.__mRelList.append(rel) + continue # go process next fixup + + res2 = self.m_sufrel.match(expr) + if res2 is not None: + symname = res2.groups()[0] + relsymid = self.__getSymbolIdx(symname) + reltype = 'R_CUDA_SURF_HEADER_INDEX' + + rel = CuAsmRelocation(fixup.section, fixup.offset, symname, relsymid, reltype=reltype, reladd=None) + self.__mRelList.append(rel) + continue # go process next fixup + + # check explicit types of relocations + # Example : fun@R_CUDA_G64(C1) + # Seems only appear in debug version? + p_rel = re.compile(r'fun@(\w+)\(([^\)])\)') + res_rel = p_rel.match(expr) + if res_rel: + reltype = res_rel.groups()[0] + symname = res_rel.groups()[1] + symidx = self.__getSymbolIdx(symname) + + rel = CuAsmRelocation(fixup.section, fixup.offset, symname, symidx, reltype=reltype, reladd=None) + self.__mRelList.append(rel) + + continue + + # check other types of relocations + val, vs = self.__evalExpr(expr) + if isinstance(vs[0], str): # symbol name in vs[0] + symname = vs[0] + relsymid = self.__getSymbolIdx(symname) # index of symbol + if fixup.dtype=='word': + reltype='R_CUDA_32' + elif fixup.dtype=='dword': + reltype='R_CUDA_64' + else: + self.__assert(False, 'Unknown data type for relocation: %s'%fixup.dtype) + + rel = CuAsmRelocation(fixup.section, fixup.offset, symname, relsymid, reltype=reltype, reladd=None) + self.__mRelList.append(rel) + + if val is not None: # symbol + label@srel, seems the label value is filled. + fixup.value = val + self.__updateSectionForFixup(fixup) + + except Exception as e: + self.__assert(False, 'Error when evaluating fixup @line%d: expr=%s, msg=%s' + %(fixup.lineno, fixup.expr, e)) + + @CuAsmLogger.logTraceIt + def __updateSymtab(self): + + bio = BytesIO(self.__mSectionDict['.symtab'].getData()) + symsize = Config.CubinELFStructs.Elf_Sym.sizeof() + + for i, s in enumerate(self.__mSymtabDict): + symid, syment = self.__mSymtabDict[s] + + if s in self.__mLabelDict: + syment['st_value'] = self.__mLabelDict[s].offset + + if s in self.__mSymbolDict: # symbols explicitly defined in assembly + symobj = self.__mSymbolDict[s] + symobj.value = self.__mLabelDict[s].offset + symobj.sizeval, _ = self.__evalExpr(symobj.size) + + syment['st_size'] = symobj.sizeval + + # print(syment) + CuAsmSymbol.resetSymtabEntryValueSize(bio, i*symsize, symobj.value, symobj.sizeval) + + else: # some symbols does not have corresponding labels, such as vprintf + pass + self.__mSectionDict['.symtab'].setData(bio.getvalue()) + + @CuAsmLogger.logTraceIt + def __layoutSections(self): + ''' Layout section data, do section padding if needed. Update section header.offset/size. + + Update segment range accordingly. + Update ELF file header accordingly. + ''' + + # initialize the offset as the ELF header size + elfheadersize = Config.CubinELFStructs.Elf_Ehdr.sizeof() + file_offset = elfheadersize + mem_offset = elfheadersize + prev_sec = None + + sh_edges = {} # key=secname, value = (file_start, file_end, mem_start, mem_end) + # First pass to get the size of every section + # NOTE: the size of current section depends the padding, which is determined by next section + # Seems only for text section? For other sections, padding will not count in size? + for secname, sec in self.__mSectionDict.items(): + if secname == '': + continue + + # print(secname) + align = sec.addralign + if prev_sec is not None and prev_sec.name.startswith('.text'): + align = 128 + file_offset, mem_offset = self.__updateSectionPadding(prev_sec, file_offset, mem_offset, align) + + sec.size = sec.getDataSize() + sec.offset = file_offset + + sec.header['size'] = sec.size + sec.header['offset'] = sec.offset + + prev_sec = sec + sh_edges[secname] = (file_offset, 0, mem_offset, 0) + + mem_offset += sec.size + if sec.header['type'] != 'SHT_NOBITS': + file_offset += sec.size + + # ??? + if prev_sec is not None and prev_sec.name.startswith('.text'): + file_offset, mem_offset = self.__updateSectionPadding(prev_sec, file_offset, mem_offset, 128) + + # Section pass to build the section edges, for locating segment range + for secname, sec in self.__mSectionDict.items(): + if secname == '': + continue + + sec.size = sec.getDataSize() + sec.header['size'] = sec.size + + if sec.header['type'] != 'SHT_NOBITS': + fsize = sec.size + msize = fsize + else: + fsize = 0 + msize = sec.size + + file_pos, _, mem_pos, _ = sh_edges[secname] + sh_edges[secname] = (file_pos, file_pos + fsize, mem_pos, mem_pos + msize) + + # FIXME: better alignment for headers ? + file_offset, self.__mPadSizeBeforeSecHeader = alignTo(file_offset, 8) + + # Current only the normal order is support: + # ELFHeader -> SectionData -> SectionHeader -> SegmentHeader + # Other orders may be possible, but not supported yet. + + SecHeaderLen = len(self.__mSectionDict) * Config.CubinELFStructs.Elf_Shdr.sizeof() + + self.__mCuAsmFile.fileHeader['shoff'] = file_offset + + phoff = file_offset + SecHeaderLen + phlen = self.__mCuAsmFile.fileHeader['phentsize'] * self.__mCuAsmFile.fileHeader['phnum'] + self.__mCuAsmFile.fileHeader['phoff'] = phoff + + sh_edges[PROGRAM_HEADER_TAG] = phoff, phoff+phlen, phoff, phoff+phlen + + for seg in self.__mSegmentList: + if seg.header['type'] == 'PT_PHDR': + seg.header['offset'] = file_offset + SecHeaderLen + seg.header['filesz'] = Config.CubinELFStructs.Elf_Phdr.sizeof() * len(self.__mSegmentList) + seg.header['memsz'] = seg.header['filesz'] + + elif seg.header['type'] == 'PT_LOAD': + # if startsection is empty, this segment is empty + # Seems a convention of compiler? + if seg.header['startsection'] != '' and seg.header['endsection'] != '': + file_start0, file_end0, mem_start0, mem_end0 = sh_edges[seg.header['startsection']] + file_start1, file_end1, mem_start1, mem_end1 = sh_edges[seg.header['endsection']] + + seg.header['offset'] = file_start0 + seg.header['filesz'] = file_end1 - file_start0 + seg.header['memsz'] = mem_end1 - mem_start0 + + else: + msg = 'Unknown segment type %s!'%seg.header['type'] + CuAsmLogger.logError(msg) + raise Exception(msg) + + # update header + seg.updateHeader() + +#### Directives + def __dir_headerflags(self, args): + self.__assertArgc('.headerflags', args, 1, allowMore=False) + self.__mCuAsmFile.headerflags = args[0] + + def __dir_elftype(self, args): + self.__assertArgc('.elftype', args, 1, allowMore=False) + self.__mCuAsmFile.elftype = args[0] + + def __dir_section(self, args): + self.__assertArgc('.section', args, 3, allowMore=False) + + # for implict sections, quotes are used for embracing the section name + # mainly for the NULL section with empty name "" + # thus the quotes will be stripped + secname = args[0].strip('"') + + self.__assert(secname not in self.__mSectionDict, 'Redefinition of section "%s"!'%secname) + self.__mCurrSection = CuAsmSection(secname, args[1], args[2]) + + CuAsmLogger.logSubroutine('Line %6d: New section "%s"'%(self.__mLineNo, secname)) + + self.__mSectionDict[secname] = self.__mCurrSection + + if args[0].startswith('.text.'): + self.__mInTextSection = True + self.__mInsIndex = 0 + else: + self.__mInTextSection = False + + def __dir_sectionflags(self, args): + self.__assertArgc('.sectionflags', args, 1, allowMore=False) + self.__mCurrSection.flags.append(args[0]) + + def __dir_sectionentsize(self, args): + self.__assertArgc('.sectionentsize', args, 1, allowMore=False) + self.__mCurrSection.entsize = int(args[0]) + + def __dir_sectioninfo(self, args): + self.__assertArgc('.sectioninfo', args, 1, allowMore=False) + self.__assert(self.__mCurrSection is not None, "No active section!") + + # TODO: parse info, check correctness + self.__mCurrSection.info.append(args[0]) + + def __dir_byte(self, args): + self.__assertArgc('.word', args, 1, allowMore=True) + self.__emitTypedBytes('byte', args) + + def __dir_dword(self, args): + ''' currently 1 dword = 8 bytes + + NOTE: .dword may reference a relocation symbol. + ''' + + self.__assertArgc('.dword', args, 1, allowMore=True) + self.__emitTypedBytes('dword', args) + + def __dir_align(self, args): + ''' .align directive may have different operations, depending on the context. + + Usually .align will pad current buffer with zeros/nops to required alignment. + But for the first .align directive of a section, it also sets the alignment + requirement of current section, which means the padding is done to last + section, thus will not affect the local offset of current section. + + For `.align` inside a section, the padding counts to the local offset, + thus will affect all the local fixup values. + ''' + + self.__assertArgc('.align', args, 1, allowMore=False) + try: + align = int(args[0]) + except: + self.__assert(False, ' unknown alignment (%s)!' % args[0]) + + self.__assert(align &(align-1) == 0, ' alignment(%d) should be power of 2!' % align) + self.__mCurrSection.emitAlign(align) + + def __dir_short(self, args): + self.__assertArgc('.short', args, 1, allowMore=True) + self.__emitTypedBytes('short', args) + + def __dir_word(self, args): + self.__assertArgc('.word', args, 1, allowMore=True) + self.__emitTypedBytes('word', args) + + def __dir_type(self, args): + ''' .type will define the symbol type. + + Example: .type flist ,@object + .type $str ,@object + .type vprintf,@function + ''' + + self.__assertArgc('.type', args, 2, allowMore=False) + symbol = args[0] + if symbol not in self.__mSymbolDict: + self.__mSymbolDict[symbol] = CuAsmSymbol(symbol) + + stype = args[1] + self.__assert(stype in CuAsmSymbol.SymbolTypes, + 'Unknown symbol type %s! Available: %s.'%(stype, str(CuAsmSymbol.SymbolTypes))) + self.__mSymbolDict[symbol].type = stype + + def __dir_size(self, args): + self.__assertArgc('.size', args, 2, allowMore=False) + symbol = args[0] + if symbol not in self.__mSymbolDict: + self.__mSymbolDict[symbol] = CuAsmSymbol(symbol) + + # NOTE: the size of a symbol is probably an expression + # this will be evaluted when generating symbol tables + self.__mSymbolDict[symbol].size = args[1] + + def __dir_global(self, args): + '''.global defines a global symbol. + + A global symbol is visible to linker. For a cubin, it can be accessed by + the driver api function `cuModuleGetGlobal`. + ''' + + self.__assertArgc('.global', args, 1, allowMore=False) + + symbol = args[0] + if symbol not in self.__mSymbolDict: + self.__mSymbolDict[symbol] = CuAsmSymbol(symbol) + + CuAsmLogger.logSubroutine('Line %6d global symbol %s'%(self.__mLineNo, symbol)) + + self.__mSymbolDict[symbol].isGlobal = True + + def __dir_weak(self, args): + '''.weak defines a weak symbol. + + A weak symbol is declared in current module, but may be overwritten by strong symbols. + + Currently no scope is implemented, thus + ''' + + self.__assertArgc('.weak', args, 1, allowMore=False) + + symbol = args[0] + if symbol not in self.__mSymbolDict: + self.__mSymbolDict[symbol] = CuAsmSymbol(symbol) + + CuAsmLogger.logWarning('Line %d: Weak symbol found! The implementation is not complete, please be cautious...'%self.__mLineNo) + CuAsmLogger.logSubroutine('Line %6d: New weak symbol "%s"'%(self.__mLineNo, symbol)) + + self.__mSymbolDict[symbol].isGlobal = True + + def __dir_zero(self, args): + '''.zero emit zeros of specified length (in bytes).''' + + self.__assertArgc('.zero', args, 1, allowMore=False) + try: + # .zero only accepts a literal, no fixup allowed + size = int(args[0]) + self.__emitBytes(b'\x00'*size) + except: + self.__assert(False, 'Unknown arg (%s) for .zero!'% args[0]) + + def __dir_other(self, args): + '''.other defines some properties of a symbol. + + Examples: + .other _Z4testPiS_S_, @"STO_CUDA_ENTRY STV_DEFAULT" + .other _Z5childPii , @"STO_CUDA_ENTRY STV_DEFAULT" + .other _Z5stestfPf , @"STO_CUDA_ENTRY STV_DEFAULT" + ''' + self.__assertArgc('.other', args, 2, allowMore=False) + + symbol = args[0] + if symbol not in self.__mSymbolDict: + #self.__mSymbolDict[symbol] = CuAsmSymbol() + self.__assert(False, 'Undefined symbol %s!!!'%symbol) + + self.__mSymbolDict[symbol].other = args[1] + + def __dir_elfheader(self, attrname, args): + self.__assertArgc('.__elf_'+attrname, args, 1, allowMore=False) + self.__mCuAsmFile.fileHeader[attrname] = self.__cvtValue(args[0]) + if attrname == 'flags': + flags = int(args[0], 16) + smversion = flags & 0xff + self.m_Arch = CuSMVersion(smversion) + + if (not hasattr(self, '__mCuInsAsmRepos') + or self.__mCuInsAsmRepos is None + or (self.__mCuInsAsmRepos.getSMVersion() != self.m_Arch) ): + + CuAsmLogger.logSubroutine('Setting CuInsAsmRepos to default dict...') + + self.__mCuInsAsmRepos = CuInsAssemblerRepos(arch=self.m_Arch) + self.__mCuInsAsmRepos.setToDefaultInsAsmDict() + + def __dir_sectionheader(self, attrname, args): + self.__assertArgc('.__section_'+attrname, args, 1, allowMore=False) + self.__mCurrSection.header[attrname] = self.__cvtValue(args[0]) + + def __dir_segment(self, args): + self.__assertArgc('.__segment', args, 2, allowMore=False) + segment = CuAsmSegment(args[0].strip('"'), args[1]) + self.__mSegmentList.append(segment) + self.__mCurrSegment = segment + self.__mCurrSection = None + + def __dir_segmentheader(self, attrname, args): + self.__assertArgc('.__segment_'+attrname, args, 1, allowMore=False) + self.__mCurrSegment.header[attrname] = self.__cvtValue(args[0]) + +#### Subroutines + def __assert(self, flag, msg=''): + if not flag: + full_msg = 'Assertion failed in:\n' + full_msg += f' File {self.__mFilename}:{self.__mLineNo} :\n' + full_msg += f' {self.__mLines[self.__mLineNo-1].strip()}\n' + full_msg += f' {msg}' + CuAsmLogger.logError(full_msg) + raise Exception(full_msg) + + def __assertArgc(self, cmd, args, argc, allowMore=True): + ''' Check the number of arguments.''' + if allowMore: + flag = len(args)>=argc + es = 'at least ' + else: + flag = len(args)==argc + es = '' + + self.__assert(flag, '%s requires %s%d args! %d given: %s.' + %(cmd, es, argc, len(args), str(args)) ) + + def __tellLocal(self): + ''' tell current pos inside current active section.''' + + if self.__mCurrSection is not None: + return self.__mCurrSection.tell() + else: + raise Exception("Cannot tell local pos without active section!") + + def __evalVar(self, var): + """Evaluate a single variable + + Args: + var ([string]): the variable expression + + Returns: + (value, is_sym) + """ + + # symbol + if var in self.__mSymtabDict: + is_sym = True + else: + is_sym = False + + # int literal + if m_intval.match(var): + return eval(var), is_sym + + if var.endswith('@srel'): + label = var.replace('@srel', '') + if label not in self.__mLabelDict: + raise Exception('Unknown expression %s'%var) + + return self.__mLabelDict[label].offset, is_sym + + if var in self.__mLabelDict: + return self.__mLabelDict[var].offset, is_sym + + raise Exception('Unknown expression %s'%var) + + def __evalExpr(self, expr): + ''' Evaluate the expression. + + value = value_a ((+|-) value_b)? + Return: Tuple(value, Tuple(value_a, op, value_b) ) + + For symbol at position a, the original symbol string will be returned as value a. + + Examples: + Expr Value Section + index@(symbol) symbol index non-text + (.Label) label offset + (.L0-.L1) + + NOTE: This subroutine has no context info, making it hard to interprete + thus all exceptions should be captured in __evalFixups, showing the full context + ''' + + # For expr: index@(symbol) + if expr.startswith('index@'): # index of symbol + symname = expr[6:].strip(' ()') + index = self.__getSymbolIdx(symname) + if index is None: + raise Exception('Unknown symbol "%s"!!!'%symname) + return index, (index, None, None) + + rexpr = expr.strip('`() ') + res = re.match(r'([.\w$@]+)\s*(\+|-)*\s*([.\w$@]+)*', rexpr) # FIXME: what if the imme is negative??? + + if res is None: + raise Exception('Unknown expr %s !!!'%expr) + else: + a = res.groups()[0] + op = res.groups()[1] + b = res.groups()[2] + + aval, a_issym = self.__evalVar(a) + + if op is None: # only one var + if a_issym: # one symbol, definitely a relocation + return aval, (a , None, None) + else: # one label + return aval, (aval, None, None) + else: # + bval, b_issym = self.__evalVar(b) # in general context, the second var should not be symbol? + # but it's possible in size expression + + if a_issym: + a_realval = a + else: + a_realval = aval + + if op == '+': + return aval + bval, (a_realval, '+', bval) + elif op == '-': + return aval - bval, (a_realval, '-', bval) + else: # never reach here, only +/- can be matched by re pattern. + raise Exception('Unknown expr.op "%s"'%op) + + def __getSymbolIdx(self, symname): + ''' Get symbol index in symtab. ''' + if symname in self.__mSymtabDict: + return self.__mSymtabDict[symname][0] + else: + return None + + def __evalInstructionFixup(self, section, offset, s): + ''' Check fixups inside an instruction. + + Examples: + RET.REL.NODEC R20 `(_Z4testPiS_S_); + BRA `(.L_14); + Relocations: + 32@hi($str) => REL + 32@lo((_Z4testPiS_S_ + .L_8@srel)) => RELA + `(vprintf) => REL + + TODO: How to determine the type of `(.LABEL) ??? + For symbol or label defined in the same section, it's a fixup + Otherwise, it seems a relocation. (To be checked...) + ''' + p_ins_rel32 = re.compile(r'(32@hi|32@lo)\(([^\)]+)\)+') + r1 = p_ins_rel32.search(s) + if r1: + expr = r1.groups()[1] + val, val_sep = self.__evalExpr(expr) + symname = val_sep[0] + symidx = self.__getSymbolIdx(val_sep[0]) + relkey = r1.groups()[0] + reltype = self.m_Arch.getInsRelocationType(relkey) + + if val_sep[1] is not None: + rela = CuAsmRelocation(section, offset, symname, symidx, reltype=reltype, reladd=val_sep[2]) + self.__mRelList.append(rela) + else: + rel = CuAsmRelocation(section, offset, symname, symidx, reltype=reltype, reladd=None) + self.__mRelList.append(rel) + + ns = p_ins_rel32.sub('0x0', s) + return ns + + p_ins_label = re.compile(r'`\(([^\)]+)\)') + r2 = p_ins_label.search(s) + if r2: + # print(s) + label = r2.groups()[0] + self.__assert((label in self.__mLabelDict) or (label in self.__mSymtabDict), + 'Unknown label (%s) !!!'%label) + + # global symbols, no corresponding label (such as vprintf) + if (label not in self.__mLabelDict) and (label in self.__mSymtabDict): + # print(s) + symname = label + symidx = self.__getSymbolIdx(symname) + reltype = self.m_Arch.getInsRelocationType('target') + rel = CuAsmRelocation(section, offset, symname, symidx, reltype=reltype, reladd=None) + self.__mRelList.append(rel) + ns = p_ins_label.sub('0x0', s) + return ns + + clabel = self.__mLabelDict[label] + if section.name == clabel.section.name: # hardcoded target in current section + val = clabel.offset + ns = p_ins_label.sub('%#x'%val, s) + return ns + else: # relocations, since the target is in another section + symname = label + symidx = self.__getSymbolIdx(symname) + reltype = self.m_Arch.getInsRelocationType('target') + rel = CuAsmRelocation(section, offset, symname, symidx, reltype=reltype, reladd=None) + self.__mRelList.append(rel) + ns = p_ins_label.sub('0x0', s) + return ns + + # No fixup patterns found + return s + + def __updateSectionForFixup(self, fixup): + ''' Update the corresponding section location for fixup.''' + + _, blen = self.dtype_pattern[fixup.dtype] + bs = int.to_bytes(fixup.value, blen, 'little') + fixup.section.updateForFixup(fixup.offset, bs) + + CuAsmLogger.logSubroutine('Eval fixup "%s" @line%d to %#x'%(fixup.expr, fixup.lineno, fixup.value)) + # print(fixup) + + def __emitBytes(self, bs): + '''emit bytes to current section.''' + self.__mCurrSection.emitBytes(bs) + + def __getLineType(self, line): + '''There can be three line types: + + 1. Directive: starts with ".\w+", but no following ":" + 2. Label: label name followed by ":" + 3. Instruction text: only in section with name prefix ".text", + and not a label line + (4. Blank lines, skipped) + + **NOTE**: usually all blanks lines will be skipped by the parser + ''' + + if len(line)==0: + return 'blank' + elif self.m_label.match(line) is not None: + return 'label' + elif self.m_directive.match(line) is not None: + return 'directive' + elif self.__mInTextSection: + return 'code' + else: + return None + #raise Exception("Unrecognized line contents!") + + def __emitTypedBytes(self, dtype, args): + dp, dsize = self.dtype_pattern[dtype] + + for arg in args: + # TODO: check contents of arg is really a fixup/relocation(may not defined yet!) ? + #if dp.match(arg): + # self.__emitBytes(bytes.fromhex(arg[2:])) + if arg.startswith('0x'): + argv = int(arg, 16) + arg_byte = argv.to_bytes(dsize, 'little') + self.__emitBytes(arg_byte) + else: + # NOTE: currently all unknowns go to fixup list, + # fixup will handle the relocations if needed. + + # all fixup values will be updated by the assembler + fixup = CuAsmFixup(self.__mCurrSection, self.__tellLocal(), + arg, dtype, self.__mLineNo) + + self.__mFixupList.append(fixup) + + # emit zeros as placeholder + self.__emitBytes(b'\x00'*dsize) + + def __cvtValue(self, s): + ''' Convert input string to int if possible.''' + if m_intval.match(s): + return eval(s) + elif s.startswith('"') and s.endswith('"'): + return s.strip('"') + else: + return s + + def __pushSectionSizeLabel(self): + '''Identify the last label that marks the end of a text section. + + DEPRECATED !!! + + The text section size label will be gathered in the procedure __gatherTextSectionSizeLabel() + ''' + if self.__mCurrSection is not None and self.__mCurrSection.name.startswith('.text') and self.__mLabelDict is not None: + key, lastlabel = self.__mLabelDict.popitem() + if self.__mCurrSection.name == lastlabel.section.name and lastlabel.offset == self.__mCurrSection.tell(): + self.__mSecSizeLabel[self.__mCurrSection.name] = lastlabel + self.__mLabelDict[key] = lastlabel # push it back + else: + self.__mLabelDict[key] = lastlabel # push it back + + def __genSectionPaddingBytes(self, sec, size): + '''Generate padding bytes for section with given size.''' + if sec.name.startswith('.text'): + padbytes = self.m_Arch.getPadBytes() + else: + padbytes = b'\x00' + + if size % len(padbytes) != 0: + raise Exception('Invalid padding size for section %s'%sec.name) + + npad = size // len(padbytes) + return npad * padbytes + + def __updateSectionPadding(self, sec, file_offset, mem_offset, align): + ''' Update section padding with size. + + For text sections: padding to the original section data, update size + For other sections: padding to seperate padbytes, keep size unchanged + For nobits sections: do nothing. + ''' + if sec is None: + return file_offset, mem_offset + + if sec.name.startswith('.text'): + align = max(align, sec.addralign) + file_offset, fpadsize = alignTo(file_offset, align) + mem_offset, mpadsize = alignTo(mem_offset, align) + + sec.emitPadding(self.__genSectionPaddingBytes(sec, fpadsize)) + + # FIXME: This treatment is weird, but the text sections seems always aligned + # and last label of .text section seems to be the padded offset. + # + # Update size label offset, it will be used in symbol size evaluation. + # I don't quite understand why it's this way, but let's just keep it as is. + if sec.name in self.__mSecSizeLabel: + sizelabel = self.__mSecSizeLabel[sec.name] + # NOTE: donot use sec.size here + sizelabel.offset = sec.getDataSize() + CuAsmLogger.logSubroutine(f'Reset size label "{sizelabel.name}" of {sec.name} to {sec.getDataSize()}!') + + elif sec.header['type'] == 'SHT_NOBITS': + mem_offset, mpadsize = alignTo(mem_offset, align) + sec.padsize = mpadsize + sec.padbytes = mpadsize * b'\x00' + else: + file_offset, fpadsize = alignTo(file_offset, align) + mem_offset, mpadsize = alignTo(mem_offset, align) + + sec.padsize = fpadsize + sec.padbytes = fpadsize * b'\x00' + + sec.updateHeader() + + return file_offset, mem_offset + + def __calcSegmentRange(self, sec_start, sec_end): + + inRange = False + seg_off = 0 + filesz = 0 + memsz = 0 + + for sname, sec in self.__mSectionDict.items(): + if sname == sec_start: + inRange = True + seg_off = sec.offset + f_off = seg_off + m_off = seg_off + + if inRange: + psize = sec.getPaddedDataSize() + m_off += psize + if sec.header['type'] != 'SHT_NOBITS': + f_off += psize + + if sname == sec_end: + inRange = False + break + + filesz = f_off - seg_off + memsz = m_off - seg_off + + return seg_off, filesz, memsz + + def __checkNVInfoOffsetLabels(self, section, labelname, offset): + ''' Check whether the label is a NVInfoOffsetLabel, push to label offset dict if necessary. + + Valid offset label should be in form: + .CUASM_OFFSET_LABEL.{SectionName}.{NVInfoAttributeName}.{Identifier} + + Identifier should be unique for every offset label (label cannot be defined twice). + (A grammar sugar is to use "#", which will be replaced by "L+{LineNo}" such as "L000002f8" + + Example: + .CUASM_OFFSET_LABEL._Z4testPiS_S_.EIATTR_EXIT_INSTR_OFFSETS.0: + .CUASM_OFFSET_LABEL._Z4testPiS_S_.EIATTR_EXIT_INSTR_OFFSETS.#: + + Return: real label name + + ''' + + # TODO: some offset labels (such as EXIT, CTAID.Z) may be detected automatically + + if not labelname.startswith('.CUASM_OFFSET_LABEL'): + return labelname + + self.__assert(section.name.startswith('.text'), 'CUASM_OFFSET_LABEL should be defined in a text section!') + + kname = section.name[6:] + vs = labelname[1:].split('.') + self.__assert(len(vs)==4, 'Offset label should be in form: .CUASM_OFFSET_LABEL.{SectionName}.{NVInfoAttributeName}.{Identifier}') + self.__assert(vs[1] == kname, 'CUASM_OFFSET_LABEL should include kernel name in second dot part!') + + if kname not in self.__mNVInfoOffsetLabels: + self.__mNVInfoOffsetLabels[kname] = {} + + # .CUASM_OFFSET_LABEL._Z4testPiS_S_.EIATTR_EXIT_INSTR_OFFSETS.0: + attr = vs[2] + if attr in self.__mNVInfoOffsetLabels[kname]: + self.__mNVInfoOffsetLabels[kname][attr].append(offset) + else: + self.__mNVInfoOffsetLabels[kname][attr] = [offset] + + if vs[3] == '#': + lstr = 'L%08x'%self.__mLineNo + return labelname[:-1] + lstr + else: + return labelname + +#### Help functions to display some internal states. + + def dispFixupList(self): + print('Fixup list:') + if self.__mFixupList is None or len(self.__mFixupList)==0: + print(' ' + str(self.__mFixupList)) + + for i,f in enumerate(self.__mFixupList): + print("Fixup %3d: %s"%(i, str(f))) + + print() + + def dispRelocationList(self): + print('Relocation list:') + if self.__mRelList is None or len(self.__mRelList)==0: + print(' No relocations.') + + for i,r in enumerate(self.__mRelList): + print('Relocation %3d: %s'%(i, r)) + print() + + def dispSectionList(self): + print('Section list:') + sdict = self.__mSectionDict + if sdict is None or len(sdict) == 0: + print(' No sections found.') + return + + print(' Idx Offset Size ES AL Type Flags Link Info Name') + i = 0 + for s in sdict: + sec = sdict[s] + ss = '%4x' % i + ss += ' {offset:6x} {size:6x} {entsize:4x}'.format(**sec.header) + ss += ' {:3x}'.format(sec.addralign) + if isinstance(sec.header['type'], str): + ss += ' {type:12s}'.format(**sec.header) + else: + ss += ' {type:<12x}'.format(**sec.header) + + ss += ' {flags:6x}'.format(**sec.header) + ss += ' {link:6x} {info:8x}'.format(**sec.header) + ss += ' ' + sec.name + print(ss) + + i += 1 + + print() + + def dispSymbolDict(self): + print('\nSymbols:') + for i,s in enumerate(self.__mSymbolDict): + symbol = self.__mSymbolDict[s] + print('Symbol %3d: %s'%(i,symbol)) + print() + + def dispSymtabDict(self): + print('\nSymtab:') + for s in self.__mSymtabDict: + symid, syment = self.__mSymtabDict[s] + print('Symbol %3d (%s): %s'%(symid, s, syment)) + if s in self.__mSymbolDict: + print(' %s'%self.__mSymbolDict[s]) + print() + + def dispLabelDict(self): + print('\nLabels: ') + for i,l in enumerate(self.__mLabelDict): + v = self.__mLabelDict[l] + print('Label %3d: %s'%(i, str(v))) + print() + + def dispSegmentHeader(self): + print('Segment headers:') + for seg in self.__mSegmentList: + print(seg.header) + + def dispFileHeader(self): + print('File header:') + print(self.__mCuAsmFile.fileHeader) + + def dispTables(self): + # self.buildInternalTables() + print('.shstrtab:') + for i, idx in enumerate(self.__mShstrtabDict): + print('%3d \t0x%x \t%s'%(i, idx, self.__mShstrtabDict[idx])) + + print('.strtab:') + for i, idx in enumerate(self.__mStrtabDict): + print('%3d \t0x%x \t%s'%(i, idx, self.__mStrtabDict[idx])) + + print('.symtab') + for i, s in enumerate(self.__mSymtabDict): + print('%3d \t%s'%(i, s)) + + @CuAsmLogger.logTimeIt + def saveCubinCmp(self, cubinname, sav_prefix): + ''' A simple helper function to display current contents vs cubin in bytes. ''' + + fasm = open(sav_prefix+'_asm.txt', 'w') + fbin = open(sav_prefix+'_bin.txt', 'w') + + felf = open(cubinname, 'rb') + ef = ELFFile(felf) + + fasm.write('FileHeader:\n' + str(self.__mCuAsmFile.getFileHeaderStruct()) + '\n') + fbin.write('FileHeader:\n' + str(ef.header) + '\n' ) + + # write section headers+data + for sname,sec in self.__mSectionDict.items(): + fasm.write('# Section %s\n'%sname) + fasm.write(str(sec.getHeaderStruct()) + '\n') + if sec.getHeaderStruct()['sh_type'] != 'SHT_NOBITS': + fasm.write(bytes2Asm(sec.getData()) +'\n\n') + else: + fasm.write('\n') + + # write segment headers + for seg in self.__mSegmentList: + fasm.write(str(seg.getHeaderStruct())+'\n') + + # write section headers+data + for sec in ef.iter_sections(): + fbin.write('# Section %s\n'%sec.name) + fbin.write(str(sec.header) + '\n') + if sec.header['sh_type'] != 'SHT_NOBITS': + fbin.write(bytes2Asm(sec.data()) + '\n\n') + else: + fbin.write('\n') + + # write segment headers + for seg in ef.iter_segments(): + fbin.write(str(seg.header) + '\n') + + fasm.close() + fbin.close() + felf.close() + + @staticmethod + def stripComments(s): + ''' Strip comments of a line. + + NOTE: cross line comments are not supported yet. + ''' + + s = CuAsmParser.m_cppcomment.subn(' ', s)[0] # replace comments as a single space, avoid unwanted concatination + s = CuAsmParser.m_ccomment.subn(' ', s)[0] + s = CuAsmParser.m_bracomment.subn(' ', s)[0] + s = re.subn(r'\s+', ' ', s)[0] # replace one or more spaces/tabs into one single space + + return s.strip() + +if __name__ == '__main__': + pass diff --git a/tinygrad/runtime/support/assembler/CuControlCode.py b/tinygrad/runtime/support/assembler/CuControlCode.py new file mode 100644 index 0000000000000..fe914c289b2d1 --- /dev/null +++ b/tinygrad/runtime/support/assembler/CuControlCode.py @@ -0,0 +1,146 @@ +# -*- coding: utf-8 -*- +import re + +# Pattern for control codes string +# ChangeLog 20220915: remove reuse field +c_ControlCodesPattern = re.compile(r'B(0|-)(1|-)(2|-)(3|-)(4|-)(5|-):R[0-5\-]:W[0-5\-]:(Y|-):S\d{2}') +c_ControlStringLen = 19 + +class CuControlCode: + def __init__(self, code): + c_waitbar, c_readbar, c_writebar, c_yield, c_stall = CuControlCode.splitCode(code) + + self.Barrier = c_waitbar + self.Read = c_readbar + self.Write = c_writebar + self.Yield = c_yield + self.Stall = c_stall + + def isYield(self): + ''' If yield flag is set(code=0).''' + return self.Yield == 0 + + def getStallCount(self): + ''' Get stall count.''' + return self.Stall + + def getReadSB(self): + ''' Get read scoreboard id, return None if not set.''' + return None if self.Read == 7 else self.Read + + def getWriteSB(self): + ''' Get write scoreboard id, return None if not set.''' + return None if self.Write == 7 else self.Write + + def getBarrierSet(self): + ''' Get a set of waiting scoreboards, return empty set if waiting on none.''' + return {i for i in range(6) if (self.Barrier & (1<0)} + + @staticmethod + def splitCode(code): + ''' Split control codes into parts. + + # c.f. : https://github.com/NervanaSystems/maxas/wiki/Control-Codes + # : https://arxiv.org/abs/1903.07486 + # reuse waitbar rbar wbar yield stall + # 0000 000000 000 000 0 0000 + # + # NOTE : It is known that for some instructions(HMMA.SP), reuse may use some other bits. + # And for some other instructions(TEXS/BRA/...), the reuse bits may be used for other encoding fields. + # Since reuses are displayed as explicit modifiers, we will not split the reuse field any more. + # Other fields will be extracted and encoded as control codes. + # TODO : Maybe we can treat those control fields as normal modifier? + ''' + + c_stall = (code & 0x0000f) >> 0 + c_yield = (code & 0x00010) >> 4 + c_writebar = (code & 0x000e0) >> 5 # write dependency barrier + c_readbar = (code & 0x00700) >> 8 # read dependency barrier + c_waitbar = (code & 0x1f800) >> 11 # wait on dependency barrier + + return c_waitbar, c_readbar, c_writebar, c_yield, c_stall + + @staticmethod + def splitCode2(code): + ''' Split control codes into parts. + + Mostly same as splitCode, but with yield/stall combined. + ''' + + c_ystall = (code & 0x0001f) >> 0 + c_writebar = (code & 0x000e0) >> 5 # write dependency barrier + c_readbar = (code & 0x00700) >> 8 # read dependency barrier + c_waitbar = (code & 0x1f800) >> 11 # wait on dependency barrier + + return c_waitbar, c_readbar, c_writebar, c_ystall + + @staticmethod + def mergeCode(c_waitbar, c_readbar, c_writebar, c_yield, c_stall): + code = c_waitbar<<11 + code += c_readbar<<8 + code += c_writebar<<5 + code += c_yield<<4 + code += c_stall + return code + + @staticmethod + def decode(code): + c_waitbar, c_readbar, c_writebar, c_yield, c_stall = CuControlCode.splitCode(code) + + s_yield = '-' if c_yield !=0 else 'Y' + s_writebar = '-' if c_writebar == 7 else '%d'%c_writebar + s_readbar = '-' if c_readbar == 7 else '%d'%c_readbar + s_waitbar = ''.join(['-' if (c_waitbar & (2**i)) == 0 else '%d'%i for i in range(6)]) + s_stall = '%02d' % c_stall + + return 'B%s:R%s:W%s:%s:S%s' % (s_waitbar, s_readbar, s_writebar, s_yield, s_stall) + + @staticmethod + def encode(s): + if not c_ControlCodesPattern.match(s): + raise ValueError('Invalid control code strings: %s !!!'%s) + + s_waitbar, s_readbar, s_writebar, s_yield, s_stall = tuple(s.split(':')) + + waitbar_tr = str.maketrans('012345-','1111110') + + c_waitbar = int(s_waitbar[:0:-1].translate(waitbar_tr), 2) + c_readbar = int(s_readbar[1].replace('-', '7')) + c_writebar = int(s_writebar[1].replace('-','7')) + c_yield = int(s_yield!='Y') + c_stall = int(s_stall[1:]) + + code = CuControlCode.mergeCode(c_waitbar, c_readbar, c_writebar, c_yield, c_stall) + + return code + +if __name__ == '__main__': + cs = ['B--2---:R0:W1:-:S07', + 'B01--4-:R-:W-:-:S05', + 'B------:R-:W0:-:S01', + 'B------:R-:W-:-:S01', + 'B------:R2:W1:-:S01', + 'B0-----:R-:W-:Y:S04', + 'B------:R-:W-:-:S01', + 'B0----5:R0:W5:Y:S05', + 'B------:R-:W0:-:S02', + 'B0-----:R-:W0:-:S02', + 'B0-----:R-:W-:Y:S04'] + + passed = True + for s in cs: + c = CuControlCode.encode(s) + s2 = CuControlCode.decode(c) + + print('0x%06x:'%c) + print(' %s'%s) + print(' %s'%s2) + if s != s2: + print('!!! Unmatched !') + passed = False + + if passed: + print("Test passed!!!") + else: + print("Test failed!!!") + \ No newline at end of file diff --git a/tinygrad/runtime/support/assembler/CuInsAssembler.py b/tinygrad/runtime/support/assembler/CuInsAssembler.py new file mode 100644 index 0000000000000..cadcdf32f99bf --- /dev/null +++ b/tinygrad/runtime/support/assembler/CuInsAssembler.py @@ -0,0 +1,414 @@ +# -*- coding: utf-8 -*- + +import sympy +from sympy import Matrix # Needed by repr +from sympy.core.numbers import Rational +from io import StringIO +from tinygrad.runtime.support.assembler.CuSMVersion import CuSMVersion +from tinygrad.runtime.support.assembler.common import reprList, reprHexMat +from tinygrad.runtime.support.assembler.CuAsmLogger import CuAsmLogger + +class CuInsAssembler(): + '''CuInsAssembler is the assembler handles the values and weights of one type of instruction.''' + + def __init__(self, inskey, d=None, arch='sm_75'): + ''' Initializer. + + inskey is mandatory, d is for initialization from saved repr. + ''' + + self.m_InsKey = inskey + if d is not None: + self.initFromDict(d) + else: + self.m_InsRepos = [] + self.m_InsModiSet = {} + + self.m_ValMatrix = None + self.m_PSol = None + self.m_PSolFac = None + self.m_ValNullMat = [] + self.m_Rhs = None + self.m_InsRecords = [] + self.m_ErrRecords = {} + + self.m_Arch = CuSMVersion(arch) + + def iterRecords(self): + ''' Iterate over all records, including normal records and error records.''' + + # m_InsRecords is a list of ins_info => (addr, code, s) + for r in self.m_InsRecords: + yield r + + # m_ErrRecords is a dict of {code_diff : ins_info => (addr, code, s) } + for _, r in self.m_ErrRecords.items(): + yield r + + def recordsFeeder(self): + for r in self.iterRecords(): + yield r[0], r[1], r[2], 0 + + def initFromDict(self, d): + self.m_InsKey = d['InsKey'] + self.m_InsRepos = d['InsRepos'] + self.m_InsModiSet = d['InsModiSet'] + + self.m_ValMatrix = d['ValMatrix'] + self.m_PSol = d['PSol'] + self.m_PSolFac = d['PSolFac'] + self.m_ValNullMat = d['ValNullMat'] + self.m_Rhs = d['Rhs'] + + self.m_InsRecords = d['InsRecords'] + self.m_ErrRecords = d['ErrRecords'] if 'ErrRecords' in d else {} + + self.m_Arch = d['Arch'] + + def initFromJsonDict(self, d): + pass + + def expandModiSet(self, modi): + ''' Push in new modifiers. + + NOTE: the order matters, since every modi has its own value. + ''' + + updated = False + for m in modi: + if m not in self.m_InsModiSet: + self.m_InsModiSet[m] = len(self.m_InsModiSet) + updated = True + + return updated + + def buildInsValVec(self, vals, modi, outRawList=False): + ''' Convert instruction value vector from vals and modifiers. + + NOTE: Due to performance reason of Matrix.nullspace(), vals are placed after modis. + Usually vals are dense, but modifiers are sparse. + This arrangement will make the valMatrix more like upper trangular, + and this will usually make carrying out the nullspace much much faster. + + TODO: currently opcode is also a modifer, maybe placed after modis? + ''' + + insval = [1 if m in modi else 0 for m in self.m_InsModiSet] # first comes modi + insval.extend(vals) # then follows vals + if outRawList: + return insval + else: + insvec = sympy.Matrix(insval) + return insvec + + def canAssemble(self, vals, modi): + """ Check whether the input code can be assembled with current info. + + Args: + vals ([int list]): value list in integers + modi ([str list]): modifier list in strings + + Returns: + (None , None) : input can be assembled + (brief, info) : brief in ['NewModi', 'NewVals'], info gives the detailed info + """ + + if not all([m in self.m_InsModiSet for m in modi]): + brief = 'NewModi' + info = 'Unknown modifiers: (%s)' % (set(modi) - set(self.m_InsModiSet.keys())) + return brief, info + else: + insvec = self.buildInsValVec(vals, modi) + + if self.m_ValNullMat is not None: + insrhs = self.m_ValNullMat * insvec + if not all([v==0 for v in insrhs]): + return 'NewVals', 'Insufficient basis, try CuAsming more instructions!' + + return None, None + + def push(self, vals, modi, code, ins_info): + ''' Push in a new instruction. + + When its code can be assembled, verify the result, + otherwise add new information to current assembler. + @return (flag, info): + flag = True (Expected result) + "NewModi" / "NewVals" for new information + "Verified" for no new information, but the results is consistent + flag = False (Unexpected result) + "NewConflict" for new conflict information + "KnownConflict" for known inconsistent assembling result + ''' + + if not all([m in self.m_InsModiSet for m in modi]): + # If new instruction contains unknown modifier, + # it's never possible to be assembled by current assembler. + CuAsmLogger.logProcedure('Pushing new modi (%s, %-20s): %s' % (self.m_Arch.formatCode(code), self.m_InsKey, ins_info)) + updated = self.expandModiSet(modi) + self.m_InsRepos.append((vals, modi, code)) + self.buildMatrix() + self.m_InsRecords.append(ins_info) + return True, 'NewModi' + else: + # If the vals of new instruction lies in the null space of + # current ValMatrix, it does not contain new information. + insvec = self.buildInsValVec(vals, modi) + + if self.m_ValNullMat is None: + doVerify = True + else: + insrhs = self.m_ValNullMat * insvec + doVerify = all([v==0 for v in insrhs]) + + if doVerify: + # return 'Verified' + inscode = self.m_PSol.dot(insvec) / self.m_PSolFac + + if inscode != code: + if inscode.is_integer: + code_diff = inscode - code + if code_diff not in self.m_ErrRecords: + self.m_ErrRecords[code_diff] = ins_info + CuAsmLogger.logError("Error when verifying for %s" % self.m_InsKey) + CuAsmLogger.logError(" Asm : %s" % ins_info[-1]) + CuAsmLogger.logError(" InputCode : %s" % self.m_Arch.formatCode(code)) + CuAsmLogger.logError(" AsmCode : %s" % self.m_Arch.formatCode(inscode)) + return False, 'NewConflict' + else: + CuAsmLogger.logDebug("Known code conflict for %s!" % self.m_InsKey) + return False, 'KnownConflict' + else: + CuAsmLogger.logCritical("FATAL! Non-integral code assembled for %s" % self.m_InsKey) + CuAsmLogger.logCritical(" Asm : %s" % ins_info[-1]) + CuAsmLogger.logCritical(" InputCode : %s" % self.m_Arch.formatCode(code)) + CuAsmLogger.logCritical(" AsmCode : (%s)!" % str(inscode)) + + # It's very unlikely the diff is just the code it self. (usually opcode will match) + code_diff = code + self.m_ErrRecords[code_diff] = ins_info + + return False, 'NewConflict' + + # print(self.__repr__()) + # raise Exception("Inconsistent instruction code!") + # return False + else: + # print("Verified: 0x%032x" % code) + return True, 'Verified' + + else: + CuAsmLogger.logProcedure('Pushing new vals (%s, %-20s): %s' % (self.m_Arch.formatCode(code), self.m_InsKey, ins_info)) + self.m_InsRepos.append((vals, modi, code)) + self.m_InsRecords.append(ins_info) + self.buildMatrix() + return True, 'NewVals' + + # Never be here + # return True + + def buildCode(self, vals, modi): + '''Assemble with the input vals and modi. + + NOTE: This function didn't check the sufficiency of matrix.''' + + inscode = 0 + for v0, vs in zip(self.m_PSol[-len(vals):], vals): + inscode += v0 * vs + + for m in modi: + inscode += self.m_PSol[self.m_InsModiSet[m]] + + if self.m_PSolFac == 1: + return int(inscode) + else: + return int(inscode//self.m_PSolFac) + + def buildMatrix(self, solve_method='LU'): + if len(self.m_InsRepos) == 0: + return None, None + + M = [] + b = [] + for vals, modis, code in self.m_InsRepos: + l = self.buildInsValVec(vals, modis, outRawList=True) + M.append(l) + b.append(code) + + self.m_ValMatrix = sympy.Matrix(M) + self.m_Rhs = sympy.Matrix(b) + self.m_ValNullMat = self.getNullMatrix(self.m_ValMatrix) + + if self.m_ValNullMat is not None: + M2 = self.m_ValMatrix.copy() + b2 = self.m_Rhs.copy() + for nn in range(self.m_ValNullMat.rows): + M2 = M2.row_insert(0, self.m_ValNullMat.row(nn)) + b2 = b2.row_insert(0, sympy.Matrix([0])) + self.m_PSol = M2.solve(b2, method=solve_method) + else: + self.m_PSol = self.m_ValMatrix.solve(self.m_Rhs, method=solve_method) + + self.m_PSol, self.m_PSolFac = self.getMatrixDenomLCM(self.m_PSol) + return self.m_ValMatrix, self.m_Rhs + + def solve(self): + ''' Try to solve every variable. + + This is possible only when ValNullMat is none. + ''' + + if self.m_ValNullMat is None: + x = self.m_ValMatrix.solve(self.m_Rhs) + print('Solution: ') + for i,v in enumerate(x): + print('%d : 0x%+033x' % (i,v)) + return x + else: + print('Not solvable!') + return None + + def printSolution(self): + print("InsKey = %s" % self.m_InsKey) + nvals = len(self.m_InsRepos[0][0]) + nmodi = len(self.m_InsModiSet) + + names = ['V%d'%v for v in range(nvals)] + names.extend([0] * nmodi) + for m, midx in self.m_InsModiSet.items(): + names[midx+nvals] = m + + # the order of solutions are altered for better display. + # vals are displayed before modis + + rev_sol = [] + rev_sol.extend(self.m_PSol[nmodi:]) + rev_sol.extend(self.m_PSol[:nmodi]) + + for name, val in zip(names, rev_sol): + if val % self.m_PSolFac == 0: + print(" %24s : %#32x" % (name, val // self.m_PSolFac)) + else: + print(" %24s : %#32x / %#x " % (name, val, self.m_PSolFac)) + + def reprPSol(self): + nvals = len(self.m_InsRepos[0][0]) + nmodi = len(self.m_InsModiSet) + + names = [0 for _ in range(nmodi)] + for m, midx in self.m_InsModiSet.items(): + names[midx] = m + + names.append('Pred') + names.extend(['V%d'%v for v in range(1, nvals)]) + + slist = [] + vlist = [] + maxvlen = 0 + for ival in range(nvals+nmodi): + sval = '%#x' % self.m_PSol[ival, 0] + slist.append(sval) + vlist.append(self.m_PSol[ival, 0]) + maxvlen = max(maxvlen, len(sval)) + + maxnlen = 0 + for name in names: + maxnlen = max(maxnlen, len(name)) + + fac = int(self.m_PSolFac) + + sio = StringIO() + sio.write('Matrix([\n') + if self.m_PSolFac == 1: + for vname, s in zip(names, slist): + ss = ' '*(maxvlen-len(s)) + s + sio.write(f'[ {ss}], # {vname}\n') + else: + for vname, s, v in zip(names, slist, vlist): + ss = ' '*(maxvlen-len(s)) + s + ns = ' '*(maxnlen-len(vname)) + vname + vv = int(v) + if vv % fac == 0: + vt = vv // fac + sio.write(f'[ {ss}], # {ns} : {vt:#32x}\n') + else: + sio.write(f'[ {ss}], # {ns} : {vv:#32x} / {fac:#x}\n') + sio.write('])') + + return sio.getvalue() + + def getNullMatrix(self, M): + ''' Get the null space of current matrix M. + + And get the lcm for all fractional denominators. + The null matrix is only for checking sufficiency of ValMatrix, + thus it won't be affected by any non-zero common factor. + Fractional seems much slower than integers. + ''' + + ns = M.nullspace(simplify=True) + if len(ns)==0: + return None + else: + nm = ns[0] + for n in ns[1:]: + nm = nm.row_join(n) + + # NullSpace won't be affected by a common factor. + nmDenom, dm = self.getMatrixDenomLCM(nm.T) + return nmDenom + + def getMatrixDenomLCM(self, M): + ''' Get lcm of matrix denominator. + + In sympy, operation of fractionals seems much slower than integers. + Thus we multiply a fraction matrix with the LCM of all denominators, + then divide the result with the LCM. + ''' + + dm = 1 + for e in M: + if isinstance(e, Rational): + nom, denom = e.as_numer_denom() + dm = sympy.lcm(denom, dm) + return (M*dm, dm) + + def __repr__(self): + ''' A string repr of current ins assembler. + + This will be used to dump it to text file and read back by setFromDict. + ''' + sio = StringIO() + + sio.write('CuInsAssembler("", {"InsKey" : %s, \n' % repr(self.m_InsKey) ) + # sio.write(' "InsRepos" : %s, \n' % repr(self.m_InsRepos)) + sio.write(' "InsRepos" : ') + reprList(sio, self.m_InsRepos) + sio.write(', \n') + + sio.write(' "InsModiSet" : %s, \n' % repr(self.m_InsModiSet)) + sio.write(' "ValMatrix" : %s, \n' % repr(self.m_ValMatrix)) + sio.write(' "PSol" : %s, \n' % self.reprPSol()) + sio.write(' "PSolFac" : %s, \n' % repr(self.m_PSolFac)) + sio.write(' "ValNullMat" : %s, \n' % repr(self.m_ValNullMat)) + #sio.write(' "InsRecords" : %s, \n' % repr(self.m_InsRecords)) + + sio.write(' "InsRecords" : [') + #reprList(sio, self.m_InsRecords) + for addr, code, s in self.m_InsRecords: + sio.write('(%#08x, %s, "%s"),\n'%(addr, self.m_Arch.formatCode(code), s)) + sio.write('], \n') + + sio.write(' "ErrRecords" : {') + for code_diff, (addr, code, s) in self.m_ErrRecords.items(): + sio.write('%#x : (%#08x, %s, "%s"),\n'%(code_diff, addr, self.m_Arch.formatCode(code), s)) + sio.write('}, ') + + sio.write(' "Rhs" : %s, \n' % reprHexMat(self.m_Rhs)) + sio.write(' "Arch" : %s })' % repr(self.m_Arch)) + + return sio.getvalue() + + def __str__(self): + + return 'CuInsAssembler(%s)' % self.m_InsKey diff --git a/tinygrad/runtime/support/assembler/CuInsAssemblerRepos.py b/tinygrad/runtime/support/assembler/CuInsAssemblerRepos.py new file mode 100644 index 0000000000000..cbdc98493ba59 --- /dev/null +++ b/tinygrad/runtime/support/assembler/CuInsAssemblerRepos.py @@ -0,0 +1,405 @@ +# -*- coding: utf-8 -*- + +import re +import time +import struct +import traceback +from tinygrad.runtime.support.assembler.common import reprDict +from tinygrad.runtime.support.assembler.CuInsParser import CuInsParser +from tinygrad.runtime.support.assembler.CuInsAssembler import CuInsAssembler +from tinygrad.runtime.support.assembler.CuSMVersion import CuSMVersion +from tinygrad.runtime.support.assembler.CuAsmLogger import CuAsmLogger +from tinygrad.runtime.support.assembler.config import Config + +from io import StringIO +from sympy import Matrix +import os + +class CuInsAssemblerRepos(): + ''' A repository consists of a set of instruction assemblers. + + TODO: Version control? Should work with CuInsParser/CuInsFeeder. + ''' + StaticRepos = {} + + def __init__(self, InsAsmDict=None, arch=None): + self.resetArch(arch) + + if InsAsmDict is None: + self.reset(None) + elif isinstance(InsAsmDict, str): + self.initFromFile(InsAsmDict) + elif isinstance(InsAsmDict, dict): + self.reset(InsAsmDict) + else: + raise ValueError('Unknown input type of InsAsmDict!') + + def resetArch(self, arch): + if arch is not None: + self.m_Arch = CuSMVersion(arch) + self.m_InsParser = CuInsParser(arch) + else: + self.m_Arch = None + self.m_InsParser = None + + def convertArch(self, arch): + dst_arch = CuSMVersion(arch) + if dst_arch == self.m_Arch: + return + + self.resetArch(dst_arch) + for k, v in self.m_InsAsmDict.items(): + v.m_Arch = dst_arch + + def setToDefaultInsAsmDict(self): + vnum = self.m_Arch.getVersionNumber() + fname = Config.getDefaultInsAsmReposFile(vnum) + if os.path.isfile(fname): + self.initFromFile(fname) + else: + # No default InsAsmRepos, but the encoding can be copied from another version + if vnum in CuSMVersion.InsAsmReposAliasDict: + anum = CuSMVersion.InsAsmReposAliasDict[vnum] + aname = Config.getDefaultInsAsmReposFile(anum) + if os.path.isfile(aname): + CuAsmLogger.logWarning(f'No default InsAsmRepos for SM_{vnum} found! Use SM_{anum} instead...') + self.initFromFile(aname) + self.convertArch(anum) + return + + CuAsmLogger.logError(f'No default or alias InsAsmRepos for SM_{vnum} found! Use empty repos ...') + self.reset() + + @staticmethod + def getDefaultRepos(arch) -> 'CuInsAssemblerRepos': + repos = CuInsAssemblerRepos(arch=arch) + repos.setToDefaultInsAsmDict() + return repos + + @staticmethod + def getStaticRepos(arch) -> 'CuInsAssemblerRepos': + ''' Get a static repos for arch. + + NOTE: The purpose of this method is to avoid multiple instantiation. + Usually static repos will be read-only. + If it's read/write, be cautious for alias. + ''' + if arch not in CuInsAssemblerRepos.StaticRepos: + CuInsAssemblerRepos.StaticRepos[arch] = CuInsAssemblerRepos.getDefaultRepos(arch) + + return CuInsAssemblerRepos.StaticRepos[arch] + + def reset(self, InsAsmDict=None): + if InsAsmDict is None: + self.m_InsAsmDict = {} + else: + self.m_InsAsmDict = InsAsmDict + + def __getitem__(self, k): + return self.m_InsAsmDict[k] + + def __setitem__(self, k, v): + self.m_InsAsmDict[k] = v + + def __delitem__(self, k): + del self.m_InsAsmDict[k] + + def __constains__(self, k): + return k in self.m_InsAsmDict + + def __len__(self): + return len(self.m_InsAsmDict) + + def __iter__(self): + return iter(self.m_InsAsmDict) + + def items(self): + return self.m_InsAsmDict.items() + + def initFromFile(self, fname): + ''' Load repos from file. ''' + with open(fname,'r') as fin: + fconts = fin.read() + asm_repos = eval(fconts) + self.m_InsAsmDict = asm_repos.m_InsAsmDict + + for k, v in self.m_InsAsmDict.items(): + if self.m_Arch is None: + self.resetArch(v.m_Arch) + elif v.m_Arch != self.m_Arch: + CuAsmLogger.logWarning(f'InsAsm arch {v.m_Arch} of {k} does not match with repos {self.m_Arch}!!! Resetting...') + self.resetArch(v.m_Arch) + + # only check the first insasm + break + + def assemble(self, addr, s, precheck=True, showCandidates=True): + ''' Try to assemble the input instruction string. + + Raise KeyError when the ins_key is not found. + if precheck is true, a ValueError will be raised if it cannot be assembled. + ''' + ins_key, ins_vals, ins_modi = self.m_InsParser.parse(s, addr, 0) + if ins_key not in self.m_InsAsmDict: + msg = 'Unknown InsKey(%s) in Repos!' % ins_key + if showCandidates: + ckeys = self.getInsKeyCandidates(ins_key) + msg += '\n Available InsKeys: \n' + ckeys + # print(msg) + raise ValueError(msg) + + insAsm = self.m_InsAsmDict[ins_key] + if precheck: + brief, info = insAsm.canAssemble(ins_vals, ins_modi) + if brief is not None: + msg = 'Assembling failed (%s): %s'%(brief, info) + if showCandidates: + msg += '\n Known Records:\n' + for _, _, asm in insAsm.iterRecords(): + msg += ' '*8 + asm + '\n' + # print(msg) + raise ValueError(msg) + + code = insAsm.buildCode(ins_vals, ins_modi) + return code + + @CuAsmLogger.logTimeIt + def verify(self, feeder): + ''' Verify current repos. + + The feeder should yield (addr, code, asm, ctrl), but "ctrl" is not used. + ''' + res = True + t0 = time.time() + cnt = 0 + for addr, code, s, ctrlcodes in feeder: + cnt += 1 + try: + casm = self.assemble(addr, s) + if code != casm: + CuAsmLogger.logError('Error when verifying :') + CuAsmLogger.logError(' ' + s) + CuAsmLogger.logError(' CodeOrg: %s'%self.m_Arch.formatCode(code)) + CuAsmLogger.logError(' CodeAsm: %s'%self.m_Arch.formatCode(casm)) + # raise Exception('Assembled code not match!') + except Exception as e: + CuAsmLogger.logError(str(e)) + CuAsmLogger.logError('Error when assembling :') + CuAsmLogger.logError(' ' + s) + CuAsmLogger.logError(traceback.format_exc()) + res = False + + t1 = time.time() + + if res: + msg = "Verified %d ins in %8.3f secs." % (cnt, t1-t0) + if t0!=t1: + msg += " ~%8.2f ins/s." % (cnt/(t1-t0)) + + CuAsmLogger.logProcedure(msg) + else: + CuAsmLogger.logError("Verifying failed in %8.3f secs!!!" % (t1-t0)) + + return res + + @CuAsmLogger.logTimeIt + def update(self, feeder, ins_asm_dict=None): + ''' Update the input instruction assembler dict with input from feeder. + + Args: + feeder : yield (addr, code, asm, ctrl), but "ctrl" is not used. + ins_asm_dict : destination dict + For ins_asm_dict=None(default), use the internal self.m_InsAsmDict as dst. + Return: + ncnt : number of new records, 0 for unchanged + ''' + if ins_asm_dict is None: + ins_asm_dict = self.m_InsAsmDict + + t0 = time.time() + cnt = 0 + ncnt = 0 + + for addr, code, s, ctrlcodes in feeder: + cnt += 1 + # print('%#6x : %s'%(addr, s)) + ins_key, ins_vals, ins_modi = self.m_InsParser.parse(s, addr, code) + # + + if ins_key not in ins_asm_dict: + ins_asm_dict[ins_key] = CuInsAssembler(ins_key, arch=self.m_Arch) + + ins_info = (addr, code, s) + res_flag, res_info = ins_asm_dict[ins_key].push(ins_vals, ins_modi, code, ins_info) + + # if not res_flag: + # CuAsmLogger.logError("CuInsAsmRepos update error!!! Unmatched codes!") + # CuAsmLogger.logError(' Str : ' + s) + # CuAsmLogger.logError(' Addr: %#6x'%addr) + # CuAsmLogger.logInfo(repr(ins_asm_dict[ins_key])) + if res_info in {'NewModi', 'NewVals', 'NewConflict'}: + ncnt += 1 + + t1 = time.time() + msg = "Updated %d ins (%d new) in %8.3f secs ." % (cnt, ncnt, t1-t0) + if (t0!=t1): + msg += " ~%8.2f ins/s." % (cnt/(t1-t0)) + + CuAsmLogger.logProcedure(msg) + + return ncnt + + @CuAsmLogger.logTimeIt + def save2file(self, fname): + CuAsmLogger.logEntry('Saving to %s...'%fname) + with open(fname, 'w') as fout: + fout.write(self.__repr__()) + + @CuAsmLogger.logTimeIt + def rebuild(self): + ''' When the CuInsParser is updated, the meaning of ins value/modifier may have changed. + + Thus CuInsAsmRepos should be rebuilt from original input (saved in ins records) + + TODO: We may store some redundant records? + ''' + + tmp_ins_asm_dict = {} + feeder = self.recordsFeeder() + + self.update(feeder, tmp_ins_asm_dict) + self.m_InsAsmDict = tmp_ins_asm_dict + + @CuAsmLogger.logTimeIt + def merge(self, merge_source): + ''' Merge instruction assembler from another source. + + TODO: Check version? + ''' + if isinstance(merge_source, (str,dict)): + repos = CuInsAssemblerRepos(merge_source) + elif isinstance(merge_source, CuInsAssemblerRepos): + repos = merge_source + else: + raise TypeError('Unknown merge source type!') + + feeder = repos.recordsFeeder() + self.update(feeder) + + def iterRecords(self, key_filter = None): + ''' Iterate over all records from CuInsAssembler with keys filtered by key_filter. + + Return (addr, code, asm) + key_filter can be: + 1. list/tuple/dict of keys + 2. string of re pattern + 3. re pattern + 4. None for all keys + + ''' + + if key_filter is None: + keys = self.m_InsAsmDict.keys() + elif isinstance(key_filter, str): + keys = filter(lambda x: re.match(key_filter, x), self.m_InsAsmDict.keys()) + elif isinstance(key_filter, re.Pattern): + keys = filter(lambda x: key_filter.match(x), self.m_InsAsmDict.keys()) + elif isinstance(key_filter, (list, tuple, dict)): # any iteratable + keys = key_filter + elif callable(key_filter): + keys = filter(key_filter, self.m_InsAsmDict.keys()) + else: + raise TypeError(f'Unknown key_filter: {key_filter} with type {type(key_filter)}!!!') + + for ins_key in keys: + for r in self.m_InsAsmDict[ins_key].iterRecords(): + yield r + + def recordsFeeder(self, key_filter=None): + ''' A generator as internal instruction feeder, pulling from instruction records. + + Return (addr, code, asm, ctrl=0), same as CuInsFeeder. + key_filter can be: + 1. list/tuple/dict of keys + 2. string of re pattern + 3. re pattern + 4. None for all keys + ''' + + # Records is a list of ins_info => (addr, code, s), control codes is not stored, thus return 0. + for r in self.iterRecords(key_filter): + yield r[0], r[1], r[2], 0 + + def showErrRecords(self): + for ins_key, ins_asm in self.items(): + # print(ins_key) + if len(ins_asm.m_ErrRecords) > 0: + print(f'#### ErrRecords for {ins_key}:') + + for k, (addr, code, s) in ins_asm.m_ErrRecords.items(): + print(f' {addr:#08x} : {s}') + print(' Org: %s' % self.m_Arch.formatCode(code)) + acode = self.assemble(addr, s) + print(' Asm: %s' % self.m_Arch.formatCode(acode)) + diff = abs(acode - code) + diffs = self.m_Arch.formatCode(diff)[2:].replace('0', ' ') + + print(' Diff: %s' % diffs) + + def completePredCodes(self): + ''' Some instructions seem very rarely appear with guard predicates. + + Thus when the instruction assemblers are gathered from ptxas output, + many of them will not be able to encode predicates. + + This may give some useful information as performance guidelines. + However, there will be certainly some occasions predicates will be needed. + ''' + + feeder = self.genPredRecords() + self.update(feeder) + + def clearErrRecords(self): + for k in self.m_InsAsmDict: + self.m_InsAsmDict[k].m_ErrRecords = {} + + def genPredRecords(self): + ''' A generator that yields modified instruction info with predicates. ''' + for ins_key, ins_asm in self.m_InsAsmDict.items(): + ins_info = ins_asm.m_InsRecords[0] + pred_ins_info = self.m_Arch.genPredCode(ins_info) + if pred_ins_info is not None: + yield pred_ins_info[0], pred_ins_info[1], pred_ins_info[2], 0 + + def genUndefRecords(self): + key = 'UNDEF' + for v in [0x1, 0x2, 0x3]: + yield 0x0, v, f'{key} {v:#x};', 0 + + def getArchString(self): + return self.m_Arch.getVersionString().lower() + + def getSMVersion(self): + return self.m_Arch + + def __repr__(self): + sio = StringIO() + + sio.write('CuInsAssemblerRepos(') + reprDict(sio, self.m_InsAsmDict) + sio.write(', arch=%s)'%self.m_Arch) + + return sio.getvalue() + + def __str__(self): + return "CuInsAssemblerRepos(%d keys)" % len(self.m_InsAsmDict) + + def getInsKeyCandidates(self, key, n=5): + from difflib import get_close_matches + keys = self.m_InsAsmDict.keys() + cs = get_close_matches(key, keys, n=n) + if len(cs)==0: + return 'None' + else: + return '\n'.join([' '*8+s for s in cs]) + diff --git a/tinygrad/runtime/support/assembler/CuInsFeeder.py b/tinygrad/runtime/support/assembler/CuInsFeeder.py new file mode 100644 index 0000000000000..ecf93aa8ef5bf --- /dev/null +++ b/tinygrad/runtime/support/assembler/CuInsFeeder.py @@ -0,0 +1,755 @@ +# -*- coding: utf-8 -*- + +from io import StringIO +import re +from enum import Enum, auto +from tinygrad.runtime.support.assembler.CuSMVersion import CuSMVersion +from tinygrad.runtime.support.assembler.CuAsmLogger import CuAsmLogger +from tinygrad.runtime.support.assembler.CuControlCode import c_ControlStringLen, CuControlCode + +class SassLineType(Enum): + ''' Six types of lines in a dumped sass file: + + 1. function name line: + Function : _Z5ktestPmPi + 2. headerflags line: + .headerflags @"EF_CUDA_SM86 EF_CUDA_PTX_SM(EF_CUDA_SM86)" + 3. ins with code (first code line of SM7x/8x, normal code line of SM5x/6x) + /*0030*/ UMOV UR11, 0x14b00000 ; /* 0x14b00000000b7882 */ + 4. ins without code (second ins line of dual-issue) + /*5348*/ LDG.E.U16 R28, [R28] } + 5. code without ins (control code of SM5x/6x, second ins code of dual-issue, second code line of SM7x/8x) + /* 0xeed2200000071c1c */ + 6. others + Fatbin elf code: + ================ + arch = sm_35 + code version = [1,7] + producer = + host = windows + compile_size = 64bit + ............ + (blank lines) + + ''' + # Pattern that contains an instruction string (including address and code) + # NOTE: For maxwell/pascal, there may be braces "{}" for dual-issued instructions. + InsCode = 0, re.compile(r'^\s*\/\*(?P\w+)\*\/\s*\{?\s*(?P.*;)\s*\/\* (?P.*) \*\/') + + # Pattern for the dual-issued instruction, the code is in next line. + InsOnly = 1, re.compile(r'^\s*\/\*(?P\w+)\*\/\s*(?P.*\})') + + # + CodeOnly = 2, re.compile(r'^\s*\/\* (?P0x[0-9a-f]{16}) \*\/') + + # Function : _Z5ktestPmPi + FuncName = 3, re.compile(r'^\s*Function\s*:\s*(?P.*)') + + # .section .text._ZN8xmma_trt13implicit_gemm24cuda_reorder_hmma_filterENS0_26Reorder_imma_filter_paramsE,"ax",@progbits + SectionName = 4, re.compile(r'^\s*\.section\s*\.text\.(?P[^,\s]*)') + + # .headerflags @"EF_CUDA_SM86 EF_CUDA_PTX_SM(EF_CUDA_SM86)" + # .headerflags @"EF_CUDA_TEXMODE_UNIFIED EF_CUDA_64BIT_ADDRESS EF_CUDA_SM75 EF_CUDA_VIRTUAL_SM(EF_CUDA_SM75)" + HeaderFlag = 5, re.compile(r'^\s*\.headerflags.*EF_CUDA_SM(?P\d+)') + + # all other lines + Others = 6, re.compile('.*') + + def __init__(self, idx, pattern): + self.idx = idx + self.pattern = pattern + + def match(self, s): + return self.pattern.match(s) + + def search(self, s): + return self.pattern.search(s) + + @classmethod + def getLineType(cls, line): + for t in cls: + r = t.match(line) + if r: + return t, r + + raise Exception(f'Unknown line type for line:\n {line}') + + @classmethod + def getCallbackArgs(cls, op, res): + ''' Get call back args for current linetype. + + NOTE 1: return tuples + NOTE 2: all ints are converted + ''' + if op == cls.InsCode: + return int(res.group('addr'), 16), res.group('asm'), int(res.group('code'), 16) + elif op == cls.InsOnly: + return int(res.group('addr'), 16), res.group('asm') + elif op == cls.CodeOnly: + return int(res.group('code'), 16), + elif op == cls.FuncName: + return res.group('func'), + elif op == cls.SectionName: + return res.group('sec'), + elif op == cls.HeaderFlag: + return res.group('arch'), + elif op == cls.Others: + return res.group(), + else: + raise Exception(f'Unknown line type {op}!') + +class ParserState(Enum): + ''' State of the Parser. + + ''' + + Ready = auto() # may accept others, instructions, codes, funcnames + WaitForFunc = auto() # ??? not needed? + WaitForArch = auto() # Deprecated, the arch will be switched inline, thus no need to transfer to this state + + WaitForCode4 = auto() + WaitForCode3 = auto() + WaitForCode2 = auto() + WaitForCode1 = auto() + + WaitForIns7 = auto() + WaitForIns6 = auto() + WaitForIns5 = auto() + WaitForIns4 = auto() + WaitForIns3 = auto() + WaitForIns2 = auto() + WaitForIns1 = auto() + Invalid = auto() + +SLT = SassLineType +PS = ParserState + +def IterNone(): + while True: + yield None + +class StateTransferMatrix: + ''' The state transfer matrix is a dict of dict: + + tm = { s0 : {op0 : s1, op1 : s2, ...}, + s1 : {op0 : s0, op1 : s0, ...} + } + + If s not in tm, s must be an invalid state of current parser. + If op not in tm[s], op is an invalid op for current state. + ''' + def __init__(self): + self.TM = {} + + def __contains__(self, key): + return key in self.TM + + def __getitem__(self, key): + return self.TM[key] + + def __setitem__(self, key, value): + self.TM[key] = value + + def addop(self, s, op, ts, callback=None): + ''' Add a single link op(s) -> ts ''' + if s not in self.TM: + self.TM[s] = {} + + self.TM[s][op] = ts, callback + + def addops(self, s, ops, tss, callbacks=IterNone): + ''' Add a list/tuple of ops and transferred states.''' + if s not in self.TM: + self.TM[s] = {} + + for op, ts in zip(ops, tss): + self.TM[s][op] = ts + + def addop_dict(self, s, opd): + ''' Add ops and transferred states as a dict''' + if s not in self.TM: + self.TM[s] = {} + + # NOTE: keep current entries!!! Direct assign may overwrite them. + for op, vs in opd.items(): + if isinstance(vs, tuple): + self.TM[s][op] = vs # vs = (ts, callback) + else: + self.TM[s][op] = vs, None # with default callback to None + + def __str__(self): + return 'StateTransferMatrix: ' + repr(self.TM) + +class ParserStateMachine: + def __init__(self, init_state, tr_matrix): + self.state = init_state + self.tr_matrix = tr_matrix + + def reset(self, init_state, tr_matrix): + self.state = init_state + self.tr_matrix = tr_matrix + + def feed(self, op): + ''' Feed in an op and transfer the state.''' + + # current state don't accept the input op + if op not in self.tr_matrix[self.state]: + raise Exception(f'Invalid input {op} for state {self.state} (Unacceptable op)!!!') + + self.state = self.tr_matrix[self.state][op] + + # current state transferred to an invalid state + if self.state not in self.tr_matrix or self.state == PS.Invalid: + raise Exception(f'Invalid input {op} for state {self.state} (Invalid transferred state)!!!') + + return self.state + +class CuInsFeeder(): + def __init__(self, fstream, archfilter=None, insfilter=None): + """ Construct a instruction feeder, yield (addr, code, asm, ctrl). + + Args: + fstream (str or file stream): file name or the file object + arch (optional): should be a valid input for CuSMVersion. + Defaults to None, means all arches should be processed. + insfilter (optional): filter for lines. Usually used for feeding a perticular instruction. + insfilter can be: + 1. regex pattern string + 2. regex pattern object + 3. a callable function that accepts a line string input + Defaults to None (empty string also means None). + """ + + if isinstance(fstream, str): + self.__mFileName = fstream + self.__mFStream = open(self.__mFileName, 'r') + + else: + self.__mFileName = None + self.__mFStream = fstream + + # compile ins filter + if insfilter is None or insfilter=='': + self.__mInsFilterFun = lambda x: True + elif isinstance(insfilter, str): + p = re.compile(insfilter) + self.__mInsFilterFun = lambda x: p.search(x) + elif isinstance(insfilter, re.Pattern): + self.__mInsFilterFun = lambda x: insfilter.search(x) + elif callable(insfilter): + self.__mInsFilterFun = insfilter + else: + raise TypeError(f'Unknown type of insfilter {insfilter}!') + + if archfilter is None or archfilter == '': + self.__mArchFilterFun = lambda x: True + else: + arch = CuSMVersion(archfilter) + self.__mArchFilterFun = lambda x: x==arch + + self.__mLineNo = 0 + + self.CurrFuncName = '' + self.CurrArch = '' + + self.__SplitCodeList = lambda x: (x, x) + self.__mAddrList = [] + self.__mAsmList = [] + self.__mCodeList = [] + + self.__TMs = {'default': self.__getTrMatrixDefault(), + '3x' : self.__getTrMatrixForSM_3x(), + '5x6x' : self.__getTrMatrixForSM_5x6x(), + '7x8x' : self.__getTrMatrixForSM_7x8x()} + + self.__CurrTM = self.__TMs['default'] + self.__mPState = PS.Ready + + @staticmethod + def parseInsFilter(insfilter): + if insfilter is None or insfilter=='': + InsFilterFun = lambda x: True + elif isinstance(insfilter, str): + p = re.compile(insfilter) + InsFilterFun = lambda x: p.search(x) + elif isinstance(insfilter, re.Pattern): + InsFilterFun = lambda x: insfilter.search(x) + elif callable(insfilter): + InsFilterFun = insfilter + else: + raise TypeError(f'Unknown type of insfilter {insfilter}!') + + return InsFilterFun + + def nextParseLine(self): + ''' Parse next line. + + Return tuple(linetype, line, res) + linetype : type of next line; + line : contents of next line; + res : re match object for linetype. + + Return (None, None, None) if lines are exhausted. + ''' + line = self.readline() + if len(line) == 0: + return None, None, None + + linetype, res = SassLineType.getLineType(line) + return linetype, line, res + + def __feedLineOp(self, op, *args): + ''' Feed in an linetype op and transfer the state.''' + + # CuAsmLogger.logDebug(f'Line{self.__mLineNo:04d}: feedOp op:{op} with args:{args}') + # current state don't accept the input op + if op not in self.__CurrTM[self.__mPState]: + raise Exception(f'Invalid input {op} for state {self.__mPState} (Unacceptable op) @Line{self.__mLineNo:04d}!!!') + + _, callback = self.__CurrTM[self.__mPState][op] + if callback is not None: + callback(*args) + + # callback may change __CurrTM (such as switchArch) !!! + # thus the state should base on the new __CurrTM + self.__mPState, _ = self.__CurrTM[self.__mPState][op] + + # current state transferred to an invalid state + if self.__mPState not in self.__CurrTM or self.__mPState == PS.Invalid: + raise Exception(f'Invalid input {op} for state {self.__mPState} (Invalid transferred state) @Line{self.__mLineNo:04d}!!!') + + return self.__mPState + + def __iter__(self): + ''' yield (addr, code, asm, ctrl). + + NOTE: feeder will be re-initialized when the iterator is used again. + ''' + self.restart() + + doFeed = False + while True: + linetype, line, res = self.nextParseLine() + if linetype is None: # ended + break + elif linetype is SassLineType.HeaderFlag: + # Skip filtered arches + args = SassLineType.getCallbackArgs(linetype, res) + ns = self.__feedLineOp(linetype, *args) + doFeed = self.__mArchFilterFun(self.CurrArch) + continue + elif linetype in {SassLineType.FuncName, SassLineType.SectionName}: + # func name line appears before arch line, thus need to be processed + args = SassLineType.getCallbackArgs(linetype, res) + ns = self.__feedLineOp(linetype, *args) + continue + + if not doFeed: + continue + + args = SassLineType.getCallbackArgs(linetype, res) + ns = self.__feedLineOp(linetype, *args) + + if ns == ParserState.Ready: + for addr, code, asm, ctrl in self.__iterPopIns(): + if self.__filterIns(asm): + yield addr, code, asm, ctrl + + @CuAsmLogger.logTimeIt + def trans(self, fout, codeonly_line_mode='none'): + ''' Translate an input sass to sass with control codes. + The sass input is usually obtained by `cuobjdump -sass fname > a.sass`. + + fout : output filename or stream object + codeonly_line_mode : + whether to keep lines with only codes + such as the control code line for sm5x/6x, and 2nd line for sm7x/8x + + 'keep' : keep unchanged + 'none' : skipped (default) + + NOTE: the filter does not work for this function. + NOTE 2: this function is not quite robust, not recommended for any hand-written sass. + ''' + + if isinstance(fout, str): + fout_stream = open(fout, 'w+') + need_close = True + else: + fout_stream = fout + + if codeonly_line_mode == 'keep': + pCodeLine = lambda ctrl_str, l: f'{ctrl_str} {l}\n' + elif codeonly_line_mode == 'none': + pCodeLine = lambda ctrl_str, l: None + else: + pass + + self.restart() + + line_buffers = [] + out_buffers = [] + while True: + linetype, line, res = self.nextParseLine() + if linetype is None: + break + + line = line.rstrip() + line_buffers.append( (line, linetype)) + + args = SassLineType.getCallbackArgs(linetype, res) + ns = self.__feedLineOp(linetype, *args) + + pre_lt = None + line_len = -1 + if ns == ParserState.Ready: + inslist = [ins for ins in self.__iterPopIns()] + + for l, lt in line_buffers: + if lt == SLT.InsCode: + __, __, __, ctrl = inslist.pop(0) + ctrl_str = self.formatCtrlCodeString(ctrl) + line = f'{ctrl_str} {l}' + out_buffers.append(line + '\n') + if line_len == -1: + line_len = len(line) + elif lt == SLT.InsOnly: + __, __, __, ctrl = inslist.pop(0) + ctrl_str = self.formatCtrlCodeString(ctrl) + out_buffers.append(f'{ctrl_str} {l.rstrip()}') + elif lt == SLT.CodeOnly: + if pre_lt == SLT.InsOnly: + lcode = l.strip() + out_buffers[-1] = out_buffers[-1], lcode + else: + ctrl_str = self.formatCtrlCodeString(0, phantom_mode=True) + oline = pCodeLine(ctrl_str, l) + if oline is not None: + out_buffers.append(oline) + else: + out_buffers.append(l+'\n') + + pre_lt = lt + + for oline in out_buffers: + if isinstance(oline, tuple): + slen = line_len - len(oline[0]) - len(oline[1]) + nline = oline[0] + (slen*' ') + oline[1] + '\n' + fout_stream.write(nline) + else: + fout_stream.write(oline) + + line_buffers.clear() + out_buffers.clear() + + if need_close: + fout_stream.close() + + def extract(self, fout, *, func_filter=None, ins_filter=None): + ''' Extracting kernel matching the filter to fout. + + Sometimes whole kernel sass is needed to check the context of an instruction, + this will help to identify some rules of instruction correlations. + + fout: output filename + func_filter: filter for the function name, may be string/re.Pattern/callable + ins_filter: filter for the instruction + + Match rules: + 1. when func_filter matched the name, output first matched kernel; + 2. when ins_filter matched an instruction, output the first kernel containing the instruction; + ''' + buf = StringIO() + do_dump = False + + InsFilterFun = CuInsFeeder.parseInsFilter(ins_filter) + FuncFilterFun = CuInsFeeder.parseInsFilter(func_filter) + + def tryDump(): + if do_dump: + if buf.tell() == 0: + print('Empty buffer! Nothing to dump...') + return False + + print('================================') + print(buf.getvalue()) + with open(fout, 'w') as fout_stream: + print(f'Dump to file {fout}...') + fout_stream.write(buf.getvalue()) + return True + else: + return False + + while True: + linetype, line, res = self.nextParseLine() + + if linetype is None: + tryDump() + break + + if linetype == SLT.FuncName: + if tryDump(): + break + else: + if func_filter is not None: + if FuncFilterFun(res.group('func')): + do_dump = True + buf = StringIO() + elif linetype in {SLT.InsCode, SLT.InsOnly, SLT.CodeOnly}: + if InsFilterFun(line): + do_dump = True + else: + pass + + buf.write(line.rstrip() + '\n') + + if not do_dump: + print('Nothing to dump...') + + def formatCtrlCodeString(self, ccode, phantom_mode=False): + if self.CurrArch.getMajor()<5: + return '' + else: + if phantom_mode: + return ' '*(c_ControlStringLen+2) # +2 for "[]" + else: + return '[' + CuControlCode.decode(ccode) + ']' + + def __del__(self): + '''Close the stream if the handler is owned by this feeder.''' + + if self.__mFileName is not None and not self.__mFStream.closed: + self.__mFStream.close() + + def close(self): + if not self.__mFStream.closed: + self.__mFStream.close() + self.__mLineNo = 0 + return True + else: + return False + + def restart(self): + if self.__mFStream.seekable: + self.__mFStream.seek(0) + self.__mLineNo = 0 + else: + raise Exception("This feeder cannot be restarted!") + + def readline(self): + ''' A helper function for reading lines, with line number recorded.''' + self.__mLineNo += 1 + return self.__mFStream.readline() + + def lines(self): + ''' Iterator for reading the stream line by line. ''' + while True: + line = self.readline() + if len(line)>0: + yield line + else: + break + + def tell(self): + '''Report the progress of file or stream.''' + + return self.__mFStream.tell() + + def tellLine(self): + '''Report current line number.''' + + return self.__mLineNo + +#### subroutines for operation ins queue + def __pushAddr(self, addr): + self.__mAddrList.append(addr) + + def __pushAsm(self, asm): + self.__mAsmList.append(asm) + + def __pushCode(self, code): + self.__mCodeList.append(code) + + def __pushInsCode_3x(self, addr, asm, code): + self.__pushAddr(addr) + self.__pushAsm(asm) + self.__pushCode(code) + + def __pushInsCode(self, addr, asm, code): + self.__pushAddr(addr) + self.__pushAsm(asm) + self.__pushCode(code) + + def __pushInsOnly_5x6x(self, addr, asm): + self.__pushAddr(addr) + self.__pushAsm(asm) + + def __filterIns(self, asm): + ''' Check whether current instruction can pass the filter. + + True for pass, False for filterred. + ''' + return self.__mInsFilterFun(asm) + + def __SplitCodeList_3x(self, int_list): + ''' Split code list to (ctrl_list, code_list).''' + return [0 for _ in int_list], int_list + + def __SplitCodeList_5x6x(self, int_list): + ''' Split code list to (ctrl_list, code_list).''' + return CuSMVersion.splitCtrlCodeFromIntList_5x_6x(int_list) + + def __SplitCodeList_7x8x(self, int_list): + ''' Split code list to (ctrl_list, code_list).''' + + cs = [int_list[i] + (int_list[i+1]<<64) for i in range(0, len(int_list), 2)] + return CuSMVersion.splitCtrlCodeFromIntList_7x_8x(cs) + + def __iterPopIns(self): + ''' Pop (addr, code, asm, ctrl) iteratively.''' + clist, ilist = self.__SplitCodeList(self.__mCodeList) + for addr, code, asm, ctrl in zip(self.__mAddrList, + ilist, + self.__mAsmList, + clist): + yield addr, code, asm, ctrl + + # clear current buffer + self.__mAddrList = [] + self.__mCodeList = [] + self.__mAsmList = [] + + def __setFuncName(self, func): + self.CurrFuncName = func + + def __setSectionName(self, sec): + ''' section name is .text.funcname. ''' + self.CurrFuncName = sec + + def __switchArch(self, arch): + smversion = CuSMVersion(arch) + self.CurrArch = smversion + + if smversion.getMajor() == 3: + self.__CurrTM = self.__TMs['3x'] + self.__SplitCodeList = self.__SplitCodeList_3x + elif smversion.getMajor() in {5,6}: + self.__CurrTM = self.__TMs['5x6x'] + self.__SplitCodeList = self.__SplitCodeList_5x6x + elif smversion.getMajor() in {7,8}: + self.__CurrTM = self.__TMs['7x8x'] + self.__SplitCodeList = self.__SplitCodeList_7x8x + else: + raise NotImplementedError(f'ERROR! No implemented state machine for arch {smversion}!!!') + + def __emitMessage(self, msg): + CuAsmLogger.logWarning(f'CuInsFeeder Message: {msg} @Line{self.__mLineNo-1:04d}') + +#### Subroutines for constructing StateTransferMatrix + def __getTrMatrixDefault(self): + stm = StateTransferMatrix() + stm.addop_dict(PS.Ready, + {SLT.FuncName : (PS.Ready, self.__setFuncName), + SLT.SectionName : (PS.Ready, self.__setSectionName), + SLT.HeaderFlag : (PS.Ready, self.__switchArch), + SLT.Others : (PS.Ready, None), + }) + return stm + + def __getTrMatrixForSM_3x(self): + stm = StateTransferMatrix() + + stm.addop_dict(PS.Ready, + { SLT.FuncName : (PS.Ready, self.__setFuncName), + SLT.SectionName : (PS.Ready, self.__setSectionName), + SLT.HeaderFlag : (PS.Ready, self.__switchArch), + SLT.CodeOnly : (PS.Ready, None), + SLT.InsCode : (PS.Ready, self.__pushInsCode_3x), # sometimes the padded ins may just follow normal ins + # no extra ctrl line. + SLT.Others : (PS.Ready, None), + }) + + return stm + + def __getTrMatrixForSM_3x_dep(self): + stm = StateTransferMatrix() + + stm.addop_dict(PS.Ready, + { SLT.FuncName : (PS.Ready, self.__setFuncName), + SLT.CodeOnly : (PS.WaitForIns7, None), + SLT.InsCode : (PS.Ready, None), # sometimes the padded ins may just follow normal ins, no ctrl line. + SLT.Others : (PS.Ready, None), + }) + + # Arch always follows FuncName, then wait for control code + # stm.addop(PS.WaitForArch, SLT.HeaderFlag, PS.WaitForCode1, self.__switchArch) + + # 1 CodeOnly(control code) + 7 InsCode + stm.addop(PS.WaitForCode1, SLT.CodeOnly, PS.WaitForIns7, None) + + # 7 InsCode as a chain, no explicit dual issue + stm.addop(PS.WaitForIns7, SLT.InsCode, PS.WaitForIns6, self.__pushInsCode_3x) + stm.addop(PS.WaitForIns6, SLT.InsCode, PS.WaitForIns5, self.__pushInsCode_3x) + stm.addop(PS.WaitForIns5, SLT.InsCode, PS.WaitForIns4, self.__pushInsCode_3x) + stm.addop(PS.WaitForIns4, SLT.InsCode, PS.WaitForIns3, self.__pushInsCode_3x) + stm.addop(PS.WaitForIns3, SLT.InsCode, PS.WaitForIns2, self.__pushInsCode_3x) + stm.addop(PS.WaitForIns2, SLT.InsCode, PS.WaitForIns1, self.__pushInsCode_3x) + stm.addop(PS.WaitForIns1, SLT.InsCode, PS.Ready , self.__pushInsCode_3x) + + return stm + + def __getTrMatrixForSM_5x6x(self): + stm = StateTransferMatrix() + + stm.addop_dict(PS.Ready, + { SLT.FuncName : (PS.Ready, self.__setFuncName), + SLT.SectionName : (PS.Ready, self.__setSectionName), + SLT.HeaderFlag : (PS.Ready, self.__switchArch), + SLT.CodeOnly : (PS.WaitForIns3, self.__pushCode), + SLT.Others : (PS.Ready, None), + }) + + + # 1 CodeOnly(Control code) + 3 InsCode + # stm.addop(PS.WaitForCode4, SLT.CodeOnly, PS.WaitForIns3, self.__pushCtrl_5x6x) + + # 3->2 + stm.addop(PS.WaitForIns3, SLT.InsCode , PS.WaitForIns2, self.__pushInsCode) + # wait for code of the ins (dual-issued: 3/3 of last pack4 + 1/3 of current pack4) + stm.addop(PS.WaitForIns3, SLT.InsOnly, PS.WaitForCode3, self.__pushInsOnly_5x6x) + stm.addop(PS.WaitForCode3, SLT.CodeOnly, PS.WaitForIns2, self.__pushCode) + + # 2->1 + stm.addop(PS.WaitForIns2, SLT.InsCode, PS.WaitForIns1, self.__pushInsCode) + # wait for code of the ins (dual-issued: 1/3 + 2/3 of current pack4) + stm.addop(PS.WaitForIns2, SLT.InsOnly, PS.WaitForCode2, self.__pushInsOnly_5x6x) + stm.addop(PS.WaitForCode2, SLT.CodeOnly, PS.WaitForIns1, self.__pushCode) + + # 1-> ready + stm.addop(PS.WaitForIns1, SLT.InsCode, PS.Ready, self.__pushInsCode) + # wait for code of the ins (dual-issued: 2/3 + 3/3 of current pack4) + stm.addop(PS.WaitForIns1, SLT.InsOnly, PS.WaitForCode1, self.__pushInsOnly_5x6x) + stm.addop(PS.WaitForCode1, SLT.CodeOnly, PS.Ready, self.__pushCode) + + return stm + + def __getTrMatrixForSM_7x8x(self): + stm = StateTransferMatrix() + + stm.addop_dict(PS.Ready, + { SLT.FuncName : (PS.Ready, self.__setFuncName), + SLT.SectionName : (PS.Ready, self.__setSectionName), + SLT.HeaderFlag : (PS.Ready, self.__switchArch), + SLT.InsCode : (PS.WaitForCode1, self.__pushInsCode), + SLT.CodeOnly : (PS.Ready, lambda x: self.__emitMessage('Missing Instruction')), + SLT.Others : (PS.Ready, None), + }) + + # 1 InsCode + 1 CodeOnly + stm.addop(PS.WaitForIns1, SLT.InsCode, PS.WaitForCode1, self.__pushInsCode) + stm.addop(PS.WaitForCode1, SLT.CodeOnly, PS.Ready, self.__pushCode) + + return stm + +if __name__ == '__main__': + pass diff --git a/tinygrad/runtime/support/assembler/CuInsParser.py b/tinygrad/runtime/support/assembler/CuInsParser.py new file mode 100644 index 0000000000000..dd891b634f891 --- /dev/null +++ b/tinygrad/runtime/support/assembler/CuInsParser.py @@ -0,0 +1,675 @@ +# -*- coding: utf-8 -*- + +import re +import struct +from tinygrad.runtime.support.assembler.common import * +from tinygrad.runtime.support.assembler.CuSMVersion import CuSMVersion + +# Pattern that matches an instruction string +p_InsPattern = re.compile(r'(?P@!?U?P\w\s+)?\s*(?P[\w\.\?]+)(?P.*)') + +# Pattern that matches scoreboard sets, such as {1}, {4,2} +# Seems only appear after opcode DEPBAR +p_SBSet = re.compile(r'\{(\d,)*\d\}') + +# NOTE: about constants translate dict +# 1) +/-QNAN is not recognized by python float(), use +/-NAN +# +/-INF seems OK, +# QNAN for FSEL may not work properly, needs special treatment +# 2) (.reuse will be treated seperately for control codes, hence ignored here.) +# Bugfix: reuse will be treated as normal modifier +# 3) RZ may also appear in FADD/FMUL/FFMA.RZ ... +# 4) UPT is not found, may be just PT? +p_ConstTrDict = {r'(?[~\-\|!]*)(?P
.*?)\|?(?P(\.\w+)*)\|?$') + +# Match Label+Index (including translated RZ/URZ/PT) +# SBSet is the score board set for DEPBAR, translated before parsing +p_IndexedPattern = re.compile(r'\b(?P