Skip to content

Commit cf38541

Browse files
committed
Enhance rewrite pattern and tests
1 parent 3ca2eda commit cf38541

2 files changed

Lines changed: 66 additions & 38 deletions

File tree

fastseq/optimizer/jit/einsum_rewriter.py

Lines changed: 21 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,61 +10,54 @@
1010

1111
from fastseq.optimizer.jit.utils import graph_pattern, rewrite_graph
1212

13-
@graph_pattern
1413
def einsum_pattern_0(t0: str, t1: List[Tensor]):
1514
r = torch.einsum(t0, t1)
1615
return r
1716

18-
@graph_pattern
1917
def einsum_rewrite_pattern_0(eqn: str, operands: List[Tensor]):
18+
# eqn = eqn.replace(' ', '') # TODO: fix the issue: ValueError: stoll
2019
# for cases like "bmhtd,bnhsd->bmhts"
21-
if (len(eqn) == 18 and eqn[0:3] == eqn[13:16] and eqn[0] == eqn[6] and
22-
eqn[2] == eqn[8] and eqn[4] == eqn[10] and eqn[3] == eqn[16] and
23-
eqn[9] == eqn[17]):
20+
if (len(eqn) == 18 and eqn[0:4] == eqn[13:17] and eqn[0] == eqn[6] and
21+
eqn[2] == eqn[8] and eqn[4] == eqn[10] and eqn[9] == eqn[17]):
2422
t0 = operands[0]
2523
t1 = operands[1]
26-
b = t0.size(0)
27-
m = t0.size(1)
28-
h = t0.size(2)
29-
t = t0.size(3)
30-
d = t0.size(4)
24+
b, m, h, t, d = t0.shape
25+
s = t1.size(3)
3126
n = t1.size(1)
27+
t1 = t1.permute(0, 2, 3, 4, 1) # (b, h, s, d, n)
3228
if n > 1:
33-
t1 = t1.sum(dim=1, keepdim=True) # (b, 1, h, d, s)
34-
s = t1.size(3)
29+
t1 = t1.sum(dim=4, keepdim=True) # (b, h, s, d, 1)
30+
3531
t0 = t0.permute(0, 2, 1, 3, 4) # (b, h, m, t, d)
36-
t1 = t1.permute(0, 2, 1, 4, 3) # (b, h, 1, d, s)
32+
t1 = t1.permute(0, 1, 3, 4, 2) # (b, h, d, 1, s)
3733
t0 = t0.reshape(b*h, m*t, d)
38-
t1 = t1.reshape(b*h, d, s)
34+
t1 = t1.view(b*h, d, s)
3935
r = torch.bmm(t0, t1).view(b, h, m, t, s).permute(0, 2, 1, 3, 4)
4036
return r
4137

4238
# for cases like "bmhts,bnhsd->bmhtd"
43-
if (len(eqn) == 18 and eqn[0:3] == eqn[13:16] and eqn[0] == eqn[6] and
44-
eqn[2] == eqn[8] and eqn[4] == eqn[9] and eqn[3] == eqn[16] and
45-
eqn[10] == eqn[17]):
39+
if (len(eqn) == 18 and eqn[0:4] == eqn[13:17] and eqn[0] == eqn[6] and
40+
eqn[2] == eqn[8] and eqn[4] == eqn[9] and eqn[10] == eqn[17]):
4641
t0 = operands[0]
4742
t1 = operands[1]
48-
b = t0.size(0)
49-
m = t0.size(1)
50-
h = t0.size(2)
51-
t = t0.size(3)
52-
s = t0.size(4)
43+
b, m, h, t, s = t0.shape
5344
n = t1.size(1)
54-
if n > 1:
55-
t1 = t1.sum(dim=1, keepdim=True) # (b, 1, h, s, d)
5645
d = t1.size(4)
46+
t1 = t1.permute(0, 2, 4, 3, 1) # (b, h, d, s, n)
47+
if n > 1:
48+
t1 = t1.sum(dim=4, keepdim=True) # (b, h, d, s, 1)
49+
# t1 = t1.squeeze(1) # (b, h, s, d)
5750
t0 = t0.permute(0, 2, 1, 3, 4) # (b, h, m, t, s)
58-
t1 = t1.permute(0, 2, 1, 3, 4) # (b, h, 1, s, d)
51+
t1 = t1.permute(0, 1, 3, 4, 2) # (b, h, s, 1, d)
5952
t0 = t0.reshape(b*h, m*t, s)
60-
t1 = t1.reshape(b*h, s, d)
53+
t1 = t1.view(b*h, s, d)
6154
r = torch.bmm(t0, t1).view(b, h, m, t, d).permute(0, 2, 1, 3, 4)
6255
return r
6356

