@@ -95,13 +95,15 @@ def _gather_helper_fwd(self, tensor, attn_dist):
9595 wgroup = self .w_group ,
9696 )
9797
98- def _gather_helper_bwd (self , tensor , attn_dist ):
98+ def _gather_helper_bwd (self , tensor , attn_dist , use_out_shapes = False ):
99+ hshapes = attn_dist .lat_out_shapes if use_out_shapes else attn_dist .lat_in_shapes
100+ wshapes = attn_dist .lon_out_shapes if use_out_shapes else attn_dist .lon_in_shapes
99101 return gather_tensor_hw (
100102 tensor ,
101103 hdim = - 2 ,
102104 wdim = - 1 ,
103- hshapes = attn_dist . lat_in_shapes ,
104- wshapes = attn_dist . lon_in_shapes ,
105+ hshapes = hshapes ,
106+ wshapes = wshapes ,
105107 hsize = self .grid_size_h ,
106108 wsize = self .grid_size_w ,
107109 hrank = self .hrank ,
@@ -116,7 +118,7 @@ def _gather_helper_bwd(self, tensor, attn_dist):
116118 [64 , 128 , 64 , 128 , 2 , 16 , 1 , None , None , "equiangular" , "equiangular" , 1e-5 , 1e-4 ],
117119 [64 , 128 , 64 , 128 , 2 , 16 , 2 , None , None , "equiangular" , "equiangular" , 1e-5 , 1e-4 ],
118120 [64 , 128 , 64 , 128 , 2 , 16 , 1 , 8 , 8 , "equiangular" , "equiangular" , 1e-5 , 1e-4 ],
119- # [64, 128, 32, 64, 2, 16, 1, None, None, "equiangular", "equiangular", 1e-5, 1e-4],
121+ [64 , 128 , 32 , 64 , 2 , 16 , 1 , None , None , "equiangular" , "equiangular" , 1e-5 , 1e-4 ],
120122 [65 , 128 , 65 , 128 , 2 , 16 , 1 , None , None , "equiangular" , "equiangular" , 1e-5 , 1e-4 ],
121123 ],
122124 skip_on_empty = True ,
@@ -208,7 +210,8 @@ def test_distributed_neighborhood_attention(
208210
209211 # ---- compare backward ----
210212 for inp in ["q" , "k" , "v" ]:
211- igrad_gather = self ._gather_helper_bwd (igrad_local [inp ], attn_dist )
213+ use_out = (inp == "q" )
214+ igrad_gather = self ._gather_helper_bwd (igrad_local [inp ], attn_dist , use_out_shapes = use_out )
212215 self .assertTrue (compare_tensors (f"input gradient { inp } " , igrad_full [inp ], igrad_gather , atol = atol , rtol = rtol , verbose = verbose ))
213216
214217
0 commit comments