@@ -284,22 +284,46 @@ def forward(self, hidden_states):
284284 key_proj = self .key (hidden_states )
285285 value_proj = self .value (hidden_states )
286286
287- # transpose
288- query_states = self .transpose_for_scores (query_proj )
289- key_states = self .transpose_for_scores (key_proj )
290- value_states = self .transpose_for_scores (value_proj )
287+ if self .num_heads > 1 :
288+ query_states = self .transpose_for_scores (query_proj )
289+ key_states = self .transpose_for_scores (key_proj )
290+ value_states = self .transpose_for_scores (value_proj )
291+ else :
292+ query_states , key_states , value_states = query_proj , key_proj , value_proj
291293
292294 # get scores
293- scale = 1 / math .sqrt (math .sqrt (self .channels / self .num_heads ))
294- attention_scores = torch .matmul (query_states * scale , key_states .transpose (- 1 , - 2 ) * scale ) # TODO: use baddmm
295+ scale = 1 / math .sqrt (self .channels / self .num_heads )
296+
297+ if self .num_heads > 1 :
298+ # TODO: is there a way to perform batched matmul (e.g. baddbmm) on 4D tensors?
299+ # or reformulate this into a 3D problem?
300+ # TODO: measure whether on MPS device it would be faster to do this matmul via einsum
301+ # as some matmuls can be 1.94x slower than an equivalent einsum on MPS
302+ # https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0
303+ attention_scores = torch .matmul (query_states , key_states .transpose (- 1 , - 2 )) * scale
304+ else :
305+ attention_scores = torch .baddbmm (
306+ torch .empty (query_states .shape [0 ], query_states .shape [1 ], key_states .shape [1 ], dtype = query_states .dtype , device = query_states .device ),
307+ query_states ,
308+ key_states .transpose (- 1 , - 2 ),
309+ beta = 0 ,
310+ alpha = scale ,
311+ )
295312 attention_probs = torch .softmax (attention_scores .float (), dim = - 1 ).type (attention_scores .dtype )
296313
297314 # compute attention output
298- hidden_states = torch .matmul (attention_probs , value_states )
299-
300- hidden_states = hidden_states .permute (0 , 2 , 1 , 3 ).contiguous ()
301- new_hidden_states_shape = hidden_states .size ()[:- 2 ] + (self .channels ,)
302- hidden_states = hidden_states .view (new_hidden_states_shape )
315+ if self .num_heads > 1 :
316+ # TODO: is there a way to perform batched matmul (e.g. bmm) on 4D tensors?
317+ # or reformulate this into a 3D problem?
318+ # TODO: measure whether on MPS device it would be faster to do this matmul via einsum
319+ # as some matmuls can be 1.94x slower than an equivalent einsum on MPS
320+ # https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0
321+ hidden_states = torch .matmul (attention_probs , value_states )
322+ hidden_states = hidden_states .permute (0 , 2 , 1 , 3 ).contiguous ()
323+ new_hidden_states_shape = hidden_states .size ()[:- 2 ] + (self .channels ,)
324+ hidden_states = hidden_states .view (new_hidden_states_shape )
325+ else :
326+ hidden_states = torch .bmm (attention_probs , value_states )
303327
304328 # compute next hidden_states
305329 hidden_states = self .proj_attn (hidden_states )
@@ -507,19 +531,17 @@ def forward(self, hidden_states, context=None, mask=None):
507531 return hidden_states
508532
509533 def _attention (self , query , key , value ):
510- # TODO: use baddbmm for better performance
511- if query .device .type == "mps" :
512- # Better performance on mps (~20-25%)
513- attention_scores = torch .einsum ("b i d, b j d -> b i j" , query , key ) * self .scale
514- else :
515- attention_scores = torch .matmul (query , key .transpose (- 1 , - 2 )) * self .scale
534+ attention_scores = torch .baddbmm (
535+ torch .empty (query .shape [0 ], query .shape [1 ], key .shape [1 ], dtype = query .dtype , device = query .device ),
536+ query ,
537+ key .transpose (- 1 , - 2 ),
538+ beta = 0 ,
539+ alpha = self .scale ,
540+ )
516541 attention_probs = attention_scores .softmax (dim = - 1 )
517542 # compute attention output
518543
519- if query .device .type == "mps" :
520- hidden_states = torch .einsum ("b i j, b j d -> b i d" , attention_probs , value )
521- else :
522- hidden_states = torch .matmul (attention_probs , value )
544+ hidden_states = torch .bmm (attention_probs , value )
523545
524546 # reshape hidden_states
525547 hidden_states = self .reshape_batch_dim_to_heads (hidden_states )
@@ -534,21 +556,15 @@ def _sliced_attention(self, query, key, value, sequence_length, dim):
534556 for i in range (hidden_states .shape [0 ] // slice_size ):
535557 start_idx = i * slice_size
536558 end_idx = (i + 1 ) * slice_size
537- if query .device .type == "mps" :
538- # Better performance on mps (~20-25%)
539- attn_slice = (
540- torch .einsum ("b i d, b j d -> b i j" , query [start_idx :end_idx ], key [start_idx :end_idx ])
541- * self .scale
542- )
543- else :
544- attn_slice = (
545- torch .matmul (query [start_idx :end_idx ], key [start_idx :end_idx ].transpose (1 , 2 )) * self .scale
546- ) # TODO: use baddbmm for better performance
559+ attn_slice = torch .baddbmm (
560+ torch .empty (slice_size , query .shape [1 ], key .shape [1 ], dtype = query .dtype , device = query .device ),
561+ query [start_idx :end_idx ],
562+ key [start_idx :end_idx ].transpose (- 1 , - 2 ),
563+ beta = 0 ,
564+ alpha = self .scale ,
565+ )
547566 attn_slice = attn_slice .softmax (dim = - 1 )
548- if query .device .type == "mps" :
549- attn_slice = torch .einsum ("b i j, b j d -> b i d" , attn_slice , value [start_idx :end_idx ])
550- else :
551- attn_slice = torch .matmul (attn_slice , value [start_idx :end_idx ])
567+ attn_slice = torch .bmm (attn_slice , value [start_idx :end_idx ])
552568
553569 hidden_states [start_idx :end_idx ] = attn_slice
554570
0 commit comments