Skip to content

Commit 4d434d6

Browse files
committed
working downsampling attention
1 parent 810adad commit 4d434d6

1 file changed

Lines changed: 8 additions & 5 deletions

File tree

tests/test_distributed_attention.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,13 +95,15 @@ def _gather_helper_fwd(self, tensor, attn_dist):
9595
wgroup=self.w_group,
9696
)
9797

98-
def _gather_helper_bwd(self, tensor, attn_dist):
98+
def _gather_helper_bwd(self, tensor, attn_dist, use_out_shapes=False):
99+
hshapes = attn_dist.lat_out_shapes if use_out_shapes else attn_dist.lat_in_shapes
100+
wshapes = attn_dist.lon_out_shapes if use_out_shapes else attn_dist.lon_in_shapes
99101
return gather_tensor_hw(
100102
tensor,
101103
hdim=-2,
102104
wdim=-1,
103-
hshapes=attn_dist.lat_in_shapes,
104-
wshapes=attn_dist.lon_in_shapes,
105+
hshapes=hshapes,
106+
wshapes=wshapes,
105107
hsize=self.grid_size_h,
106108
wsize=self.grid_size_w,
107109
hrank=self.hrank,
@@ -116,7 +118,7 @@ def _gather_helper_bwd(self, tensor, attn_dist):
116118
[64, 128, 64, 128, 2, 16, 1, None, None, "equiangular", "equiangular", 1e-5, 1e-4],
117119
[64, 128, 64, 128, 2, 16, 2, None, None, "equiangular", "equiangular", 1e-5, 1e-4],
118120
[64, 128, 64, 128, 2, 16, 1, 8, 8, "equiangular", "equiangular", 1e-5, 1e-4],
119-
# [64, 128, 32, 64, 2, 16, 1, None, None, "equiangular", "equiangular", 1e-5, 1e-4],
121+
[64, 128, 32, 64, 2, 16, 1, None, None, "equiangular", "equiangular", 1e-5, 1e-4],
120122
[65, 128, 65, 128, 2, 16, 1, None, None, "equiangular", "equiangular", 1e-5, 1e-4],
121123
],
122124
skip_on_empty=True,
@@ -208,7 +210,8 @@ def test_distributed_neighborhood_attention(
208210

209211
# ---- compare backward ----
210212
for inp in ["q", "k", "v"]:
211-
igrad_gather = self._gather_helper_bwd(igrad_local[inp], attn_dist)
213+
use_out = (inp == "q")
214+
igrad_gather = self._gather_helper_bwd(igrad_local[inp], attn_dist, use_out_shapes=use_out)
212215
self.assertTrue(compare_tensors(f"input gradient {inp}", igrad_full[inp], igrad_gather, atol=atol, rtol=rtol, verbose=verbose))
213216

214217

0 commit comments

Comments
 (0)