From 1fc7aaf5323e4a975cf1ff6eb47c680d09cfb396 Mon Sep 17 00:00:00 2001 From: Niko Maroulis Date: Sun, 28 Dec 2025 12:07:38 -0500 Subject: [PATCH 01/11] Add Gemma 3 support for FunctionGemma and other Gemma 3 models Gemma 3 architecture includes several key differences from Gemma v1: - QK-norm (RMS normalization on query/key after projection) - Pre/post FFN layer norms (pre_feedforward_layernorm, post_feedforward_layernorm) - Different residual connection order (after post_attention_layernorm) - Alternating local/global attention (sliding window) - RMS norm with shift=1.0 formula: output * (1.0 + weight) Files added: - lib/bumblebee/text/gemma3.ex: Full Gemma 3 model implementation - test/bumblebee/text/gemma3_test.exs: Unit tests - notebooks/function_calling.livemd: Livebook with FunctionGemma examples Files modified: - lib/bumblebee.ex: Model and tokenizer registrations - lib/bumblebee/layers/transformer.ex: Per-layer attention_window_size support --- lib/bumblebee.ex | 9 + lib/bumblebee/layers/transformer.ex | 20 +- lib/bumblebee/text/gemma3.ex | 726 ++++++++++++++++++++++++++++ notebooks/function_calling.livemd | 554 +++++++++++++++++++++ test/bumblebee/text/gemma3_test.exs | 55 +++ 5 files changed, 1363 insertions(+), 1 deletion(-) create mode 100644 lib/bumblebee/text/gemma3.ex create mode 100644 notebooks/function_calling.livemd create mode 100644 test/bumblebee/text/gemma3_test.exs diff --git a/lib/bumblebee.ex b/lib/bumblebee.ex index a6e832c7..e607213b 100644 --- a/lib/bumblebee.ex +++ b/lib/bumblebee.ex @@ -137,6 +137,13 @@ defmodule Bumblebee do "GemmaModel" => {Bumblebee.Text.Gemma, :base}, "GemmaForCausalLM" => {Bumblebee.Text.Gemma, :for_causal_language_modeling}, "GemmaForSequenceClassification" => {Bumblebee.Text.Gemma, :for_sequence_classification}, + "Gemma3Model" => {Bumblebee.Text.Gemma3, :base}, + "Gemma3ForCausalLM" => {Bumblebee.Text.Gemma3, :for_causal_language_modeling}, + "Gemma3ForSequenceClassification" => {Bumblebee.Text.Gemma3, :for_sequence_classification}, + "Gemma3TextModel" => {Bumblebee.Text.Gemma3, :base}, + "Gemma3TextForCausalLM" => {Bumblebee.Text.Gemma3, :for_causal_language_modeling}, + "Gemma3TextForSequenceClassification" => + {Bumblebee.Text.Gemma3, :for_sequence_classification}, "GPT2ForSequenceClassification" => {Bumblebee.Text.Gpt2, :for_sequence_classification}, "GPT2ForTokenClassification" => {Bumblebee.Text.Gpt2, :for_token_classification}, "GPT2LMHeadModel" => {Bumblebee.Text.Gpt2, :for_causal_language_modeling}, @@ -249,6 +256,8 @@ defmodule Bumblebee do "camembert" => :camembert, "clip" => :clip, "gemma" => :gemma, + "gemma3" => :gemma, + "gemma3_text" => :gemma, "gpt_neox" => :gpt_neo_x, "gpt2" => :gpt2, "gpt_bigcode" => :gpt2, diff --git a/lib/bumblebee/layers/transformer.ex b/lib/bumblebee/layers/transformer.ex index 59ad9595..01fea9c2 100644 --- a/lib/bumblebee/layers/transformer.ex +++ b/lib/bumblebee/layers/transformer.ex @@ -25,6 +25,13 @@ defmodule Bumblebee.Layers.Transformer do - a keyword list (applied to all blocks) - a function that takes the block index and returns the configuration + * `:attention_window_size` - sliding window attention configuration. Can be: + - `nil` for global attention (default) + - a `{left, right}` tuple (applied to all blocks) + - a function that takes the block index and returns `nil` or `{left, right}`. + This enables per-layer attention patterns like Gemma 3's alternating + local/global attention (5 local layers followed by 1 global layer) + * `:name` - the prefix for layer names For all other options (including required options) see `block/2`. @@ -52,7 +59,6 @@ defmodule Bumblebee.Layers.Transformer do :output_use_bias, :layer_norm, :block_type, - :attention_window_size, :scale_attention_weights ] @@ -64,6 +70,7 @@ defmodule Bumblebee.Layers.Transformer do :name, :num_blocks, :rotary_embedding, + :attention_window_size, attention_mask: Layers.none(), attention_head_mask: Layers.none(), attention_relative_bias: nil, @@ -85,6 +92,7 @@ defmodule Bumblebee.Layers.Transformer do cross_attention_head_mask = opts[:cross_attention_head_mask] cache = opts[:cache] rotary_embedding = opts[:rotary_embedding] + attention_window_size = opts[:attention_window_size] block_opts = Keyword.take(opts, block_opts_keys) @@ -121,6 +129,15 @@ defmodule Bumblebee.Layers.Transformer do config when is_list(config) -> config end + # Support per-layer attention window size for models like Gemma 3 + # that alternate between local (sliding window) and global attention + block_attention_window_size = + case attention_window_size do + nil -> nil + fun when is_function(fun, 1) -> fun.(idx) + size -> size + end + {hidden_state, attention, cross_attention, block_cache, attention_relative_bias} = block( state.hidden_state, @@ -134,6 +151,7 @@ defmodule Bumblebee.Layers.Transformer do block_cache: block_cache, offset: offset, rotary_embedding: block_rotary_embedding, + attention_window_size: block_attention_window_size, name: join(name, idx) ] ++ block_opts ) diff --git a/lib/bumblebee/text/gemma3.ex b/lib/bumblebee/text/gemma3.ex new file mode 100644 index 00000000..1a161d0e --- /dev/null +++ b/lib/bumblebee/text/gemma3.ex @@ -0,0 +1,726 @@ +defmodule Bumblebee.Text.Gemma3 do + alias Bumblebee.Shared + + options = + [ + vocab_size: [ + default: 262_208, + doc: """ + the vocabulary size of the token embedding. This corresponds to the number of distinct + tokens that can be represented in model input and output + """ + ], + max_positions: [ + default: 131_072, + doc: """ + the vocabulary size of the position embedding. This corresponds to the maximum sequence + length that this model can process. Typically this is set to a large value just in case, + such as 512, 1024 or 2048 + """ + ], + hidden_size: [ + default: 2304, + doc: "the dimensionality of hidden layers" + ], + intermediate_size: [ + default: 9216, + doc: "the dimensionality of intermediate layers" + ], + attention_head_size: [ + default: 256, + doc: "the size of the key, value, and query projection per attention head" + ], + num_blocks: [ + default: 26, + doc: "the number of Transformer blocks in the model" + ], + num_attention_heads: [ + default: 8, + doc: "the number of attention heads for each attention layer in the model" + ], + num_key_value_heads: [ + default: 4, + doc: "the number of key value heads for each attention layer in the model" + ], + activation: [ + default: :gelu_approx_tanh, + doc: "the activation function" + ], + rotary_embedding_base: [ + default: 10_000, + doc: "base for computing rotary embedding frequency" + ], + rotary_embedding_scaling_strategy: [ + default: nil, + doc: """ + scaling configuration for rotary embedding. Currently the supported values are: + + * `%{type: :linear, factor: number()}` + + * `%{type: :dynamic, factor: number()}` + + For more details see https://www.reddit.com/r/LocalLlama/comments/14mrgpr/dynamically_scaled_rope_further_increases + """ + ], + use_attention_bias: [ + default: false, + doc: + "whether or not to use bias in the query, key, value, and output projections in attention layers" + ], + layer_norm_epsilon: [ + default: 1.0e-6, + doc: "the epsilon used by RMS normalization layers" + ], + initializer_scale: [ + default: 0.02, + doc: + "the standard deviation of the normal initializer used for initializing kernel parameters" + ], + sliding_window: [ + default: 4096, + doc: "the sliding window size for local attention layers" + ], + global_attention_layer_interval: [ + default: 6, + doc: """ + the interval for global attention layers. In Gemma 3, every Nth layer uses global + attention while others use local (sliding window) attention. A value of 6 means + layers 5, 11, 17, 23... use global attention (5:1 local/global ratio) + """ + ], + tie_word_embeddings: [ + default: true, + doc: "whether to tie input and output embedding weights" + ] + ] ++ + Shared.common_options([:num_labels, :id_to_label]) ++ Shared.token_options(pad_token_id: 0) + + @moduledoc """ + Gemma 3 model family. + + Gemma 3 is an updated version of the Gemma architecture with several key improvements: + + * Alternating local/global attention (5:1 ratio by default) for better efficiency + * Larger vocabulary (262K tokens) + * Extended context length (up to 128K tokens) + + This module also supports FunctionGemma, which is built on Gemma 3 and optimized + for function calling tasks. + + ## Architectures + + * `:base` - plain Gemma 3 without any head on top + + * `:for_causal_language_modeling` - Gemma 3 with a language modeling + head. The head returns logits for each token in the original + sequence + + * `:for_sequence_classification` - Gemma 3 with a sequence + classification head. The head returns logits corresponding to + possible classes + + ## Inputs + + * `"input_ids"` - `{batch_size, sequence_length}` + + Indices of input sequence tokens in the vocabulary. + + * `"attention_mask"` - `{batch_size, sequence_length}` + + Mask indicating which tokens to attend to. This is used to ignore + padding tokens, which are added when processing a batch of sequences + with different length. + + * `"position_ids"` - `{batch_size, sequence_length}` + + Indices of positions of each input sequence tokens in the position + embeddings. + + * `"attention_head_mask"` - `{encoder_num_blocks, encoder_num_attention_heads}` + + Mask to nullify selected heads of the self-attention blocks in + the encoder. + + * `"input_embeddings"` - `{batch_size, sequence_length, hidden_size}` + + Embedded representation of `"input_ids"`, which can be specified + for more control over how `"input_ids"` are embedded than the + model's internal embedding lookup. If `"input_embeddings"` are present, + then `"input_ids"` will be ignored. + + * `"cache"` + + A container with cached layer results used to speed up sequential + decoding (autoregression). With cache, certain hidden states are + taken from the cache, rather than recomputed on every decoding + pass. The cache should be treated as opaque and initialized with + `Bumblebee.Text.Generation.init_cache/4`. + + ## Global layer options + + #{Shared.global_layer_options_doc([:output_hidden_states, :output_attentions])} + + ## Configuration + + #{Shared.options_doc(options)} + """ + + defstruct [architecture: :base] ++ Shared.option_defaults(options) + + @behaviour Bumblebee.ModelSpec + @behaviour Bumblebee.Configurable + @behaviour Bumblebee.Text.Generation + + import Bumblebee.Utils.Model, only: [join: 2] + + alias Bumblebee.Layers + + @impl true + def architectures(), + do: [ + :base, + :for_causal_language_modeling, + :for_sequence_classification + ] + + @impl true + def config(spec, opts) do + spec + |> Shared.put_config_attrs(opts) + |> Shared.validate_label_options() + end + + @impl true + def input_template(_spec) do + %{ + "input_ids" => Nx.template({1, 1}, :s64) + } + end + + @impl true + def init_cache(spec, batch_size, max_length, _inputs) do + Layers.Decoder.init_cache(batch_size, max_length, + hidden_size: spec.hidden_size, + attention_head_size: spec.attention_head_size, + decoder_num_attention_heads: spec.num_attention_heads, + decoder_num_blocks: spec.num_blocks + ) + end + + @impl true + def traverse_cache(_spec, cache, fun) do + Layers.Decoder.traverse_cache(cache, fun) + end + + @impl true + def model(%__MODULE__{architecture: :base} = spec) do + inputs = inputs(spec) + + inputs + |> core(spec) + |> Layers.output() + end + + def model(%__MODULE__{architecture: :for_causal_language_modeling} = spec) do + inputs = inputs(spec) + + outputs = core(inputs, spec) + logits = language_modeling_head(outputs.hidden_state, spec, name: "language_modeling_head") + + Layers.output(%{ + logits: logits, + hidden_states: outputs.hidden_states, + attentions: outputs.attentions, + cache: outputs.cache + }) + end + + def model(%__MODULE__{architecture: :for_sequence_classification} = spec) do + inputs = inputs(spec) + + outputs = core(inputs, spec) + + logits = + Axon.dense(outputs.hidden_state, spec.num_labels, + kernel_initializer: kernel_initializer(spec), + name: "sequence_classification_head.output", + use_bias: false + ) + + pooled_logits = + Layers.if_present inputs["input_ids"] do + Axon.layer( + fn logits, input_ids, _opts -> + indices = + input_ids + |> Nx.not_equal(spec.pad_token_id) + |> Nx.sum(axes: [-1]) + |> Nx.subtract(1) + |> Nx.as_type({:s, 64}) + + Bumblebee.Utils.Nx.batched_take(logits, indices) + end, + [logits, inputs["input_ids"]] + ) + else + Layers.take_token(logits, axis: 1, index: -1) + end + + Layers.output(%{ + logits: pooled_logits, + hidden_states: outputs.hidden_states, + attentions: outputs.attentions, + cache: outputs.cache + }) + end + + defp inputs(spec) do + shape = {nil, nil} + hidden_shape = {nil, nil, spec.hidden_size} + + attention_head_mask_shape = {spec.num_blocks, spec.num_attention_heads} + + Bumblebee.Utils.Model.inputs_to_map([ + Axon.input("input_ids", optional: true, shape: shape), + Axon.input("attention_mask", optional: true, shape: shape), + Axon.input("position_ids", optional: true, shape: shape), + Axon.input("attention_head_mask", optional: true, shape: attention_head_mask_shape), + Axon.input("input_embeddings", optional: true, shape: hidden_shape), + Axon.input("cache", optional: true) + ]) + end + + defp core(inputs, spec) do + embeddings = + embedder( + inputs["input_ids"], + inputs["input_embeddings"], + spec, + name: "embedder" + ) + + position_ids = + Layers.default inputs["position_ids"] do + Layers.default_position_ids(embeddings) + end + + decoder_outputs = + decoder( + embeddings, + position_ids, + inputs["attention_mask"], + inputs["attention_head_mask"], + inputs["cache"], + spec, + name: "decoder" + ) + + hidden_state = + Layers.rms_norm(decoder_outputs.hidden_state, + name: "output_norm", + shift: 1.0, + epsilon: spec.layer_norm_epsilon, + upcast: :all + ) + + %{ + hidden_state: hidden_state, + hidden_states: Layers.append(decoder_outputs.hidden_states, hidden_state), + attentions: decoder_outputs.attentions, + cache: decoder_outputs.cache + } + end + + defp embedder(input_ids, input_embeddings, spec, opts) do + name = opts[:name] + + # Note: Gemma 3 still normalizes embeddings by sqrt(hidden_size), same as Gemma v1 + Layers.default input_embeddings do + Axon.embedding(input_ids, spec.vocab_size, spec.hidden_size, + kernel_initializer: kernel_initializer(spec), + name: join(name, "token_embedding") + ) + end + |> Axon.nx(fn x -> + normalization_factor = + spec.hidden_size + |> Nx.tensor(type: Nx.type(x)) + |> Nx.sqrt() + + Nx.multiply(x, normalization_factor) + end) + end + + defp decoder( + hidden_state, + position_ids, + attention_mask, + attention_head_mask, + cache, + spec, + opts + ) do + name = opts[:name] + + # Use cached attention mask + {attention_mask, cache} = Layers.Decoder.cached_attention_mask(attention_mask, cache) + offset = Layers.Decoder.get_cache_offset(cache) + + state = %{ + hidden_state: hidden_state, + hidden_states: Axon.container({hidden_state}), + attentions: Axon.container({}), + cache: cache + } + + outputs = + for idx <- 0..(spec.num_blocks - 1), reduce: state do + state -> + block_attention_head_mask = Axon.nx(attention_head_mask, & &1[idx]) + block_cache = Layers.Decoder.get_block_cache(state.cache, idx) + + # Gemma 3 alternates between local (sliding window) and global attention + # Every global_attention_layer_interval-th layer uses global attention + attention_window_size = + if rem(idx + 1, spec.global_attention_layer_interval) == 0 do + # Global attention (no window) + nil + else + # Local attention with sliding window + {spec.sliding_window, spec.sliding_window} + end + + {hidden_state, attention, block_cache} = + gemma3_block(state.hidden_state, + attention_mask: attention_mask, + attention_head_mask: block_attention_head_mask, + block_cache: block_cache, + offset: offset, + position_ids: position_ids, + attention_window_size: attention_window_size, + spec: spec, + name: join(name, "blocks.#{idx}") + ) + + cache = Layers.Decoder.put_block_cache(state.cache, idx, block_cache) + + %{ + hidden_state: hidden_state, + hidden_states: Layers.append(state.hidden_states, hidden_state), + attentions: Layers.append(state.attentions, attention), + cache: cache + } + end + + outputs = update_in(outputs.cache, &Layers.Decoder.update_cache_offset(&1, hidden_state)) + + %{ + hidden_state: outputs.hidden_state, + hidden_states: outputs.hidden_states, + attentions: outputs.attentions, + cache: outputs.cache + } + end + + defp gemma3_block(hidden_state, opts) do + attention_mask = opts[:attention_mask] + attention_head_mask = opts[:attention_head_mask] + block_cache = opts[:block_cache] + offset = opts[:offset] + position_ids = opts[:position_ids] + attention_window_size = opts[:attention_window_size] + spec = opts[:spec] + name = opts[:name] + + {self_attention_cache, cross_attention_cache} = + Layers.Decoder.get_attention_caches(block_cache) + + # Self-attention with pre-norm (input_layernorm) + shortcut = hidden_state + + hidden_state = + Layers.rms_norm(hidden_state, + shift: 1.0, + name: join(name, "self_attention_norm"), + epsilon: spec.layer_norm_epsilon, + upcast: :all + ) + + {hidden_state, attention, self_attention_cache} = + gemma3_attention(hidden_state, hidden_state, hidden_state, + attention_mask: attention_mask, + attention_head_mask: attention_head_mask, + attention_cache: self_attention_cache, + offset: offset, + position_ids: position_ids, + attention_window_size: attention_window_size, + spec: spec, + name: join(name, "self_attention") + ) + + # Post-attention norm BEFORE residual add (Gemma 3 specific) + hidden_state = + Layers.rms_norm(hidden_state, + shift: 1.0, + name: join(name, "post_attention_norm"), + epsilon: spec.layer_norm_epsilon, + upcast: :all + ) + + # Residual add AFTER post_attention_norm + hidden_state = Axon.add(shortcut, hidden_state) + + # FFN with pre/post norms (Gemma 3 specific) + shortcut = hidden_state + + hidden_state = + Layers.rms_norm(hidden_state, + shift: 1.0, + name: join(name, "pre_ffn_norm"), + epsilon: spec.layer_norm_epsilon, + upcast: :all + ) + + hidden_state = + gated_ffn(hidden_state, spec.intermediate_size, spec.hidden_size, + name: join(name, "ffn"), + activation: spec.activation + ) + + hidden_state = + Layers.rms_norm(hidden_state, + shift: 1.0, + name: join(name, "post_ffn_norm"), + epsilon: spec.layer_norm_epsilon, + upcast: :all + ) + + hidden_state = Axon.add(shortcut, hidden_state) + + block_cache = + Layers.Decoder.put_attention_caches( + block_cache, + self_attention_cache, + cross_attention_cache + ) + + {hidden_state, attention, block_cache} + end + + defp gemma3_attention(query, key, value, opts) do + attention_mask = opts[:attention_mask] + attention_head_mask = opts[:attention_head_mask] + attention_cache = opts[:attention_cache] + offset = opts[:offset] + position_ids = opts[:position_ids] + attention_window_size = opts[:attention_window_size] + spec = opts[:spec] + name = opts[:name] + + num_heads = spec.num_attention_heads + num_key_value_heads = spec.num_key_value_heads + attention_head_size = spec.attention_head_size + inner_size = num_heads * attention_head_size + inner_kv_size = num_key_value_heads * attention_head_size + + # Project Q, K, V + query = + query + |> Axon.dense(inner_size, + kernel_initializer: kernel_initializer(spec), + name: join(name, "query"), + use_bias: spec.use_attention_bias + ) + |> Layers.split_heads(num_heads) + + key = + key + |> Axon.dense(inner_kv_size, + kernel_initializer: kernel_initializer(spec), + name: join(name, "key"), + use_bias: spec.use_attention_bias + ) + |> Layers.split_heads(num_key_value_heads) + + value = + value + |> Axon.dense(inner_kv_size, + kernel_initializer: kernel_initializer(spec), + name: join(name, "value"), + use_bias: spec.use_attention_bias + ) + |> Layers.split_heads(num_key_value_heads) + + # Apply QK-norm (Gemma 3 specific) - uses (1+weight) formula like other norms + query = + Layers.rms_norm(query, + shift: 1.0, + name: join(name, "q_norm"), + epsilon: spec.layer_norm_epsilon + ) + + key = + Layers.rms_norm(key, + shift: 1.0, + name: join(name, "k_norm"), + epsilon: spec.layer_norm_epsilon + ) + + # Apply rotary embeddings + {query, key} = + Layers.rotary_embedding(query, key, position_ids, attention_mask, attention_head_size, + max_positions: spec.max_positions, + base: spec.rotary_embedding_base, + scaling_strategy: spec.rotary_embedding_scaling_strategy + ) + + # Replicate K/V heads for GQA + num_key_value_groups = div(num_heads, num_key_value_heads) + key = repeat_states(key, num_key_value_groups) + value = repeat_states(value, num_key_value_groups) + + # Update cache + {key, value, attention_cache} = + Layers.Decoder.cached_attention_key_values(key, value, attention_cache, offset) + + # Compute attention + # Layers.attention signature: (query, key, value, key_mask, head_mask, bias, offset, opts) + {attention_output, attention_weights} = + Layers.attention( + query, + key, + value, + attention_mask, + attention_head_mask, + Layers.none(), + offset, + scale: true, + causal: true, + window_size: attention_window_size, + dropout_rate: 0.0 + ) + + # Output projection + hidden_state = + attention_output + |> Layers.flatten_trailing() + |> Axon.dense(spec.hidden_size, + kernel_initializer: kernel_initializer(spec), + name: join(name, "output"), + use_bias: spec.use_attention_bias + ) + + {hidden_state, attention_weights, attention_cache} + end + + defp repeat_states(state, 1), do: state + + defp repeat_states(state, times) do + Layers.repeat_interleave(state, times, axis: 2) + end + + defp gated_ffn(hidden_state, intermediate_size, output_size, opts) do + name = opts[:name] + activation = opts[:activation] + + intermediate = + Axon.dense(hidden_state, intermediate_size, + name: join(name, "intermediate"), + use_bias: false + ) + + gate = Axon.dense(hidden_state, intermediate_size, name: join(name, "gate"), use_bias: false) + + hidden_state = Axon.multiply(intermediate, Layers.activation(gate, activation)) + + Axon.dense(hidden_state, output_size, name: join(name, "output"), use_bias: false) + end + + defp language_modeling_head(hidden_state, spec, opts) do + name = opts[:name] + + # TODO: Tie lm-head to word embedding as a spec option + Layers.dense_transposed(hidden_state, spec.vocab_size, + kernel_initializer: kernel_initializer(spec), + name: join(name, "output") + ) + end + + defp kernel_initializer(spec) do + Axon.Initializers.normal(scale: spec.initializer_scale) + end + + defimpl Bumblebee.HuggingFace.Transformers.Config do + def load(spec, data) do + import Shared.Converters + + scaling_strategy_converter = fn name, value -> + case value do + %{"type" => "linear", "factor" => factor} when is_number(factor) -> + {:ok, %{type: :linear, factor: factor}} + + %{"type" => "dynamic", "factor" => factor} when is_number(factor) -> + {:ok, %{type: :dynamic, factor: factor}} + + _other -> + {:error, "invalid format for #{inspect(name)}, got: #{inspect(value)}"} + end + end + + opts = + convert!(data, + vocab_size: {"vocab_size", number()}, + max_positions: {"max_position_embeddings", number()}, + hidden_size: {"hidden_size", number()}, + num_blocks: {"num_hidden_layers", number()}, + num_attention_heads: {"num_attention_heads", number()}, + num_key_value_heads: {"num_key_value_heads", number()}, + attention_head_size: {"head_dim", number()}, + intermediate_size: {"intermediate_size", number()}, + activation: {"hidden_activation", activation()}, + use_attention_bias: {"attention_bias", boolean()}, + rotary_embedding_base: {"rope_theta", number()}, + rotary_embedding_scaling_strategy: + {"rope_scaling", optional(scaling_strategy_converter)}, + initializer_scale: {"initializer_range", number()}, + layer_norm_epsilon: {"rms_norm_eps", number()}, + sliding_window: {"sliding_window", optional(number())}, + tie_word_embeddings: {"tie_word_embeddings", boolean()} + ) ++ Shared.common_options_from_transformers(data, spec) + + @for.config(spec, opts) + end + end + + defimpl Bumblebee.HuggingFace.Transformers.Model do + def params_mapping(spec) do + # Gemma 3 specific params mapping with QK-norm and extra FFN layer norms + %{ + "embedder.token_embedding" => "model.embed_tokens", + # Attention projections + "decoder.blocks.{n}.self_attention.query" => "model.layers.{n}.self_attn.q_proj", + "decoder.blocks.{n}.self_attention.key" => "model.layers.{n}.self_attn.k_proj", + "decoder.blocks.{n}.self_attention.value" => "model.layers.{n}.self_attn.v_proj", + "decoder.blocks.{n}.self_attention.output" => "model.layers.{n}.self_attn.o_proj", + # QK-norm (Gemma 3 specific) + "decoder.blocks.{n}.self_attention.q_norm" => "model.layers.{n}.self_attn.q_norm", + "decoder.blocks.{n}.self_attention.k_norm" => "model.layers.{n}.self_attn.k_norm", + # Layer norms + "decoder.blocks.{n}.self_attention_norm" => "model.layers.{n}.input_layernorm", + "decoder.blocks.{n}.post_attention_norm" => "model.layers.{n}.post_attention_layernorm", + # FFN layer norms (Gemma 3 specific) + "decoder.blocks.{n}.pre_ffn_norm" => "model.layers.{n}.pre_feedforward_layernorm", + "decoder.blocks.{n}.post_ffn_norm" => "model.layers.{n}.post_feedforward_layernorm", + # FFN projections + "decoder.blocks.{n}.ffn.gate" => "model.layers.{n}.mlp.gate_proj", + "decoder.blocks.{n}.ffn.intermediate" => "model.layers.{n}.mlp.up_proj", + "decoder.blocks.{n}.ffn.output" => "model.layers.{n}.mlp.down_proj", + # Output + "output_norm" => "model.norm", + "language_modeling_head.output" => + if(spec.tie_word_embeddings, do: "model.embed_tokens", else: "lm_head"), + "sequence_classification_head.output" => "score" + } + end + end +end diff --git a/notebooks/function_calling.livemd b/notebooks/function_calling.livemd new file mode 100644 index 00000000..028c2bea --- /dev/null +++ b/notebooks/function_calling.livemd @@ -0,0 +1,554 @@ +# Function Calling with FunctionGemma + +```elixir +Mix.install([ + {:bumblebee, "~> 0.6.0"}, + {:nx, "~> 0.9.0"}, + {:exla, "~> 0.9.0"}, + {:kino, "~> 0.14.0"} +]) + +Nx.global_default_backend({EXLA.Backend, client: :host}) +``` + +## Why FunctionGemma? + +[FunctionGemma](https://huggingface.co/google/functiongemma-270m-it) is a compact 270M parameter model from Google, specifically designed for function calling tasks. Here's why it's a great choice: + +### Lightweight & Fast + +- **Only 270M parameters** - runs efficiently on CPU or modest GPU hardware +- **Quick inference** - ideal for real-time applications +- **Low memory footprint** - can run alongside other models + +### Perfect for Edge & IoT + +- **Home Assistants** - control smart home devices with natural language +- **Voice Interfaces** - parse voice commands into structured function calls +- **Embedded Systems** - fits on devices with limited resources + +### Easy to Fine-tune + +- **Train on Google Colab T4** - the small size makes fine-tuning accessible +- **Custom function schemas** - adapt to your specific API or tool set +- **Fast iteration** - experiment and deploy quickly + +### Structured Output + +- **Reliable function call format** - consistent `` / `` markers +- **Easy to parse** - extract function name and arguments programmatically +- **Multi-turn support** - handles function responses and follow-up calls + +## Loading the Model + +FunctionGemma requires accepting Google's license on HuggingFace. Visit [google/functiongemma-270m-it](https://huggingface.co/google/functiongemma-270m-it) to request access, then create a [HuggingFace auth token](https://huggingface.co/settings/tokens) and add it as a `HF_TOKEN` Livebook secret. + +```elixir +hf_token = System.fetch_env!("LB_HF_TOKEN") +repo = {:hf, "google/functiongemma-270m-it", auth_token: hf_token} + +{:ok, model_info} = Bumblebee.load_model(repo) +{:ok, tokenizer} = Bumblebee.load_tokenizer(repo) +{:ok, generation_config} = Bumblebee.load_generation_config(repo) + +:ok +``` + +## Creating the Serving + +```elixir +serving = + Bumblebee.Text.generation(model_info, tokenizer, generation_config, + compile: [batch_size: 1, sequence_length: 512], + defn_options: [compiler: EXLA] + ) + +Kino.start_child({Nx.Serving, name: FunctionGemma, serving: serving}) +``` + +## Function Schema Builder + +FunctionGemma uses a specific prompt format. Here's a complete module to build function declarations: + +```elixir +defmodule FunctionGemma.Schema do + @moduledoc """ + Builds FunctionGemma-compatible function declarations. + + ## Example + + FunctionGemma.Schema.declare("get_weather", "Get current weather", [ + location: [type: :string, description: "City name", required: true], + units: [type: :string, description: "celsius or fahrenheit"] + ]) + """ + + @type param_opts :: [ + type: :string | :number | :boolean | :array, + description: String.t(), + required: boolean() + ] + + @doc """ + Declares a function with its name, description, and parameters. + + ## Parameters + + - `name` - The function name (e.g., "get_weather") + - `description` - What the function does + - `parameters` - Keyword list of `{param_name, options}` + + ## Parameter Options + + - `:type` - One of `:string`, `:number`, `:boolean`, `:array` (default: `:string`) + - `:description` - Description of the parameter + - `:required` - Whether the parameter is required (default: `false`) + """ + @spec declare(String.t(), String.t(), keyword(param_opts())) :: String.t() + def declare(name, description, parameters \\ []) do + params_schema = build_parameters_schema(parameters) + + "" <> + "declaration:#{name}{" <> + "description:#{description}," <> + "parameters:#{params_schema}" <> + "}" + end + + @doc """ + Builds a complete prompt with system message, functions, and user query. + """ + @spec build_prompt(String.t(), [String.t()], String.t()) :: String.t() + def build_prompt(system_message, function_declarations, user_message) do + functions = Enum.join(function_declarations, "") + + """ + developer + #{system_message} + #{functions} + user + #{user_message} + model + """ + end + + # Private helpers + + defp build_parameters_schema(parameters) do + properties = build_properties(parameters) + required = build_required(parameters) + + "{properties:{#{properties}},required:[#{required}],type:OBJECT}" + end + + defp build_properties(parameters) do + parameters + |> Enum.map(fn {name, opts} -> + type = opts |> Keyword.get(:type, :string) |> type_to_string() + desc = Keyword.get(opts, :description) + + prop = + if desc do + "#{name}:{description:#{desc},type:#{type}}" + else + "#{name}:{type:#{type}}" + end + + prop + end) + |> Enum.join(",") + end + + defp build_required(parameters) do + parameters + |> Enum.filter(fn {_, opts} -> Keyword.get(opts, :required, false) end) + |> Enum.map(fn {name, _} -> "#{name}" end) + |> Enum.join(",") + end + + defp type_to_string(:string), do: "STRING" + defp type_to_string(:number), do: "NUMBER" + defp type_to_string(:boolean), do: "BOOLEAN" + defp type_to_string(:array), do: "ARRAY" + defp type_to_string(other), do: String.upcase(to_string(other)) +end +``` + +## Function Call Parser + +Parse the model's function call output into structured data: + +```elixir +defmodule FunctionGemma.Parser do + @moduledoc """ + Parses FunctionGemma function call responses. + """ + + @type function_call :: %{ + function: String.t(), + arguments: map() + } + + @doc """ + Parses a FunctionGemma response into a function call struct. + + ## Examples + + iex> parse("call:get_weather{location:Paris}") + {:ok, %{function: "get_weather", arguments: %{"location" => "Paris"}}} + + iex> parse("I don't know") + {:error, :no_function_call} + """ + @spec parse(String.t()) :: {:ok, function_call()} | {:error, atom()} + def parse(response) do + pattern = ~r/call:(\w+)\{(.*?)\}/ + + case Regex.run(pattern, response) do + [_, function_name, args_str] -> + arguments = parse_arguments(args_str) + {:ok, %{function: function_name, arguments: arguments}} + + nil -> + {:error, :no_function_call} + end + end + + @doc """ + Same as `parse/1` but raises on error. + """ + @spec parse!(String.t()) :: function_call() + def parse!(response) do + case parse(response) do + {:ok, result} -> result + {:error, reason} -> raise "Failed to parse function call: #{reason}" + end + end + + # Parse key:value pairs + defp parse_arguments(""), do: %{} + + defp parse_arguments(args_str) do + ~r/(\w+):([^<]*)/ + |> Regex.scan(args_str) + |> Enum.map(fn [_, key, value] -> {key, value} end) + |> Map.new() + end +end +``` + +## Mock Functions (Smart Home Example) + +Let's create actual mock functions that simulate a smart home system: + +```elixir +defmodule SmartHome do + @moduledoc """ + Mock smart home functions that FunctionGemma can call. + """ + + # Simulated device states + use Agent + + def start_link do + Agent.start_link( + fn -> + %{ + lights: %{ + "living room" => false, + "bedroom" => false, + "kitchen" => false + }, + thermostat: 20, + weather_cache: %{} + } + end, + name: __MODULE__ + ) + end + + @doc """ + Controls a light in a specific room. + + ## Parameters + - room: The room name (living room, bedroom, kitchen) + - action: "on" or "off" + """ + def control_light(%{"room" => room, "action" => action}) do + room = String.downcase(room) + state = action == "on" + + Agent.update(__MODULE__, fn data -> + put_in(data, [:lights, room], state) + end) + + current = Agent.get(__MODULE__, & &1.lights) + + %{ + success: true, + message: "Turned #{action} the #{room} light", + current_states: current + } + end + + def control_light(_), do: %{success: false, message: "Missing room or action parameter"} + + @doc """ + Gets the current weather for a location (mocked with random data). + + ## Parameters + - location: The city name + """ + def get_weather(%{"location" => location}) do + # Simulate weather data + conditions = ["sunny", "cloudy", "rainy", "partly cloudy", "windy"] + temp = Enum.random(15..30) + humidity = Enum.random(40..80) + condition = Enum.random(conditions) + + %{ + success: true, + location: location, + temperature: temp, + humidity: humidity, + condition: condition, + message: "Weather in #{location}: #{temp}C, #{condition}, #{humidity}% humidity" + } + end + + def get_weather(_), do: %{success: false, message: "Missing location parameter"} + + @doc """ + Sets the thermostat temperature. + + ## Parameters + - temperature: Temperature in Celsius (number as string) + """ + def set_thermostat(%{"temperature" => temp_str}) do + temp = + case Integer.parse(temp_str) do + {t, _} -> t + :error -> 20 + end + + Agent.update(__MODULE__, fn data -> + Map.put(data, :thermostat, temp) + end) + + %{ + success: true, + message: "Thermostat set to #{temp}C", + temperature: temp + } + end + + def set_thermostat(_), do: %{success: false, message: "Missing temperature parameter"} + + @doc """ + Returns current state of all devices. + """ + def get_status do + Agent.get(__MODULE__, & &1) + end +end + +# Start the mock smart home +SmartHome.start_link() +IO.puts("Smart Home system initialized!") +IO.inspect(SmartHome.get_status(), label: "Initial state") +``` + +## Function Executor + +Now let's create an executor that connects FunctionGemma to our mock functions: + +```elixir +defmodule FunctionGemma.Executor do + @moduledoc """ + Executes function calls from FunctionGemma using registered handlers. + """ + + @doc """ + Executes a parsed function call against registered handlers. + """ + def execute(%{function: function, arguments: args}, handlers) do + case Map.get(handlers, function) do + nil -> + {:error, "Unknown function: #{function}"} + + handler when is_function(handler, 1) -> + result = handler.(args) + {:ok, result} + end + end + + @doc """ + Complete pipeline: send prompt to model, parse response, execute function. + """ + def run(serving_name, prompt, handlers) do + # Get model response + %{results: [%{text: response}]} = Nx.Serving.batched_run(serving_name, prompt) + + IO.puts("Model response: #{response}") + + # Parse function call + case FunctionGemma.Parser.parse(response) do + {:ok, function_call} -> + IO.puts("Parsed: #{function_call.function}(#{inspect(function_call.arguments)})") + + # Execute function + case execute(function_call, handlers) do + {:ok, result} -> + {:ok, function_call, result} + + {:error, reason} -> + {:error, reason} + end + + {:error, reason} -> + {:error, reason} + end + end +end +``` + +## Putting It All Together + +Let's define our function schema and handlers, then run the complete pipeline: + +```elixir +# Define function declarations for the model +function_declarations = [ + FunctionGemma.Schema.declare( + "control_light", + "Turn a light on or off in a specific room", + room: [type: :string, description: "The room name (living room, bedroom, kitchen)", required: true], + action: [type: :string, description: "on or off", required: true] + ), + FunctionGemma.Schema.declare( + "get_weather", + "Get the current weather for a location", + location: [type: :string, description: "The city name", required: true] + ), + FunctionGemma.Schema.declare( + "set_thermostat", + "Set the thermostat temperature", + temperature: [type: :number, description: "Temperature in Celsius", required: true] + ) +] + +# Map function names to their implementations +function_handlers = %{ + "control_light" => &SmartHome.control_light/1, + "get_weather" => &SmartHome.get_weather/1, + "set_thermostat" => &SmartHome.set_thermostat/1 +} + +IO.puts("Registered #{length(function_declarations)} functions") +:ok +``` + +## Interactive Demo + +Try sending commands to the smart home assistant: + +```elixir +user_input = Kino.Input.textarea("Command", + default: "Turn on the lights in the living room" +) +``` + +```elixir +user_message = Kino.Input.read(user_input) + +prompt = + FunctionGemma.Schema.build_prompt( + "You are a smart home assistant that controls devices and provides information.", + function_declarations, + user_message + ) + +IO.puts("=== Sending to FunctionGemma ===") +IO.puts("User: #{user_message}\n") + +case FunctionGemma.Executor.run(FunctionGemma, prompt, function_handlers) do + {:ok, function_call, result} -> + IO.puts("\n=== Function Executed ===") + IO.puts("Function: #{function_call.function}") + IO.puts("Arguments: #{inspect(function_call.arguments)}") + IO.puts("\n=== Result ===") + IO.inspect(result, pretty: true) + + {:error, reason} -> + IO.puts("Error: #{inspect(reason)}") +end +``` + +## Batch Demo - Multiple Commands + +Watch the smart home respond to multiple commands: + +```elixir +commands = [ + "What's the weather in Tokyo?", + "Turn on the bedroom lights", + "Set the temperature to 22 degrees", + "Turn off the kitchen light" +] + +Kino.Shorts.data_table( + for command <- commands do + prompt = + FunctionGemma.Schema.build_prompt( + "You are a smart home assistant.", + function_declarations, + command + ) + + result = + case FunctionGemma.Executor.run(FunctionGemma, prompt, function_handlers) do + {:ok, fc, res} -> + %{ + command: command, + function: fc.function, + args: inspect(fc.arguments), + result: res.message + } + + {:error, reason} -> + %{command: command, function: "ERROR", args: "", result: inspect(reason)} + end + + IO.puts("---") + result + end +) +``` + +## Check Final Smart Home State + +```elixir +IO.puts("=== Final Smart Home State ===") +SmartHome.get_status() |> IO.inspect(pretty: true) +``` + +## Use Cases + +Here are some practical applications for FunctionGemma: + +| Use Case | Example Functions | +|----------|-------------------| +| **Smart Home** | control_light, set_thermostat, lock_door, play_music | +| **Calendar** | create_event, list_events, reschedule_meeting | +| **E-commerce** | search_products, add_to_cart, check_order_status | +| **Database** | query_users, update_record, generate_report | +| **DevOps** | deploy_service, check_status, scale_instances | + +## Next Steps + +- **Fine-tune** on your specific function schemas for better accuracy +- **Add function responses** for multi-turn conversations +- **Integrate** with your actual APIs and services +- **Deploy** as a Phoenix LiveView application + +For fine-tuning, check out [Google's FunctionGemma documentation](https://huggingface.co/google/functiongemma-270m-it) and the Bumblebee fine-tuning notebook. diff --git a/test/bumblebee/text/gemma3_test.exs b/test/bumblebee/text/gemma3_test.exs new file mode 100644 index 00000000..210f1a63 --- /dev/null +++ b/test/bumblebee/text/gemma3_test.exs @@ -0,0 +1,55 @@ +defmodule Bumblebee.Text.Gemma3Test do + use ExUnit.Case, async: true + + import Bumblebee.TestHelpers + + @moduletag model_test_tags() + + test ":base" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model({:hf, "nmaroulis/tiny-random-Gemma3Model"}) + + assert %Bumblebee.Text.Gemma3{architecture: :base} = spec + + inputs = %{ + "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), + "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.hidden_state) == {1, 10, 32} + end + + test ":for_sequence_classification" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model({:hf, "nmaroulis/tiny-random-Gemma3ForSequenceClassification"}) + + assert %Bumblebee.Text.Gemma3{architecture: :for_sequence_classification} = spec + + inputs = %{ + "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), + "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.logits) == {1, 2} + end + + test ":for_causal_language_modeling" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model({:hf, "nmaroulis/tiny-random-Gemma3ForCausalLM"}) + + assert %Bumblebee.Text.Gemma3{architecture: :for_causal_language_modeling} = spec + + inputs = %{ + "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), + "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.logits) == {1, 10, 1024} + end +end From 2b0df16e2ce4efa3fa679467ab7b767679d8e793 Mon Sep 17 00:00:00 2001 From: Niko Maroulis Date: Mon, 29 Dec 2025 13:46:49 -0500 Subject: [PATCH 02/11] Refactor Gemma 3 to use Layers.Transformer.blocks and add reference tests - Refactor decoder to use shared Layers.Transformer.blocks infrastructure - Use per-layer attention_window_size function for alternating local/global attention - Use query_norm/key_norm options for QK-normalization - Use custom block_type function for Gemma 3's unique normalization structure - Add assert_all_close with reference values from Python transformers - Fix bug in Layers.Transformer.blocks where attention_window_size was duplicated when using a function for per-layer configuration - Update params_mapping to use query_norm/key_norm naming from shared infrastructure --- lib/bumblebee/layers/transformer.ex | 3 +- lib/bumblebee/text/gemma3.ex | 292 +++++++--------------------- test/bumblebee/text/gemma3_test.exs | 19 ++ 3 files changed, 96 insertions(+), 218 deletions(-) diff --git a/lib/bumblebee/layers/transformer.ex b/lib/bumblebee/layers/transformer.ex index 97eded2e..86c883fd 100644 --- a/lib/bumblebee/layers/transformer.ex +++ b/lib/bumblebee/layers/transformer.ex @@ -43,6 +43,8 @@ defmodule Bumblebee.Layers.Transformer do def blocks(hidden_state, opts) do validate_required_keys!(opts, [:num_blocks, :num_attention_heads, :hidden_size, :ffn]) + # Note: :attention_window_size is NOT in block_opts_keys because it's handled + # specially (supports per-layer function) and passed explicitly to block/2 block_opts_keys = [ :num_attention_heads, :num_key_value_heads, @@ -59,7 +61,6 @@ defmodule Bumblebee.Layers.Transformer do :output_use_bias, :layer_norm, :block_type, - :attention_window_size, :scale_attention_weights, :query_norm, :key_norm diff --git a/lib/bumblebee/text/gemma3.ex b/lib/bumblebee/text/gemma3.ex index 1a161d0e..94331e3e 100644 --- a/lib/bumblebee/text/gemma3.ex +++ b/lib/bumblebee/text/gemma3.ex @@ -362,103 +362,79 @@ defmodule Bumblebee.Text.Gemma3 do ) do name = opts[:name] - # Use cached attention mask - {attention_mask, cache} = Layers.Decoder.cached_attention_mask(attention_mask, cache) - offset = Layers.Decoder.get_cache_offset(cache) - - state = %{ - hidden_state: hidden_state, - hidden_states: Axon.container({hidden_state}), - attentions: Axon.container({}), - cache: cache - } - - outputs = - for idx <- 0..(spec.num_blocks - 1), reduce: state do - state -> - block_attention_head_mask = Axon.nx(attention_head_mask, & &1[idx]) - block_cache = Layers.Decoder.get_block_cache(state.cache, idx) - - # Gemma 3 alternates between local (sliding window) and global attention - # Every global_attention_layer_interval-th layer uses global attention - attention_window_size = - if rem(idx + 1, spec.global_attention_layer_interval) == 0 do - # Global attention (no window) - nil - else - # Local attention with sliding window - {spec.sliding_window, spec.sliding_window} - end - - {hidden_state, attention, block_cache} = - gemma3_block(state.hidden_state, - attention_mask: attention_mask, - attention_head_mask: block_attention_head_mask, - block_cache: block_cache, - offset: offset, - position_ids: position_ids, - attention_window_size: attention_window_size, - spec: spec, - name: join(name, "blocks.#{idx}") - ) - - cache = Layers.Decoder.put_block_cache(state.cache, idx, block_cache) - - %{ - hidden_state: hidden_state, - hidden_states: Layers.append(state.hidden_states, hidden_state), - attentions: Layers.append(state.attentions, attention), - cache: cache - } + # QK-norm functions for Gemma 3 (uses shift: 1.0 for (1+weight) formula) + query_norm = &Layers.rms_norm(&1, shift: 1.0, epsilon: spec.layer_norm_epsilon, name: &2) + key_norm = &Layers.rms_norm(&1, shift: 1.0, epsilon: spec.layer_norm_epsilon, name: &2) + + # Per-layer attention window size for alternating local/global attention + # Every global_attention_layer_interval-th layer uses global attention + attention_window_size = fn idx -> + if rem(idx + 1, spec.global_attention_layer_interval) == 0 do + nil + else + {spec.sliding_window, spec.sliding_window} end + end - outputs = update_in(outputs.cache, &Layers.Decoder.update_cache_offset(&1, hidden_state)) + # Custom block_type function for Gemma 3's unique block structure + block_type = fn hidden_state, steps, block_name -> + gemma3_block_impl(hidden_state, steps, block_name, spec) + end - %{ - hidden_state: outputs.hidden_state, - hidden_states: outputs.hidden_states, - attentions: outputs.attentions, - cache: outputs.cache - } + Layers.Transformer.blocks(hidden_state, + attention_mask: attention_mask, + attention_head_mask: attention_head_mask, + cache: cache, + num_blocks: spec.num_blocks, + num_attention_heads: spec.num_attention_heads, + num_key_value_heads: spec.num_key_value_heads, + hidden_size: spec.hidden_size, + attention_head_size: spec.attention_head_size, + kernel_initializer: kernel_initializer(spec), + layer_norm: + &Layers.rms_norm(&1, + shift: 1.0, + name: &2, + epsilon: spec.layer_norm_epsilon, + upcast: :all + ), + ffn: + &gated_ffn(&1, spec.intermediate_size, spec.hidden_size, + name: &2, + activation: spec.activation + ), + block_type: block_type, + causal: true, + rotary_embedding: [ + position_ids: position_ids, + max_positions: spec.max_positions, + base: spec.rotary_embedding_base, + scaling_strategy: spec.rotary_embedding_scaling_strategy + ], + attention_window_size: attention_window_size, + query_norm: query_norm, + key_norm: key_norm, + query_use_bias: spec.use_attention_bias, + key_use_bias: spec.use_attention_bias, + value_use_bias: spec.use_attention_bias, + output_use_bias: spec.use_attention_bias, + name: join(name, "blocks") + ) end - defp gemma3_block(hidden_state, opts) do - attention_mask = opts[:attention_mask] - attention_head_mask = opts[:attention_head_mask] - block_cache = opts[:block_cache] - offset = opts[:offset] - position_ids = opts[:position_ids] - attention_window_size = opts[:attention_window_size] - spec = opts[:spec] - name = opts[:name] - - {self_attention_cache, cross_attention_cache} = - Layers.Decoder.get_attention_caches(block_cache) - - # Self-attention with pre-norm (input_layernorm) + # Custom block implementation for Gemma 3's unique normalization structure: + # - Post-attention norm BEFORE residual add + # - Pre/post FFN norms + defp gemma3_block_impl(hidden_state, steps, name, spec) do + # Pre-attention norm + attention (using provided steps) shortcut = hidden_state - hidden_state = - Layers.rms_norm(hidden_state, - shift: 1.0, - name: join(name, "self_attention_norm"), - epsilon: spec.layer_norm_epsilon, - upcast: :all - ) - - {hidden_state, attention, self_attention_cache} = - gemma3_attention(hidden_state, hidden_state, hidden_state, - attention_mask: attention_mask, - attention_head_mask: attention_head_mask, - attention_cache: self_attention_cache, - offset: offset, - position_ids: position_ids, - attention_window_size: attention_window_size, - spec: spec, - name: join(name, "self_attention") - ) + {hidden_state, attention_info} = + hidden_state + |> steps.self_attention_norm.() + |> steps.self_attention.() - # Post-attention norm BEFORE residual add (Gemma 3 specific) + # Post-attention norm BEFORE residual (Gemma 3 specific) hidden_state = Layers.rms_norm(hidden_state, shift: 1.0, @@ -467,7 +443,6 @@ defmodule Bumblebee.Text.Gemma3 do upcast: :all ) - # Residual add AFTER post_attention_norm hidden_state = Axon.add(shortcut, hidden_state) # FFN with pre/post norms (Gemma 3 specific) @@ -481,11 +456,7 @@ defmodule Bumblebee.Text.Gemma3 do upcast: :all ) - hidden_state = - gated_ffn(hidden_state, spec.intermediate_size, spec.hidden_size, - name: join(name, "ffn"), - activation: spec.activation - ) + hidden_state = steps.ffn.(hidden_state) hidden_state = Layers.rms_norm(hidden_state, @@ -497,126 +468,13 @@ defmodule Bumblebee.Text.Gemma3 do hidden_state = Axon.add(shortcut, hidden_state) - block_cache = - Layers.Decoder.put_attention_caches( - block_cache, - self_attention_cache, - cross_attention_cache - ) - - {hidden_state, attention, block_cache} - end - - defp gemma3_attention(query, key, value, opts) do - attention_mask = opts[:attention_mask] - attention_head_mask = opts[:attention_head_mask] - attention_cache = opts[:attention_cache] - offset = opts[:offset] - position_ids = opts[:position_ids] - attention_window_size = opts[:attention_window_size] - spec = opts[:spec] - name = opts[:name] - - num_heads = spec.num_attention_heads - num_key_value_heads = spec.num_key_value_heads - attention_head_size = spec.attention_head_size - inner_size = num_heads * attention_head_size - inner_kv_size = num_key_value_heads * attention_head_size - - # Project Q, K, V - query = - query - |> Axon.dense(inner_size, - kernel_initializer: kernel_initializer(spec), - name: join(name, "query"), - use_bias: spec.use_attention_bias - ) - |> Layers.split_heads(num_heads) - - key = - key - |> Axon.dense(inner_kv_size, - kernel_initializer: kernel_initializer(spec), - name: join(name, "key"), - use_bias: spec.use_attention_bias - ) - |> Layers.split_heads(num_key_value_heads) - - value = - value - |> Axon.dense(inner_kv_size, - kernel_initializer: kernel_initializer(spec), - name: join(name, "value"), - use_bias: spec.use_attention_bias - ) - |> Layers.split_heads(num_key_value_heads) - - # Apply QK-norm (Gemma 3 specific) - uses (1+weight) formula like other norms - query = - Layers.rms_norm(query, - shift: 1.0, - name: join(name, "q_norm"), - epsilon: spec.layer_norm_epsilon - ) - - key = - Layers.rms_norm(key, - shift: 1.0, - name: join(name, "k_norm"), - epsilon: spec.layer_norm_epsilon - ) - - # Apply rotary embeddings - {query, key} = - Layers.rotary_embedding(query, key, position_ids, attention_mask, attention_head_size, - max_positions: spec.max_positions, - base: spec.rotary_embedding_base, - scaling_strategy: spec.rotary_embedding_scaling_strategy - ) - - # Replicate K/V heads for GQA - num_key_value_groups = div(num_heads, num_key_value_heads) - key = repeat_states(key, num_key_value_groups) - value = repeat_states(value, num_key_value_groups) - - # Update cache - {key, value, attention_cache} = - Layers.Decoder.cached_attention_key_values(key, value, attention_cache, offset) - - # Compute attention - # Layers.attention signature: (query, key, value, key_mask, head_mask, bias, offset, opts) - {attention_output, attention_weights} = - Layers.attention( - query, - key, - value, - attention_mask, - attention_head_mask, - Layers.none(), - offset, - scale: true, - causal: true, - window_size: attention_window_size, - dropout_rate: 0.0 - ) - - # Output projection - hidden_state = - attention_output - |> Layers.flatten_trailing() - |> Axon.dense(spec.hidden_size, - kernel_initializer: kernel_initializer(spec), - name: join(name, "output"), - use_bias: spec.use_attention_bias - ) - - {hidden_state, attention_weights, attention_cache} - end - - defp repeat_states(state, 1), do: state + # Handle cross-attention (required by block interface but not used by Gemma 3) + {_hidden_state, cross_attention_info} = + steps.cross_attention_maybe.(hidden_state, fn _ -> + raise "cross attention not supported" + end) - defp repeat_states(state, times) do - Layers.repeat_interleave(state, times, axis: 2) + {hidden_state, attention_info, cross_attention_info} end defp gated_ffn(hidden_state, intermediate_size, output_size, opts) do @@ -702,9 +560,9 @@ defmodule Bumblebee.Text.Gemma3 do "decoder.blocks.{n}.self_attention.key" => "model.layers.{n}.self_attn.k_proj", "decoder.blocks.{n}.self_attention.value" => "model.layers.{n}.self_attn.v_proj", "decoder.blocks.{n}.self_attention.output" => "model.layers.{n}.self_attn.o_proj", - # QK-norm (Gemma 3 specific) - "decoder.blocks.{n}.self_attention.q_norm" => "model.layers.{n}.self_attn.q_norm", - "decoder.blocks.{n}.self_attention.k_norm" => "model.layers.{n}.self_attn.k_norm", + # QK-norm (Gemma 3 specific) - uses query_norm/key_norm from shared infrastructure + "decoder.blocks.{n}.self_attention.query_norm" => "model.layers.{n}.self_attn.q_norm", + "decoder.blocks.{n}.self_attention.key_norm" => "model.layers.{n}.self_attn.k_norm", # Layer norms "decoder.blocks.{n}.self_attention_norm" => "model.layers.{n}.input_layernorm", "decoder.blocks.{n}.post_attention_norm" => "model.layers.{n}.post_attention_layernorm", diff --git a/test/bumblebee/text/gemma3_test.exs b/test/bumblebee/text/gemma3_test.exs index 210f1a63..6d236017 100644 --- a/test/bumblebee/text/gemma3_test.exs +++ b/test/bumblebee/text/gemma3_test.exs @@ -19,6 +19,13 @@ defmodule Bumblebee.Text.Gemma3Test do outputs = Axon.predict(model, params, inputs) assert Nx.shape(outputs.hidden_state) == {1, 10, 32} + + assert_all_close( + outputs.hidden_state[[.., 1..3, 1..3]], + Nx.tensor([ + [[-1.6458, 0.7249, -0.5747], [-1.9452, -0.1602, -0.2329], [-2.3408, -0.4665, -0.1177]] + ]) + ) end test ":for_sequence_classification" do @@ -35,6 +42,11 @@ defmodule Bumblebee.Text.Gemma3Test do outputs = Axon.predict(model, params, inputs) assert Nx.shape(outputs.logits) == {1, 2} + + assert_all_close( + outputs.logits, + Nx.tensor([[-0.0060, -0.0212]]) + ) end test ":for_causal_language_modeling" do @@ -51,5 +63,12 @@ defmodule Bumblebee.Text.Gemma3Test do outputs = Axon.predict(model, params, inputs) assert Nx.shape(outputs.logits) == {1, 10, 1024} + + assert_all_close( + outputs.logits[[.., 1..3, 1..3]], + Nx.tensor([ + [[0.1472, 0.0633, 0.0922], [-0.1089, -0.0344, 0.0755], [0.0112, 0.1083, 0.1461]] + ]) + ) end end From f7c16c6e2fe77a1e9c25ce489fb8c645a32e136d Mon Sep 17 00:00:00 2001 From: Niko Maroulis Date: Mon, 29 Dec 2025 14:03:49 -0500 Subject: [PATCH 03/11] Add fine-tuning resources to FunctionGemma notebook --- notebooks/function_calling.livemd | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/notebooks/function_calling.livemd b/notebooks/function_calling.livemd index 028c2bea..e95b38e8 100644 --- a/notebooks/function_calling.livemd +++ b/notebooks/function_calling.livemd @@ -551,4 +551,9 @@ Here are some practical applications for FunctionGemma: - **Integrate** with your actual APIs and services - **Deploy** as a Phoenix LiveView application -For fine-tuning, check out [Google's FunctionGemma documentation](https://huggingface.co/google/functiongemma-270m-it) and the Bumblebee fine-tuning notebook. +## Fine-tuning FunctionGemma + +Want to fine-tune FunctionGemma on your own function schemas? Check out these resources: + +- [FunctionGemma Fine-tuning Notebook (Colab)](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/FunctionGemma_(270M)-Mobile-Actions.ipynb) - Step-by-step guide using Unsloth for efficient fine-tuning on Google Colab T4 +- [Google's FunctionGemma documentation](https://huggingface.co/google/functiongemma-270m-it) - Official model card and usage instructions From 9407bf4d16e39d521d479b9e01ba7478c38852fc Mon Sep 17 00:00:00 2001 From: Niko Maroulis Date: Wed, 31 Dec 2025 14:18:28 -0500 Subject: [PATCH 04/11] Rename Gemma3 to Gemma3Text and add attention_scale_base support - Rename Bumblebee.Text.Gemma3 to Bumblebee.Text.Gemma3Text to distinguish text-only model from future multimodal Gemma3 - Add attention_scale_base config option (from query_pre_attn_scalar) - Compute attention scale as attention_scale_base ** -0.5 - Update model mappings to use Gemma3Text* variants - Update tests to use bumblebee-testing models with Python reference values - Fix duplicate attention_window_size key in transformer.ex after merge --- lib/bumblebee.ex | 11 +++---- lib/bumblebee/layers/transformer.ex | 1 - .../text/{gemma3.ex => gemma3_text.ex} | 15 +++++++++- .../{gemma3_test.exs => gemma3_text_test.exs} | 29 +++++++++++-------- 4 files changed, 35 insertions(+), 21 deletions(-) rename lib/bumblebee/text/{gemma3.ex => gemma3_text.ex} (96%) rename test/bumblebee/text/{gemma3_test.exs => gemma3_text_test.exs} (61%) diff --git a/lib/bumblebee.ex b/lib/bumblebee.ex index bc05a018..5aedd06c 100644 --- a/lib/bumblebee.ex +++ b/lib/bumblebee.ex @@ -137,13 +137,11 @@ defmodule Bumblebee do "GemmaModel" => {Bumblebee.Text.Gemma, :base}, "GemmaForCausalLM" => {Bumblebee.Text.Gemma, :for_causal_language_modeling}, "GemmaForSequenceClassification" => {Bumblebee.Text.Gemma, :for_sequence_classification}, - "Gemma3Model" => {Bumblebee.Text.Gemma3, :base}, - "Gemma3ForCausalLM" => {Bumblebee.Text.Gemma3, :for_causal_language_modeling}, - "Gemma3ForSequenceClassification" => {Bumblebee.Text.Gemma3, :for_sequence_classification}, - "Gemma3TextModel" => {Bumblebee.Text.Gemma3, :base}, - "Gemma3TextForCausalLM" => {Bumblebee.Text.Gemma3, :for_causal_language_modeling}, + "Gemma3ForCausalLM" => {Bumblebee.Text.Gemma3Text, :for_causal_language_modeling}, + "Gemma3TextModel" => {Bumblebee.Text.Gemma3Text, :base}, + "Gemma3TextForCausalLM" => {Bumblebee.Text.Gemma3Text, :for_causal_language_modeling}, "Gemma3TextForSequenceClassification" => - {Bumblebee.Text.Gemma3, :for_sequence_classification}, + {Bumblebee.Text.Gemma3Text, :for_sequence_classification}, "GPT2ForSequenceClassification" => {Bumblebee.Text.Gpt2, :for_sequence_classification}, "GPT2ForTokenClassification" => {Bumblebee.Text.Gpt2, :for_token_classification}, "GPT2LMHeadModel" => {Bumblebee.Text.Gpt2, :for_causal_language_modeling}, @@ -259,7 +257,6 @@ defmodule Bumblebee do "camembert" => :camembert, "clip" => :clip, "gemma" => :gemma, - "gemma3" => :gemma, "gemma3_text" => :gemma, "gpt_neox" => :gpt_neo_x, "gpt2" => :gpt2, diff --git a/lib/bumblebee/layers/transformer.ex b/lib/bumblebee/layers/transformer.ex index ad9eb9be..188b0ffe 100644 --- a/lib/bumblebee/layers/transformer.ex +++ b/lib/bumblebee/layers/transformer.ex @@ -61,7 +61,6 @@ defmodule Bumblebee.Layers.Transformer do :output_use_bias, :layer_norm, :block_type, - :attention_window_size, :attention_scale, :query_norm, :key_norm diff --git a/lib/bumblebee/text/gemma3.ex b/lib/bumblebee/text/gemma3_text.ex similarity index 96% rename from lib/bumblebee/text/gemma3.ex rename to lib/bumblebee/text/gemma3_text.ex index 94331e3e..bc8d0df8 100644 --- a/lib/bumblebee/text/gemma3.ex +++ b/lib/bumblebee/text/gemma3_text.ex @@ -1,4 +1,4 @@ -defmodule Bumblebee.Text.Gemma3 do +defmodule Bumblebee.Text.Gemma3Text do alias Bumblebee.Shared options = @@ -30,6 +30,13 @@ defmodule Bumblebee.Text.Gemma3 do default: 256, doc: "the size of the key, value, and query projection per attention head" ], + attention_scale_base: [ + default: nil, + doc: """ + base value for computing attention scale. The attention scale is computed as + `attention_scale_base ** -0.5`. When `nil`, defaults to `:attention_head_size` + """ + ], num_blocks: [ default: 26, doc: "the number of Transformer blocks in the model" @@ -381,6 +388,10 @@ defmodule Bumblebee.Text.Gemma3 do gemma3_block_impl(hidden_state, steps, block_name, spec) end + # Compute attention scale from attention_scale_base (defaults to attention_head_size) + attention_scale_base = spec.attention_scale_base || spec.attention_head_size + attention_scale = :math.pow(attention_scale_base, -0.5) + Layers.Transformer.blocks(hidden_state, attention_mask: attention_mask, attention_head_mask: attention_head_mask, @@ -390,6 +401,7 @@ defmodule Bumblebee.Text.Gemma3 do num_key_value_heads: spec.num_key_value_heads, hidden_size: spec.hidden_size, attention_head_size: spec.attention_head_size, + attention_scale: attention_scale, kernel_initializer: kernel_initializer(spec), layer_norm: &Layers.rms_norm(&1, @@ -534,6 +546,7 @@ defmodule Bumblebee.Text.Gemma3 do num_attention_heads: {"num_attention_heads", number()}, num_key_value_heads: {"num_key_value_heads", number()}, attention_head_size: {"head_dim", number()}, + attention_scale_base: {"query_pre_attn_scalar", optional(number())}, intermediate_size: {"intermediate_size", number()}, activation: {"hidden_activation", activation()}, use_attention_bias: {"attention_bias", boolean()}, diff --git a/test/bumblebee/text/gemma3_test.exs b/test/bumblebee/text/gemma3_text_test.exs similarity index 61% rename from test/bumblebee/text/gemma3_test.exs rename to test/bumblebee/text/gemma3_text_test.exs index 6d236017..331c5f47 100644 --- a/test/bumblebee/text/gemma3_test.exs +++ b/test/bumblebee/text/gemma3_text_test.exs @@ -1,4 +1,4 @@ -defmodule Bumblebee.Text.Gemma3Test do +defmodule Bumblebee.Text.Gemma3TextTest do use ExUnit.Case, async: true import Bumblebee.TestHelpers @@ -7,9 +7,9 @@ defmodule Bumblebee.Text.Gemma3Test do test ":base" do assert {:ok, %{model: model, params: params, spec: spec}} = - Bumblebee.load_model({:hf, "nmaroulis/tiny-random-Gemma3Model"}) + Bumblebee.load_model({:hf, "bumblebee-testing/tiny-random-Gemma3TextModel"}) - assert %Bumblebee.Text.Gemma3{architecture: :base} = spec + assert %Bumblebee.Text.Gemma3Text{architecture: :base} = spec inputs = %{ "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), @@ -23,16 +23,19 @@ defmodule Bumblebee.Text.Gemma3Test do assert_all_close( outputs.hidden_state[[.., 1..3, 1..3]], Nx.tensor([ - [[-1.6458, 0.7249, -0.5747], [-1.9452, -0.1602, -0.2329], [-2.3408, -0.4665, -0.1177]] - ]) + [[-0.5691, 1.3813, -0.1463], [0.0754, 1.1590, -0.3055], [1.7564, 0.4456, -0.4530]] + ]), + atol: 0.2 ) end test ":for_sequence_classification" do assert {:ok, %{model: model, params: params, spec: spec}} = - Bumblebee.load_model({:hf, "nmaroulis/tiny-random-Gemma3ForSequenceClassification"}) + Bumblebee.load_model( + {:hf, "bumblebee-testing/tiny-random-Gemma3TextForSequenceClassification"} + ) - assert %Bumblebee.Text.Gemma3{architecture: :for_sequence_classification} = spec + assert %Bumblebee.Text.Gemma3Text{architecture: :for_sequence_classification} = spec inputs = %{ "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), @@ -45,15 +48,16 @@ defmodule Bumblebee.Text.Gemma3Test do assert_all_close( outputs.logits, - Nx.tensor([[-0.0060, -0.0212]]) + Nx.tensor([[0.0366, -0.0045]]), + atol: 0.1 ) end test ":for_causal_language_modeling" do assert {:ok, %{model: model, params: params, spec: spec}} = - Bumblebee.load_model({:hf, "nmaroulis/tiny-random-Gemma3ForCausalLM"}) + Bumblebee.load_model({:hf, "bumblebee-testing/tiny-random-Gemma3ForCausalLM"}) - assert %Bumblebee.Text.Gemma3{architecture: :for_causal_language_modeling} = spec + assert %Bumblebee.Text.Gemma3Text{architecture: :for_causal_language_modeling} = spec inputs = %{ "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), @@ -67,8 +71,9 @@ defmodule Bumblebee.Text.Gemma3Test do assert_all_close( outputs.logits[[.., 1..3, 1..3]], Nx.tensor([ - [[0.1472, 0.0633, 0.0922], [-0.1089, -0.0344, 0.0755], [0.0112, 0.1083, 0.1461]] - ]) + [[0.0114, -0.0579, -0.1748], [-0.0151, -0.1486, -0.1722], [-0.1478, -0.0452, -0.1211]] + ]), + atol: 0.02 ) end end From d4d3bffe880c3d8d6323c6e7d6ef99ce1d50f883 Mon Sep 17 00:00:00 2001 From: Niko Maroulis Date: Wed, 31 Dec 2025 14:33:40 -0500 Subject: [PATCH 05/11] Update lib/bumblebee/text/gemma3_text.ex MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Jonatan Kłosko --- lib/bumblebee/text/gemma3_text.ex | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/lib/bumblebee/text/gemma3_text.ex b/lib/bumblebee/text/gemma3_text.ex index bc8d0df8..58a8f3df 100644 --- a/lib/bumblebee/text/gemma3_text.ex +++ b/lib/bumblebee/text/gemma3_text.ex @@ -388,9 +388,7 @@ defmodule Bumblebee.Text.Gemma3Text do gemma3_block_impl(hidden_state, steps, block_name, spec) end - # Compute attention scale from attention_scale_base (defaults to attention_head_size) - attention_scale_base = spec.attention_scale_base || spec.attention_head_size - attention_scale = :math.pow(attention_scale_base, -0.5) + attention_scale = :math.pow(spec.attention_scale_base, -0.5) Layers.Transformer.blocks(hidden_state, attention_mask: attention_mask, From 91e74d07e83373757454cb15bfbc91306c724539 Mon Sep 17 00:00:00 2001 From: Niko Maroulis Date: Wed, 31 Dec 2025 14:34:06 -0500 Subject: [PATCH 06/11] Update lib/bumblebee/text/gemma3_text.ex MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Jonatan Kłosko --- lib/bumblebee/text/gemma3_text.ex | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/bumblebee/text/gemma3_text.ex b/lib/bumblebee/text/gemma3_text.ex index 58a8f3df..ab027115 100644 --- a/lib/bumblebee/text/gemma3_text.ex +++ b/lib/bumblebee/text/gemma3_text.ex @@ -413,7 +413,7 @@ defmodule Bumblebee.Text.Gemma3Text do name: &2, activation: spec.activation ), - block_type: block_type, + block_type: &gemma3_block_impl(&1, &2, &3, spec), causal: true, rotary_embedding: [ position_ids: position_ids, From a2bf4a58ed94791a98423e147255a831e8aadfb5 Mon Sep 17 00:00:00 2001 From: Niko Maroulis Date: Wed, 31 Dec 2025 14:34:37 -0500 Subject: [PATCH 07/11] Update lib/bumblebee/text/gemma3_text.ex MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Jonatan Kłosko --- lib/bumblebee/text/gemma3_text.ex | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/bumblebee/text/gemma3_text.ex b/lib/bumblebee/text/gemma3_text.ex index ab027115..be022ef0 100644 --- a/lib/bumblebee/text/gemma3_text.ex +++ b/lib/bumblebee/text/gemma3_text.ex @@ -544,7 +544,7 @@ defmodule Bumblebee.Text.Gemma3Text do num_attention_heads: {"num_attention_heads", number()}, num_key_value_heads: {"num_key_value_heads", number()}, attention_head_size: {"head_dim", number()}, - attention_scale_base: {"query_pre_attn_scalar", optional(number())}, + attention_scale_base: {"query_pre_attn_scalar", number()}, intermediate_size: {"intermediate_size", number()}, activation: {"hidden_activation", activation()}, use_attention_bias: {"attention_bias", boolean()}, From a8ffd0d305601c8e285e6f08279aa7bdb0080e73 Mon Sep 17 00:00:00 2001 From: Niko Maroulis Date: Wed, 31 Dec 2025 14:35:09 -0500 Subject: [PATCH 08/11] Update lib/bumblebee/text/gemma3_text.ex MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Jonatan Kłosko --- lib/bumblebee/text/gemma3_text.ex | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/bumblebee/text/gemma3_text.ex b/lib/bumblebee/text/gemma3_text.ex index be022ef0..c07b96ba 100644 --- a/lib/bumblebee/text/gemma3_text.ex +++ b/lib/bumblebee/text/gemma3_text.ex @@ -31,10 +31,10 @@ defmodule Bumblebee.Text.Gemma3Text do doc: "the size of the key, value, and query projection per attention head" ], attention_scale_base: [ - default: nil, + default: 256, doc: """ base value for computing attention scale. The attention scale is computed as - `attention_scale_base ** -0.5`. When `nil`, defaults to `:attention_head_size` + `attention_scale_base ** -0.5`. """ ], num_blocks: [ From 43564dbc257e11084218009be21c9eec6501d8ef Mon Sep 17 00:00:00 2001 From: Niko Maroulis Date: Thu, 1 Jan 2026 08:38:15 -0500 Subject: [PATCH 09/11] Add layer_types support and update test reference values - Add layer_types config option for per-layer attention type (sliding vs full) - Generate layer_types from sliding_window_pattern for backward compatibility - Update test expected values with Python reference outputs - Add tolerances to tests (base model needs larger tolerance, see TODO) --- lib/bumblebee/text/gemma3_text.ex | 68 ++++++++++++++++++------ test/bumblebee/text/gemma3_text_test.exs | 14 ++--- 2 files changed, 60 insertions(+), 22 deletions(-) diff --git a/lib/bumblebee/text/gemma3_text.ex b/lib/bumblebee/text/gemma3_text.ex index c07b96ba..ff7f0964 100644 --- a/lib/bumblebee/text/gemma3_text.ex +++ b/lib/bumblebee/text/gemma3_text.ex @@ -87,12 +87,12 @@ defmodule Bumblebee.Text.Gemma3Text do default: 4096, doc: "the sliding window size for local attention layers" ], - global_attention_layer_interval: [ - default: 6, + layer_types: [ + default: nil, doc: """ - the interval for global attention layers. In Gemma 3, every Nth layer uses global - attention while others use local (sliding window) attention. A value of 6 means - layers 5, 11, 17, 23... use global attention (5:1 local/global ratio) + a list of layer types for each layer, where each element is either `:sliding_attention` + (local attention with sliding window) or `:full_attention` (global attention). + If not provided, will be computed from `sliding_window_pattern`. """ ], tie_word_embeddings: [ @@ -373,21 +373,18 @@ defmodule Bumblebee.Text.Gemma3Text do query_norm = &Layers.rms_norm(&1, shift: 1.0, epsilon: spec.layer_norm_epsilon, name: &2) key_norm = &Layers.rms_norm(&1, shift: 1.0, epsilon: spec.layer_norm_epsilon, name: &2) - # Per-layer attention window size for alternating local/global attention - # Every global_attention_layer_interval-th layer uses global attention + # Per-layer attention window size based on layer_types + # :sliding_attention uses local (sliding window) attention + # :full_attention uses global attention (nil window size) + layer_types = spec.layer_types || generate_layer_types(spec.num_blocks) + attention_window_size = fn idx -> - if rem(idx + 1, spec.global_attention_layer_interval) == 0 do - nil - else - {spec.sliding_window, spec.sliding_window} + case Enum.at(layer_types, idx, :sliding_attention) do + :full_attention -> nil + :sliding_attention -> {spec.sliding_window, spec.sliding_window} end end - # Custom block_type function for Gemma 3's unique block structure - block_type = fn hidden_state, steps, block_name -> - gemma3_block_impl(hidden_state, steps, block_name, spec) - end - attention_scale = :math.pow(spec.attention_scale_base, -0.5) Layers.Transformer.blocks(hidden_state, @@ -518,6 +515,18 @@ defmodule Bumblebee.Text.Gemma3Text do Axon.Initializers.normal(scale: spec.initializer_scale) end + # Generate layer_types from sliding_window_pattern (default 6) + # Pattern: every Nth layer uses full attention, others use sliding attention + defp generate_layer_types(num_blocks, sliding_window_pattern \\ 6) do + Enum.map(0..(num_blocks - 1), fn i -> + if rem(i + 1, sliding_window_pattern) == 0 do + :full_attention + else + :sliding_attention + end + end) + end + defimpl Bumblebee.HuggingFace.Transformers.Config do def load(spec, data) do import Shared.Converters @@ -535,6 +544,32 @@ defmodule Bumblebee.Text.Gemma3Text do end end + # Support sliding_window_pattern for backward compatibility + # see https://github.com/huggingface/transformers/blob/v5.0.0rc1/src/transformers/models/gemma3/configuration_gemma3.py#L188-L195 + data = + Map.put_new_lazy(data, "layer_types", fn -> + pattern = data["sliding_window_pattern"] || 6 + num_blocks = data["num_hidden_layers"] || 26 + + Enum.map(0..(num_blocks - 1), fn i -> + if rem(i + 1, pattern) == 0 do + "full_attention" + else + "sliding_attention" + end + end) + end) + + layer_types_converter = fn _name, value -> + types = + Enum.map(value, fn + "sliding_attention" -> :sliding_attention + "full_attention" -> :full_attention + end) + + {:ok, types} + end + opts = convert!(data, vocab_size: {"vocab_size", number()}, @@ -554,6 +589,7 @@ defmodule Bumblebee.Text.Gemma3Text do initializer_scale: {"initializer_range", number()}, layer_norm_epsilon: {"rms_norm_eps", number()}, sliding_window: {"sliding_window", optional(number())}, + layer_types: {"layer_types", layer_types_converter}, tie_word_embeddings: {"tie_word_embeddings", boolean()} ) ++ Shared.common_options_from_transformers(data, spec) diff --git a/test/bumblebee/text/gemma3_text_test.exs b/test/bumblebee/text/gemma3_text_test.exs index 331c5f47..9cc5f713 100644 --- a/test/bumblebee/text/gemma3_text_test.exs +++ b/test/bumblebee/text/gemma3_text_test.exs @@ -20,12 +20,14 @@ defmodule Bumblebee.Text.Gemma3TextTest do assert Nx.shape(outputs.hidden_state) == {1, 10, 32} + # TODO: Larger tolerance needed for base model - investigate discrepancy + # First position matches well (~0.003), but positions 2-3 diverge more (~0.2-0.3) assert_all_close( outputs.hidden_state[[.., 1..3, 1..3]], Nx.tensor([ - [[-0.5691, 1.3813, -0.1463], [0.0754, 1.1590, -0.3055], [1.7564, 0.4456, -0.4530]] + [[-0.2461, 1.2074, 0.7663], [0.0675, 0.3987, 1.6659], [-0.3021, 0.8062, 1.0309]] ]), - atol: 0.2 + atol: 0.35 ) end @@ -48,8 +50,8 @@ defmodule Bumblebee.Text.Gemma3TextTest do assert_all_close( outputs.logits, - Nx.tensor([[0.0366, -0.0045]]), - atol: 0.1 + Nx.tensor([[-0.0145, 0.1376]]), + atol: 0.02 ) end @@ -71,9 +73,9 @@ defmodule Bumblebee.Text.Gemma3TextTest do assert_all_close( outputs.logits[[.., 1..3, 1..3]], Nx.tensor([ - [[0.0114, -0.0579, -0.1748], [-0.0151, -0.1486, -0.1722], [-0.1478, -0.0452, -0.1211]] + [[-0.0488, 0.0432, -0.0531], [-0.1553, -0.0812, 0.1153], [-0.0272, 0.1216, 0.0129]] ]), - atol: 0.02 + atol: 0.025 ) end end From c2802b8b958e1da4d0267d57f9f7fda234fedc1f Mon Sep 17 00:00:00 2001 From: Niko Maroulis Date: Thu, 1 Jan 2026 09:55:23 -0500 Subject: [PATCH 10/11] Add per-layer rotary embedding base for Gemma 3 Gemma 3 uses different RoPE bases for local vs global attention: - Local (sliding) attention: rope_local_base_freq = 10,000 - Global (full) attention: rope_theta = 1,000,000 This fixes numerical discrepancies where positions 2+ diverged by ~0.2-0.3 from Python reference values. Now achieves 4-digit precision (atol: 1.0e-4) across all positions. Changes: - Add rotary_embedding_base_local spec option - Load rope_local_base_freq from HuggingFace config - Use per-layer rotary embedding base based on layer type - Tighten test tolerances from 0.35 to 1.0e-4 --- lib/bumblebee/text/gemma3_text.ex | 31 ++++++++++++++++++------ test/bumblebee/text/gemma3_text_test.exs | 8 +++--- 2 files changed, 27 insertions(+), 12 deletions(-) diff --git a/lib/bumblebee/text/gemma3_text.ex b/lib/bumblebee/text/gemma3_text.ex index ff7f0964..a470ed5f 100644 --- a/lib/bumblebee/text/gemma3_text.ex +++ b/lib/bumblebee/text/gemma3_text.ex @@ -54,8 +54,12 @@ defmodule Bumblebee.Text.Gemma3Text do doc: "the activation function" ], rotary_embedding_base: [ + default: 1_000_000, + doc: "base for computing rotary embedding frequency for global attention layers" + ], + rotary_embedding_base_local: [ default: 10_000, - doc: "base for computing rotary embedding frequency" + doc: "base for computing rotary embedding frequency for local (sliding) attention layers" ], rotary_embedding_scaling_strategy: [ default: nil, @@ -385,6 +389,23 @@ defmodule Bumblebee.Text.Gemma3Text do end end + # Per-layer rotary embedding base: local layers use rotary_embedding_base_local, + # global layers use rotary_embedding_base + rotary_embedding = fn idx -> + base = + case Enum.at(layer_types, idx, :sliding_attention) do + :full_attention -> spec.rotary_embedding_base + :sliding_attention -> spec.rotary_embedding_base_local + end + + [ + position_ids: position_ids, + max_positions: spec.max_positions, + base: base, + scaling_strategy: spec.rotary_embedding_scaling_strategy + ] + end + attention_scale = :math.pow(spec.attention_scale_base, -0.5) Layers.Transformer.blocks(hidden_state, @@ -412,12 +433,7 @@ defmodule Bumblebee.Text.Gemma3Text do ), block_type: &gemma3_block_impl(&1, &2, &3, spec), causal: true, - rotary_embedding: [ - position_ids: position_ids, - max_positions: spec.max_positions, - base: spec.rotary_embedding_base, - scaling_strategy: spec.rotary_embedding_scaling_strategy - ], + rotary_embedding: rotary_embedding, attention_window_size: attention_window_size, query_norm: query_norm, key_norm: key_norm, @@ -584,6 +600,7 @@ defmodule Bumblebee.Text.Gemma3Text do activation: {"hidden_activation", activation()}, use_attention_bias: {"attention_bias", boolean()}, rotary_embedding_base: {"rope_theta", number()}, + rotary_embedding_base_local: {"rope_local_base_freq", number()}, rotary_embedding_scaling_strategy: {"rope_scaling", optional(scaling_strategy_converter)}, initializer_scale: {"initializer_range", number()}, diff --git a/test/bumblebee/text/gemma3_text_test.exs b/test/bumblebee/text/gemma3_text_test.exs index 9cc5f713..6af3ce6a 100644 --- a/test/bumblebee/text/gemma3_text_test.exs +++ b/test/bumblebee/text/gemma3_text_test.exs @@ -20,14 +20,12 @@ defmodule Bumblebee.Text.Gemma3TextTest do assert Nx.shape(outputs.hidden_state) == {1, 10, 32} - # TODO: Larger tolerance needed for base model - investigate discrepancy - # First position matches well (~0.003), but positions 2-3 diverge more (~0.2-0.3) assert_all_close( outputs.hidden_state[[.., 1..3, 1..3]], Nx.tensor([ [[-0.2461, 1.2074, 0.7663], [0.0675, 0.3987, 1.6659], [-0.3021, 0.8062, 1.0309]] ]), - atol: 0.35 + atol: 1.0e-4 ) end @@ -51,7 +49,7 @@ defmodule Bumblebee.Text.Gemma3TextTest do assert_all_close( outputs.logits, Nx.tensor([[-0.0145, 0.1376]]), - atol: 0.02 + atol: 1.0e-4 ) end @@ -75,7 +73,7 @@ defmodule Bumblebee.Text.Gemma3TextTest do Nx.tensor([ [[-0.0488, 0.0432, -0.0531], [-0.1553, -0.0812, 0.1153], [-0.0272, 0.1216, 0.0129]] ]), - atol: 0.025 + atol: 1.0e-4 ) end end From d38cd628294458e88c7da3df232b4c44d418dbfd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Tue, 6 Jan 2026 00:08:56 +0100 Subject: [PATCH 11/11] up --- lib/bumblebee/text/gemma3_text.ex | 26 +++++++++---------- mix.exs | 1 + notebooks/function_calling.livemd | 42 +++---------------------------- 3 files changed, 16 insertions(+), 53 deletions(-) diff --git a/lib/bumblebee/text/gemma3_text.ex b/lib/bumblebee/text/gemma3_text.ex index a470ed5f..3322ab5f 100644 --- a/lib/bumblebee/text/gemma3_text.ex +++ b/lib/bumblebee/text/gemma3_text.ex @@ -95,8 +95,7 @@ defmodule Bumblebee.Text.Gemma3Text do default: nil, doc: """ a list of layer types for each layer, where each element is either `:sliding_attention` - (local attention with sliding window) or `:full_attention` (global attention). - If not provided, will be computed from `sliding_window_pattern`. + (local attention with sliding window) or `:full_attention` (global attention) """ ], tie_word_embeddings: [ @@ -533,7 +532,9 @@ defmodule Bumblebee.Text.Gemma3Text do # Generate layer_types from sliding_window_pattern (default 6) # Pattern: every Nth layer uses full attention, others use sliding attention - defp generate_layer_types(num_blocks, sliding_window_pattern \\ 6) do + defp generate_layer_types(num_blocks) do + sliding_window_pattern = 6 + Enum.map(0..(num_blocks - 1), fn i -> if rem(i + 1, sliding_window_pattern) == 0 do :full_attention @@ -576,16 +577,6 @@ defmodule Bumblebee.Text.Gemma3Text do end) end) - layer_types_converter = fn _name, value -> - types = - Enum.map(value, fn - "sliding_attention" -> :sliding_attention - "full_attention" -> :full_attention - end) - - {:ok, types} - end - opts = convert!(data, vocab_size: {"vocab_size", number()}, @@ -606,7 +597,14 @@ defmodule Bumblebee.Text.Gemma3Text do initializer_scale: {"initializer_range", number()}, layer_norm_epsilon: {"rms_norm_eps", number()}, sliding_window: {"sliding_window", optional(number())}, - layer_types: {"layer_types", layer_types_converter}, + layer_types: + {"layer_types", + list( + mapping(%{ + "sliding_attention" => :sliding_attention, + "full_attention" => :full_attention + }) + )}, tie_word_embeddings: {"tie_word_embeddings", boolean()} ) ++ Shared.common_options_from_transformers(data, spec) diff --git a/mix.exs b/mix.exs index 44210c13..a768a59f 100644 --- a/mix.exs +++ b/mix.exs @@ -64,6 +64,7 @@ defmodule Bumblebee.MixProject do "notebooks/llms.livemd", "notebooks/llms_rag.livemd", "notebooks/qwen3.livemd", + "notebooks/function_calling.livemd", "notebooks/fine_tuning.livemd", "examples/phoenix/README.md" ], diff --git a/notebooks/function_calling.livemd b/notebooks/function_calling.livemd index e95b38e8..d9752bc9 100644 --- a/notebooks/function_calling.livemd +++ b/notebooks/function_calling.livemd @@ -1,4 +1,4 @@ -# Function Calling with FunctionGemma +# Function calling with FunctionGemma ```elixir Mix.install([ @@ -8,36 +8,12 @@ Mix.install([ {:kino, "~> 0.14.0"} ]) -Nx.global_default_backend({EXLA.Backend, client: :host}) +Nx.global_default_backend(EXLA.Backend) ``` ## Why FunctionGemma? -[FunctionGemma](https://huggingface.co/google/functiongemma-270m-it) is a compact 270M parameter model from Google, specifically designed for function calling tasks. Here's why it's a great choice: - -### Lightweight & Fast - -- **Only 270M parameters** - runs efficiently on CPU or modest GPU hardware -- **Quick inference** - ideal for real-time applications -- **Low memory footprint** - can run alongside other models - -### Perfect for Edge & IoT - -- **Home Assistants** - control smart home devices with natural language -- **Voice Interfaces** - parse voice commands into structured function calls -- **Embedded Systems** - fits on devices with limited resources - -### Easy to Fine-tune - -- **Train on Google Colab T4** - the small size makes fine-tuning accessible -- **Custom function schemas** - adapt to your specific API or tool set -- **Fast iteration** - experiment and deploy quickly - -### Structured Output - -- **Reliable function call format** - consistent `` / `` markers -- **Easy to parse** - extract function name and arguments programmatically -- **Multi-turn support** - handles function responses and follow-up calls +[FunctionGemma](https://huggingface.co/google/functiongemma-270m-it) is a compact 270M parameter model from Google, specifically designed for function calling tasks. ## Loading the Model @@ -532,18 +508,6 @@ IO.puts("=== Final Smart Home State ===") SmartHome.get_status() |> IO.inspect(pretty: true) ``` -## Use Cases - -Here are some practical applications for FunctionGemma: - -| Use Case | Example Functions | -|----------|-------------------| -| **Smart Home** | control_light, set_thermostat, lock_door, play_music | -| **Calendar** | create_event, list_events, reschedule_meeting | -| **E-commerce** | search_products, add_to_cart, check_order_status | -| **Database** | query_users, update_record, generate_report | -| **DevOps** | deploy_service, check_status, scale_instances | - ## Next Steps - **Fine-tune** on your specific function schemas for better accuracy