6457
return torch.einsum(eqn, operands)
6558

66-
EINSUM_PATTERN_STR = einsum_pattern_0()
67-
EINSUM_REWRITE_PATTERN_STR = einsum_rewrite_pattern_0()
59+
EINSUM_PATTERN_STR = graph_pattern(einsum_pattern_0)()
60+
EINSUM_REWRITE_PATTERN_STR = graph_pattern(einsum_rewrite_pattern_0)()
6861

6962
def rewrite_einsum(input_graph: torch._C.Graph):
7063
rewrite_graph(EINSUM_PATTERN_STR, EINSUM_REWRITE_PATTERN_STR, input_graph)
Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,69 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
33

4-
from typing import List
4+
import functools
5+
import logging
6+
import timeit
57

6-
from absl.testing import absltest
8+
from absl import flags
9+
from absl.testing import absltest, parameterized
710
import torch
811
from torch import Tensor
912

13+
from fastseq.logging import get_logger
1014
from fastseq.optimizer.jit.einsum_rewriter import rewrite_einsum
1115
from fastseq.utils.test_utils import TestCaseBase
1216

17+
logger = get_logger(__name__, logging.INFO)
18+
19+
20+
FLAGS = flags.FLAGS
21+
1322
class EinsumRewriterTest(TestCaseBase):
1423

15-
def test_einsum_rewriter(self):
24+
@parameterized.parameters(
25+
{'eqn': "bmhtd,bnhsd->bmhts",
26+
'shape0': [128, 4, 16, 5, 64],
27+
'shape1': [128, 2, 16, 1024, 64]},
28+
{'eqn': "kmijd,knisd->kmijs",
29+
'shape0': [128, 4, 16, 1, 64],
30+
'shape1': [128, 2, 16, 1024, 64]},
31+
{'eqn': "bmhts,bnhsd->bmhtd",
32+
'shape0': [128, 4, 16, 3, 64],
33+
'shape1': [128, 2, 16, 64, 7]},
34+
{'eqn': "impts,inpsw->imptw",
35+
'shape0': [128, 4, 16, 3, 64],
36+
'shape1': [128, 2, 16, 64, 7]},
37+
)
38+
def test_einsum_rewriter(self, eqn, shape0, shape1):
1639

17-
def run_einsum(t0: Tensor, t1: Tensor):
18-
r = torch.einsum("bmhtd,bnhsd->bmhts", t0, t1)
19-
r = r + 2.0
40+
def run_einsum(eqn: str, t0: Tensor, t1: Tensor):
41+
r = torch.einsum(eqn, t0, t1)
2042
return r
2143

22-
t0 = torch.randn(10, 3, 4, 3, 9, dtype=torch.float32)
23-
t1 = torch.randn(10, 1, 4, 7, 9, dtype=torch.float32)
44+
t0 = torch.randn(shape0, dtype=torch.float32).cuda()
45+
t1 = torch.randn(shape1, dtype=torch.float32).cuda()
46+
repeat_times = 1000
2447

25-
r0 = run_einsum(t0, t1)
48+
r0 = run_einsum(eqn, t0, t1)
49+
time0 = timeit.Timer(functools.partial(run_einsum, eqn, t0, t1))
50+
s0 = time0.timeit(repeat_times)
2651

2752
script_run_einsum = torch.jit.script(run_einsum)
53+
logger.debug(f"Original graph: \n{script_run_einsum.graph.str()}")
2854
rewrite_einsum(script_run_einsum.graph)
29-
r1 = script_run_einsum(t0, t1)
55+
logger.debug(f"Optimized graph: \n{script_run_einsum.graph.str()}")
56+
self.assertTrue('bmm' in script_run_einsum.graph.str())
57+
58+
r1 = script_run_einsum(eqn, t0, t1)
59+
time1 = timeit.Timer(
60+
functools.partial(script_run_einsum, eqn, t0, t1))
61+
s1 = time1.timeit(repeat_times)
3062

3163
self.assertTrue(torch.equal(r0, r1))
64+
logger.info(f"einsum took: {s0}; optimized einsum torchscript took: "
65+
f"{s1};")
66+
3267

3368
if __name__ == "__main__":
3469
absltest.main()

0 commit comments

Comments
 (0)