Skip to content

Commit a8caabd

Browse files
committed
Support overriding attention scale
1 parent 0987a0c commit a8caabd

5 files changed

Lines changed: 29 additions & 25 deletions

File tree

lib/bumblebee/layers.ex

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -219,8 +219,8 @@ defmodule Bumblebee.Layers do
219219
* `:window_size` - when set, enables sliding window attention.
220220
Should be a `{left, right}` tuple with window size on each side
221221
222-
* `:scale` - whether to scale attention weights by $\frac{1}{\sqrt{d}}$.
223-
Defaults to `true`
222+
* `:scale` - the scaling factor applied to the attention weights.
223+
Defaults to $\frac{1}{\sqrt{d}}$
224224
225225
* `:dropout_rate` - the dropout rate for attention weights dropout.
226226
Defaults to `0.0`
@@ -231,7 +231,7 @@ defmodule Bumblebee.Layers do
231231
232232
"""
233233
def attention(query, key, value, key_mask, head_mask, bias, offset, opts \\ []) do
234-
opts = Keyword.validate!(opts, [:window_size, causal: false, scale: true, dropout_rate: 0.0])
234+
opts = Keyword.validate!(opts, [:window_size, :scale, causal: false, dropout_rate: 0.0])
235235

236236
weights =
237237
Axon.layer(
@@ -263,14 +263,18 @@ defmodule Bumblebee.Layers do
263263

264264
weights = Nx.dot(query, [3], [0, 1], key, [3], [0, 1])
265265

266-
weights =
267-
if opts[:scale] do
268-
depth = Nx.axis_size(query, -1)
269-
weights / Nx.as_type(Nx.sqrt(depth), Nx.type(query))
270-
else
271-
weights
266+
scale =
267+
case opts[:scale] do
268+
nil ->
269+
depth = Nx.axis_size(query, -1)
270+
1 / Nx.as_type(Nx.sqrt(depth), Nx.type(query))
271+
272+
scale ->
273+
scale
272274
end
273275

276+
weights = weights * scale
277+
274278
key_mask =
275279
case key_mask do
276280
%Axon.None{} ->

lib/bumblebee/layers/transformer.ex

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ defmodule Bumblebee.Layers.Transformer do
5353
:layer_norm,
5454
:block_type,
5555
:attention_window_size,
56-
:scale_attention_weights,
56+
:attention_scale,
5757
:query_norm,
5858
:key_norm
5959
]
@@ -276,8 +276,8 @@ defmodule Bumblebee.Layers.Transformer do
276276
* `:attention_window_size` - when set, enables sliding window attention.
277277
Should be a `{left, right}` tuple with window size on each side
278278
279-
* `:scale_attention_weights` - whether to scale query in the traditional style of
280-
multi-headed attention. Defaults to `true`
279+
* `:attention_scale` - the scaling factor applied to the attention weights.
280+
Defaults to $\frac{1}{\sqrt{d}}$.
281281
282282
* `:rotary_embedding` - configuration of rotary embedding. If set,
283283
will apply rotary position embedding with the given options. Valid
@@ -331,7 +331,7 @@ defmodule Bumblebee.Layers.Transformer do
331331
block_type: :standard,
332332
layer_norm: [],
333333
attention_window_size: nil,
334-
scale_attention_weights: true,
334+
attention_scale: nil,
335335
rotary_embedding: nil,
336336
query_norm: nil,
337337
key_norm: nil
@@ -362,7 +362,7 @@ defmodule Bumblebee.Layers.Transformer do
362362
layer_norm = opts[:layer_norm]
363363
block_type = opts[:block_type]
364364
attention_window_size = opts[:attention_window_size]
365-
scale_attention_weights = opts[:scale_attention_weights]
365+
attention_scale = opts[:attention_scale]
366366
rotary_embedding = opts[:rotary_embedding]
367367
query_norm = opts[:query_norm]
368368
key_norm = opts[:key_norm]
@@ -422,7 +422,7 @@ defmodule Bumblebee.Layers.Transformer do
422422
value_use_bias: value_use_bias,
423423
output_use_bias: output_use_bias,
424424
attention_window_size: attention_window_size,
425-
scale_attention_weights: scale_attention_weights,
425+
attention_scale: attention_scale,
426426
rotary_embedding: rotary_embedding,
427427
query_norm: query_norm,
428428
key_norm: key_norm,
@@ -469,7 +469,7 @@ defmodule Bumblebee.Layers.Transformer do
469469
value_use_bias: value_use_bias,
470470
output_use_bias: output_use_bias,
471471
attention_window_size: attention_window_size,
472-
scale_attention_weights: scale_attention_weights,
472+
attention_scale: attention_scale,
473473
rotary_embedding: rotary_embedding,
474474
name: join(name, "cross_attention")
475475
)
@@ -699,8 +699,8 @@ defmodule Bumblebee.Layers.Transformer do
699699
* `:attention_window_size` - when set, enables sliding window attention.
700700
Should be a `{left, right}` tuple with window size on each side
701701
702-
* `:scale_attention_weights` - whether to scale query in the traditional style of
703-
multi-headed attention. Defaults to `true`
702+
* `:attention_scale` - the scaling factor applied to the attention weights.
703+
Defaults to $\frac{1}{\sqrt{d}}$
704704
705705
* `:rotary_embedding` - configuration of rotary embedding. If set,
706706
will apply rotary position embedding with the given options. Valid
@@ -742,7 +742,7 @@ defmodule Bumblebee.Layers.Transformer do
742742
offset: Layers.none(),
743743
causal: false,
744744
attention_window_size: nil,
745-
scale_attention_weights: true,
745+
attention_scale: nil,
746746
kernel_initializer: :glorot_uniform,
747747
dropout_rate: 0.0,
748748
attention_head_size: nil,
@@ -767,7 +767,7 @@ defmodule Bumblebee.Layers.Transformer do
767767
kernel_initializer = opts[:kernel_initializer]
768768
causal = opts[:causal]
769769
attention_window_size = opts[:attention_window_size]
770-
scale_attention_weights = opts[:scale_attention_weights]
770+
attention_scale = opts[:attention_scale]
771771
dropout_rate = opts[:dropout_rate]
772772
rotary_embedding = opts[:rotary_embedding]
773773
query_norm = opts[:query_norm]
@@ -908,7 +908,7 @@ defmodule Bumblebee.Layers.Transformer do
908908
attention_head_mask,
909909
attention_relative_bias,
910910
offset,
911-
scale: scale_attention_weights,
911+
scale: attention_scale,
912912
causal: causal,
913913
window_size: attention_window_size,
914914
dropout_rate: dropout_rate

lib/bumblebee/text/gpt2.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ defmodule Bumblebee.Text.Gpt2 do
412412
activation: spec.activation
413413
],
414414
block_type: :norm_first,
415-
scale_attention_weights: spec.scale_attention_weights,
415+
attention_scale: if(not spec.scale_attention_weights, do: 1),
416416
name: join(name, "blocks")
417417
] ++
418418
if(spec.use_cross_attention,

lib/bumblebee/text/gpt_big_code.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ defmodule Bumblebee.Text.GptBigCode do
417417
activation: spec.activation
418418
],
419419
block_type: :norm_first,
420-
scale_attention_weights: spec.scale_attention_weights,
420+
attention_scale: if(not spec.scale_attention_weights, do: 1),
421421
name: join(name, "blocks")
422422
] ++
423423
if(spec.use_cross_attention,

lib/bumblebee/text/t5.ex

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ defmodule Bumblebee.Text.T5 do
412412
max_distance: spec.relative_attention_max_distance
413413
],
414414
share_attention_relative_bias: true,
415-
scale_attention_weights: false,
415+
attention_scale: 1,
416416
name: join(name, "blocks")
417417
)
418418

@@ -469,7 +469,7 @@ defmodule Bumblebee.Text.T5 do
469469
max_distance: spec.relative_attention_max_distance
470470
],
471471
share_attention_relative_bias: true,
472-
scale_attention_weights: false,
472+
attention_scale: 1,
473473
name: join(name, "blocks")
474474
)
475475

0 commit comments

Comments
 (0)