|
10 | 10 |
|
11 | 11 | from fastseq.optimizer.jit.utils import graph_pattern, rewrite_graph |
12 | 12 |
|
13 | | -@graph_pattern |
14 | 13 | def einsum_pattern_0(t0: str, t1: List[Tensor]): |
15 | 14 | r = torch.einsum(t0, t1) |
16 | 15 | return r |
17 | 16 |
|
18 | | -@graph_pattern |
19 | 17 | def einsum_rewrite_pattern_0(eqn: str, operands: List[Tensor]): |
| 18 | + # eqn = eqn.replace(' ', '') # TODO: fix the issue: ValueError: stoll |
20 | 19 | # for cases like "bmhtd,bnhsd->bmhts" |
21 | | - if (len(eqn) == 18 and eqn[0:3] == eqn[13:16] and eqn[0] == eqn[6] and |
22 | | - eqn[2] == eqn[8] and eqn[4] == eqn[10] and eqn[3] == eqn[16] and |
23 | | - eqn[9] == eqn[17]): |
| 20 | + if (len(eqn) == 18 and eqn[0:4] == eqn[13:17] and eqn[0] == eqn[6] and |
| 21 | + eqn[2] == eqn[8] and eqn[4] == eqn[10] and eqn[9] == eqn[17]): |
24 | 22 | t0 = operands[0] |
25 | 23 | t1 = operands[1] |
26 | | - b = t0.size(0) |
27 | | - m = t0.size(1) |
28 | | - h = t0.size(2) |
29 | | - t = t0.size(3) |
30 | | - d = t0.size(4) |
| 24 | + b, m, h, t, d = t0.shape |
| 25 | + s = t1.size(3) |
31 | 26 | n = t1.size(1) |
| 27 | + t1 = t1.permute(0, 2, 3, 4, 1) # (b, h, s, d, n) |
32 | 28 | if n > 1: |
33 | | - t1 = t1.sum(dim=1, keepdim=True) # (b, 1, h, d, s) |
34 | | - s = t1.size(3) |
| 29 | + t1 = t1.sum(dim=4, keepdim=True) # (b, h, s, d, 1) |
| 30 | + |
35 | 31 | t0 = t0.permute(0, 2, 1, 3, 4) # (b, h, m, t, d) |
36 | | - t1 = t1.permute(0, 2, 1, 4, 3) # (b, h, 1, d, s) |
| 32 | + t1 = t1.permute(0, 1, 3, 4, 2) # (b, h, d, 1, s) |
37 | 33 | t0 = t0.reshape(b*h, m*t, d) |
38 | | - t1 = t1.reshape(b*h, d, s) |
| 34 | + t1 = t1.view(b*h, d, s) |
39 | 35 | r = torch.bmm(t0, t1).view(b, h, m, t, s).permute(0, 2, 1, 3, 4) |
40 | 36 | return r |
41 | 37 |
|
42 | 38 | # for cases like "bmhts,bnhsd->bmhtd" |
43 | | - if (len(eqn) == 18 and eqn[0:3] == eqn[13:16] and eqn[0] == eqn[6] and |
44 | | - eqn[2] == eqn[8] and eqn[4] == eqn[9] and eqn[3] == eqn[16] and |
45 | | - eqn[10] == eqn[17]): |
| 39 | + if (len(eqn) == 18 and eqn[0:4] == eqn[13:17] and eqn[0] == eqn[6] and |
| 40 | + eqn[2] == eqn[8] and eqn[4] == eqn[9] and eqn[10] == eqn[17]): |
46 | 41 | t0 = operands[0] |
47 | 42 | t1 = operands[1] |
48 | | - b = t0.size(0) |
49 | | - m = t0.size(1) |
50 | | - h = t0.size(2) |
51 | | - t = t0.size(3) |
52 | | - s = t0.size(4) |
| 43 | + b, m, h, t, s = t0.shape |
53 | 44 | n = t1.size(1) |
54 | | - if n > 1: |
55 | | - t1 = t1.sum(dim=1, keepdim=True) # (b, 1, h, s, d) |
56 | 45 | d = t1.size(4) |
| 46 | + t1 = t1.permute(0, 2, 4, 3, 1) # (b, h, d, s, n) |
| 47 | + if n > 1: |
| 48 | + t1 = t1.sum(dim=4, keepdim=True) # (b, h, d, s, 1) |
| 49 | + # t1 = t1.squeeze(1) # (b, h, s, d) |
57 | 50 | t0 = t0.permute(0, 2, 1, 3, 4) # (b, h, m, t, s) |
58 | | - t1 = t1.permute(0, 2, 1, 3, 4) # (b, h, 1, s, d) |
| 51 | + t1 = t1.permute(0, 1, 3, 4, 2) # (b, h, s, 1, d) |
59 | 52 | t0 = t0.reshape(b*h, m*t, s) |
60 | | - t1 = t1.reshape(b*h, s, d) |
| 53 | + t1 = t1.view(b*h, s, d) |
61 | 54 | r = torch.bmm(t0, t1).view(b, h, m, t, d).permute(0, 2, 1, 3, 4) |
62 | 55 | return r |
63 | 56 |
|
64 | 57 | return torch.einsum(eqn, operands) |
65 | 58 |
|
66 | | -EINSUM_PATTERN_STR = einsum_pattern_0() |
67 | | -EINSUM_REWRITE_PATTERN_STR = einsum_rewrite_pattern_0() |
| 59 | +EINSUM_PATTERN_STR = graph_pattern(einsum_pattern_0)() |
| 60 | +EINSUM_REWRITE_PATTERN_STR = graph_pattern(einsum_rewrite_pattern_0)() |
68 | 61 |
|
69 | 62 | def rewrite_einsum(input_graph: torch._C.Graph): |
70 | 63 | rewrite_graph(EINSUM_PATTERN_STR, EINSUM_REWRITE_PATTERN_STR, input_graph) |
0 commit comments