@@ -53,7 +53,7 @@ defmodule Bumblebee.Layers.Transformer do
5353 :layer_norm ,
5454 :block_type ,
5555 :attention_window_size ,
56- :scale_attention_weights ,
56+ :attention_scale ,
5757 :query_norm ,
5858 :key_norm
5959 ]
@@ -276,8 +276,8 @@ defmodule Bumblebee.Layers.Transformer do
276276 * `:attention_window_size` - when set, enables sliding window attention.
277277 Should be a `{left, right}` tuple with window size on each side
278278
279- * `:scale_attention_weights ` - whether to scale query in the traditional style of
280- multi-headed attention. Defaults to `true`
279+ * `:attention_scale ` - the scaling factor applied to the attention weights.
280+ Defaults to $ \f rac{1}{ \s qrt{d}}$.
281281
282282 * `:rotary_embedding` - configuration of rotary embedding. If set,
283283 will apply rotary position embedding with the given options. Valid
@@ -331,7 +331,7 @@ defmodule Bumblebee.Layers.Transformer do
331331 block_type: :standard ,
332332 layer_norm: [ ] ,
333333 attention_window_size: nil ,
334- scale_attention_weights: true ,
334+ attention_scale: nil ,
335335 rotary_embedding: nil ,
336336 query_norm: nil ,
337337 key_norm: nil
@@ -362,7 +362,7 @@ defmodule Bumblebee.Layers.Transformer do
362362 layer_norm = opts [ :layer_norm ]
363363 block_type = opts [ :block_type ]
364364 attention_window_size = opts [ :attention_window_size ]
365- scale_attention_weights = opts [ :scale_attention_weights ]
365+ attention_scale = opts [ :attention_scale ]
366366 rotary_embedding = opts [ :rotary_embedding ]
367367 query_norm = opts [ :query_norm ]
368368 key_norm = opts [ :key_norm ]
@@ -422,7 +422,7 @@ defmodule Bumblebee.Layers.Transformer do
422422 value_use_bias: value_use_bias ,
423423 output_use_bias: output_use_bias ,
424424 attention_window_size: attention_window_size ,
425- scale_attention_weights: scale_attention_weights ,
425+ attention_scale: attention_scale ,
426426 rotary_embedding: rotary_embedding ,
427427 query_norm: query_norm ,
428428 key_norm: key_norm ,
@@ -469,7 +469,7 @@ defmodule Bumblebee.Layers.Transformer do
469469 value_use_bias: value_use_bias ,
470470 output_use_bias: output_use_bias ,
471471 attention_window_size: attention_window_size ,
472- scale_attention_weights: scale_attention_weights ,
472+ attention_scale: attention_scale ,
473473 rotary_embedding: rotary_embedding ,
474474 name: join ( name , "cross_attention" )
475475 )
@@ -699,8 +699,8 @@ defmodule Bumblebee.Layers.Transformer do
699699 * `:attention_window_size` - when set, enables sliding window attention.
700700 Should be a `{left, right}` tuple with window size on each side
701701
702- * `:scale_attention_weights ` - whether to scale query in the traditional style of
703- multi-headed attention. Defaults to `true`
702+ * `:attention_scale ` - the scaling factor applied to the attention weights.
703+ Defaults to $ \f rac{1}{ \s qrt{d}}$
704704
705705 * `:rotary_embedding` - configuration of rotary embedding. If set,
706706 will apply rotary position embedding with the given options. Valid
@@ -742,7 +742,7 @@ defmodule Bumblebee.Layers.Transformer do
742742 offset: Layers . none ( ) ,
743743 causal: false ,
744744 attention_window_size: nil ,
745- scale_attention_weights: true ,
745+ attention_scale: nil ,
746746 kernel_initializer: :glorot_uniform ,
747747 dropout_rate: 0.0 ,
748748 attention_head_size: nil ,
@@ -767,7 +767,7 @@ defmodule Bumblebee.Layers.Transformer do
767767 kernel_initializer = opts [ :kernel_initializer ]
768768 causal = opts [ :causal ]
769769 attention_window_size = opts [ :attention_window_size ]
770- scale_attention_weights = opts [ :scale_attention_weights ]
770+ attention_scale = opts [ :attention_scale ]
771771 dropout_rate = opts [ :dropout_rate ]
772772 rotary_embedding = opts [ :rotary_embedding ]
773773 query_norm = opts [ :query_norm ]
@@ -908,7 +908,7 @@ defmodule Bumblebee.Layers.Transformer do
908908 attention_head_mask ,
909909 attention_relative_bias ,
910910 offset ,
911- scale: scale_attention_weights ,
911+ scale: attention_scale ,
912912 causal: causal ,
913913 window_size: attention_window_size ,
914914 dropout_rate: dropout_rate
0 commit comments