Skip to content

Commit 220c0f0

Browse files
committed
perf: prefer batched matmuls for attention. added fast-path to Decoder when num_heads=1
1 parent 249d9bc commit 220c0f0

1 file changed

Lines changed: 51 additions & 35 deletions

File tree

src/diffusers/models/attention.py

Lines changed: 51 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -284,22 +284,46 @@ def forward(self, hidden_states):
284284
key_proj = self.key(hidden_states)
285285
value_proj = self.value(hidden_states)
286286

287-
# transpose
288-
query_states = self.transpose_for_scores(query_proj)
289-
key_states = self.transpose_for_scores(key_proj)
290-
value_states = self.transpose_for_scores(value_proj)
287+
if self.num_heads > 1:
288+
query_states = self.transpose_for_scores(query_proj)
289+
key_states = self.transpose_for_scores(key_proj)
290+
value_states = self.transpose_for_scores(value_proj)
291+
else:
292+
query_states, key_states, value_states = query_proj, key_proj, value_proj
291293

292294
# get scores
293-
scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads))
294-
attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) # TODO: use baddmm
295+
scale = 1 / math.sqrt(self.channels / self.num_heads)
296+
297+
if self.num_heads > 1:
298+
# TODO: is there a way to perform batched matmul (e.g. baddbmm) on 4D tensors?
299+
# or reformulate this into a 3D problem?
300+
# TODO: measure whether on MPS device it would be faster to do this matmul via einsum
301+
# as some matmuls can be 1.94x slower than an equivalent einsum on MPS
302+
# https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0
303+
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) * scale
304+
else:
305+
attention_scores = torch.baddbmm(
306+
torch.empty(query_states.shape[0], query_states.shape[1], key_states.shape[1], dtype=query_states.dtype, device=query_states.device),
307+
query_states,
308+
key_states.transpose(-1, -2),
309+
beta=0,
310+
alpha=scale,
311+
)
295312
attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
296313

297314
# compute attention output
298-
hidden_states = torch.matmul(attention_probs, value_states)
299-
300-
hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
301-
new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
302-
hidden_states = hidden_states.view(new_hidden_states_shape)
315+
if self.num_heads > 1:
316+
# TODO: is there a way to perform batched matmul (e.g. bmm) on 4D tensors?
317+
# or reformulate this into a 3D problem?
318+
# TODO: measure whether on MPS device it would be faster to do this matmul via einsum
319+
# as some matmuls can be 1.94x slower than an equivalent einsum on MPS
320+
# https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0
321+
hidden_states = torch.matmul(attention_probs, value_states)
322+
hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
323+
new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
324+
hidden_states = hidden_states.view(new_hidden_states_shape)
325+
else:
326+
hidden_states = torch.bmm(attention_probs, value_states)
303327

304328
# compute next hidden_states
305329
hidden_states = self.proj_attn(hidden_states)
@@ -507,19 +531,17 @@ def forward(self, hidden_states, context=None, mask=None):
507531
return hidden_states
508532

509533
def _attention(self, query, key, value):
510-
# TODO: use baddbmm for better performance
511-
if query.device.type == "mps":
512-
# Better performance on mps (~20-25%)
513-
attention_scores = torch.einsum("b i d, b j d -> b i j", query, key) * self.scale
514-
else:
515-
attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
534+
attention_scores = torch.baddbmm(
535+
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
536+
query,
537+
key.transpose(-1, -2),
538+
beta=0,
539+
alpha=self.scale,
540+
)
516541
attention_probs = attention_scores.softmax(dim=-1)
517542
# compute attention output
518543

519-
if query.device.type == "mps":
520-
hidden_states = torch.einsum("b i j, b j d -> b i d", attention_probs, value)
521-
else:
522-
hidden_states = torch.matmul(attention_probs, value)
544+
hidden_states = torch.bmm(attention_probs, value)
523545

524546
# reshape hidden_states
525547
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
@@ -534,21 +556,15 @@ def _sliced_attention(self, query, key, value, sequence_length, dim):
534556
for i in range(hidden_states.shape[0] // slice_size):
535557
start_idx = i * slice_size
536558
end_idx = (i + 1) * slice_size
537-
if query.device.type == "mps":
538-
# Better performance on mps (~20-25%)
539-
attn_slice = (
540-
torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx])
541-
* self.scale
542-
)
543-
else:
544-
attn_slice = (
545-
torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale
546-
) # TODO: use baddbmm for better performance
559+
attn_slice = torch.baddbmm(
560+
torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
561+
query[start_idx:end_idx],
562+
key[start_idx:end_idx].transpose(-1, -2),
563+
beta=0,
564+
alpha=self.scale,
565+
)
547566
attn_slice = attn_slice.softmax(dim=-1)
548-
if query.device.type == "mps":
549-
attn_slice = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx])
550-
else:
551-
attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx])
567+
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
552568

553569
hidden_states[start_idx:end_idx] = attn_slice
554570

0 commit comments

Comments
 (0)