Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions lib/bumblebee.ex
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,8 @@ defmodule Bumblebee do
:params_filename,
:log_params_diff,
:backend,
:type
:type,
:preserve_source_types
])

with {:ok, repo_files} <- get_repo_files(repository),
Expand Down Expand Up @@ -654,7 +655,7 @@ defmodule Bumblebee do
[
params_mapping: params_mapping,
loader_fun: loader_fun
] ++ Keyword.take(opts, [:backend, :log_params_diff])
] ++ Keyword.take(opts, [:backend, :log_params_diff, :preserve_source_types])

params = Bumblebee.Conversion.PyTorchParams.load_params!(model, input_template, paths, opts)
{:ok, params}
Expand Down
39 changes: 31 additions & 8 deletions lib/bumblebee/conversion/pytorch_params.ex
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ defmodule Bumblebee.Conversion.PyTorchParams do
and loads the params file. Defaults to
`Bumblebee.Conversion.PyTorchLoader.load!/1`

* `:preserve_source_types` - when `true`, preserves FP8 types from the
source file instead of converting them to the model's expected type.
This is useful for loading quantized models that use FP8 weights.
Defaults to `false`

"""
@spec load_params!(Axon.t(), map(), Path.t() | list(Path.t()), keyword()) :: %Axon.ModelState{}
def load_params!(model, input_template, path, opts \\ []) do
Expand All @@ -36,6 +41,7 @@ defmodule Bumblebee.Conversion.PyTorchParams do
|> Keyword.validate!([
:log_params_diff,
:backend,
:preserve_source_types,
params_mapping: %{},
loader_fun: &Bumblebee.Conversion.PyTorchLoader.load!/1
])
Expand All @@ -58,7 +64,17 @@ defmodule Bumblebee.Conversion.PyTorchParams do
model_state = Axon.trace_init(model, input_template)

params_expr = model_state.data
{params, diff} = init_params(model, params_expr, pytorch_state, opts[:params_mapping])
preserve_source_types = opts[:preserve_source_types] || false

{params, diff} =
init_params(
model,
params_expr,
pytorch_state,
opts[:params_mapping],
preserve_source_types
)

model_state = %{model_state | data: params}

params_complete? = diff.missing == [] and diff.mismatched == []
Expand Down Expand Up @@ -95,15 +111,20 @@ defmodule Bumblebee.Conversion.PyTorchParams do
Nx.Container.impl_for(value) != nil
end

defp init_params(model, params_expr, pytorch_state, params_mapping) do
defp init_params(model, params_expr, pytorch_state, params_mapping, preserve_source_types) do
layers =
model
|> Utils.Axon.nodes_with_names()
|> Enum.filter(fn {layer, _name} -> layer.parameters != [] end)

prefixes = infer_prefixes(layers, pytorch_state, params_mapping)

diff = %{missing: [], mismatched: [], used_keys: []}
diff = %{
missing: [],
mismatched: [],
used_keys: [],
preserve_source_types: preserve_source_types
}

{params, diff} =
layers
Expand Down Expand Up @@ -155,7 +176,7 @@ defmodule Bumblebee.Conversion.PyTorchParams do

case verify_param_shape(param_expr, value) do
:ok ->
value = ensure_type(param_expr, value)
value = ensure_type(param_expr, value, diff.preserve_source_types)
{value, diff}

{:error, expected, actual} ->
Expand Down Expand Up @@ -507,11 +528,13 @@ defmodule Bumblebee.Conversion.PyTorchParams do
Utils.Nx.map(expr, &Nx.shape/1)
end

defp ensure_type(param_expr, value) do
defp ensure_type(param_expr, value, preserve_source_types) do
Utils.Nx.zip_with(param_expr, value, fn expr, tensor ->
case {Nx.type(expr), Nx.type(tensor)} do
{type, type} -> tensor
{expected, _actual} -> Nx.as_type(tensor, expected)
case {Nx.type(expr), Nx.type(tensor), preserve_source_types} do
{type, type, _} -> tensor
# Preserve FP8 E4M3FN types when preserve_source_types is enabled
{_expected, {:f8_e4m3fn, 8}, true} -> tensor
Comment on lines +535 to +536
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We likely don't want to do this here, because Axon may cast and it can lead to inconsistent behaviour (see #311). Ideally we want to apply an Axon.MixedPrecision policy, but we cannot determine it upfront. Also Axon policies apply per layer, but in this case we may have a layer where each param has different type. I need to think about the best way to address it and the loading API we should have.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@seanmor5 do you have any thoughts on how to handle layers where parameters have different types, as part of Axon.MixedPrecision?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cladios suggested something like below

Allow the params and compute fields of a policy to be either:

  • A type tuple (current behavior) — e.g., {:f, 16} — applies uniformly to all params
  • A map — e.g., %{"kernel" => {:f8_e4m3fn, 8}, :default => {:f, 32}} — per-parameter types
  • nil — no casting (current behavior)

The output field stays as-is (type or nil) since it applies to layer outputs, not individual params.

API Example

Existing API (unchanged)

policy = Axon.MixedPrecision.create_policy(params: {:bf, 16}, compute: {:bf, 16}, output: {:f, 32})

New: per-parameter types

policy = Axon.MixedPrecision.create_policy(
params: %{"kernel" => {:f8_e4m3fn, 8}, :default => {:f, 32}},
compute: %{"kernel" => nil, "scale_inv" => nil, :default => {:f, 32}},
output: {:f, 32}
)

A nil value for a specific param in the map means "don't cast this parameter" (the layer function handles it). The :default key provides the fallback type for params not
explicitly listed. If :default is absent, unlisted params are not cast.

Implementation

  1. lib/axon/mixed_precision.ex — Add cast/4 and update cast/3

Add a new cast/4 function that accepts a param_name:

def cast(%Policy{} = policy, tensor_or_container, variable_type, param_name)

Uses a resolve_type/2 helper to look up the type:

defp resolve_type(nil, _name), do: nil
defp resolve_type(%{} = map, name), do: Map.get(map, name, Map.get(map, :default))
defp resolve_type(type, _name), do: type

Update cast/3 to handle map values by falling back to :default (for layer input/output casts where no param name is available).

  1. lib/axon/mixed_precision/policy.ex — Update Inspect

Handle map values in the inspect protocol. When a field is a map, display it as {per-param} or show the map contents.

  1. lib/axon/compiler.ex — Two changes

a) Init path (line 1076, layer_init_fun):

Currently destructures %{params: dtype} and passes a single dtype to init_param. Change init_param to resolve per-parameter types:

line 1130, change from:

dtype = dtype || Nx.type(template)

to:

dtype = resolve_param_type(params_policy, name) || Nx.type(template)

Add resolve_param_type/2 (same logic as resolve_type above).

b) Predict path (line 920-926, parameter casting):

Pass the parameter name through to safe_policy_cast:

line 926, change from:

safe_policy_cast(maybe_freeze(param, frz), policy, :compute)

to:

safe_policy_cast(maybe_freeze(param, frz), policy, :compute, v)

Add an optional param_name argument to safe_policy_cast that delegates to cast/4.

  1. Tests

Add compiler tests for:

  • Per-parameter init types (kernel gets one type, bias gets another)
  • Per-parameter compute types (some params cast, others preserved)
  • :default key behavior
  • Backward compatibility (existing uniform type policies still work)

Files to Modify

┌────────────────────────────────────┬───────────────────────────────────────────────────┐
│ File │ Change │
├────────────────────────────────────┼───────────────────────────────────────────────────┤
│ lib/axon/mixed_precision.ex │ Add cast/4, update cast/3, add resolve_type/2 │
├────────────────────────────────────┼───────────────────────────────────────────────────┤
│ lib/axon/mixed_precision/policy.ex │ Update Inspect for map values │
├────────────────────────────────────┼───────────────────────────────────────────────────┤
│ lib/axon/compiler.ex │ Per-param type resolution in init + predict paths │
├────────────────────────────────────┼───────────────────────────────────────────────────┤
│ test/axon/compiler_test.exs │ Tests for per-parameter mixed precision │
└────────────────────────────────────┴───────────────────────────────────────────────────┘

Verification

  1. mix test — all existing tests pass (backward compat)
  2. mix test test/axon/compiler_test.exs — new per-parameter tests pass
  3. Manual verification: create a model with per-param policy, confirm params have expected types after init and during compute

{expected, _actual, _} -> Nx.as_type(tensor, expected)
end
end)
end
Expand Down
122 changes: 122 additions & 0 deletions lib/bumblebee/layers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,128 @@ defmodule Bumblebee.Layers do
|> Nx.add(bias)
end

@doc """
Adds an FP8-aware dense layer to the network.

This layer supports optional scale_inv parameter for FP8 quantized weights.
When scale_inv is provided, it's applied to the matmul output to account
for FP8 quantization scaling.

The kernel parameter uses standard dense layout (transposed from PyTorch).

## Options

* `:name` - layer name

* `:kernel_initializer` - initializer for `kernel` weights.
Defaults to `:glorot_uniform`

* `:use_bias` - whether the layer should add bias to the output.
Defaults to `false`

* `:block_size` - the block size used for FP8 quantization.
Defaults to 128

"""
def fp8_aware_dense(%Axon{} = x, units, opts \\ []) do
opts =
Keyword.validate!(opts, [
:name,
kernel_initializer: :glorot_uniform,
use_bias: false,
block_size: 128
])

name = opts[:name]
block_size = opts[:block_size]

kernel_shape = &Axon.Shape.dense_kernel(&1, units)
bias_shape = &Axon.Shape.dense_bias(&1, units)

# Scale shape: [input_blocks, output_blocks] where block_size is typically 128
# This matches the transposed layout from PyTorch (kernel is transposed, so is scale)
# For non-FP8 models, scale_inv will be initialized to 1.0
scale_shape = fn input_shape ->
in_features = elem(input_shape, tuple_size(input_shape) - 1)
out_features = units
# Round up to handle cases where dimensions aren't exact multiples of block_size
out_blocks = div(out_features + block_size - 1, block_size)
in_blocks = div(in_features + block_size - 1, block_size)
# Note: [in_blocks, out_blocks] to match transposed scale_inv from PyTorch
{in_blocks, out_blocks}
end

kernel = Axon.param("kernel", kernel_shape, initializer: opts[:kernel_initializer])

# scale_inv is initialized to 1.0 (identity) for non-FP8 models
# For FP8 models, it will be loaded from the checkpoint
scale_inv = Axon.param("scale_inv", scale_shape, initializer: :ones)

{inputs, op} =
if opts[:use_bias] do
bias = Axon.param("bias", bias_shape, initializer: :zeros)
{[x, kernel, scale_inv, bias], &fp8_aware_dense_impl(&1, &2, &3, &4, &5, block_size)}
else
{[x, kernel, scale_inv], &fp8_aware_dense_impl(&1, &2, &3, nil, &4, block_size)}
end

Axon.layer(op, inputs, name: name, op_name: :fp8_aware_dense)
end

deftransformp fp8_aware_dense_impl(x, kernel, scale_inv, bias, _opts, block_size) do
# Dequantize the kernel using scale_inv before matmul
# kernel: [in_features, out_features]
# scale_inv: [in_blocks, out_blocks] (transposed from PyTorch layout)
# Each 128x128 block of the kernel should be multiplied by its scale
kernel_dequant = dequantize_kernel(kernel, scale_inv, block_size)

# Do the matmul with dequantized kernel
# x: [batch, seq_len, in_features]
# kernel_dequant: [in_features, out_features]
# result: [batch, seq_len, out_features]
result = Nx.dot(x, [-1], kernel_dequant, [0])

# Add bias if present
if bias do
Nx.add(result, bias)
else
result
end
end

defp dequantize_kernel(kernel, scale_inv, block_size) do
# kernel: [in_features, out_features]
# scale_inv: [in_blocks, out_blocks] where in_blocks = ceil(in_features/128)
#
# To dequantize: for each element kernel[i,o], multiply by scale_inv[i/128, o/128]
# This is done by expanding scale_inv to match kernel shape

{in_features, out_features} = Nx.shape(kernel)
{in_blocks, out_blocks} = Nx.shape(scale_inv)

# Expand scale_inv to [in_features, out_features]
# Each scale value is replicated block_size times in both dimensions
scale_expanded =
scale_inv
# Replicate along input dimension: [in_blocks, out_blocks] -> [in_blocks * block_size, out_blocks]
|> Nx.reshape({in_blocks, 1, out_blocks})
|> Nx.broadcast({in_blocks, block_size, out_blocks})
|> Nx.reshape({in_blocks * block_size, out_blocks})
# Replicate along output dimension: [..., out_blocks] -> [..., out_blocks * block_size]
|> Nx.reshape({in_blocks * block_size, out_blocks, 1})
|> Nx.broadcast({in_blocks * block_size, out_blocks, block_size})
|> Nx.reshape({in_blocks * block_size, out_blocks * block_size})

# Slice to exact kernel dimensions (in case they're not exact multiples of block_size)
scale_expanded =
scale_expanded
|> Nx.slice([0, 0], [in_features, out_features])

# Convert kernel to higher precision for dequantization, then multiply by scale
kernel_f32 = Nx.as_type(kernel, {:f, 32})
Nx.multiply(kernel_f32, scale_expanded)
end

@doc """
Adds a 1-dimensional convolution layer to the network.

Expand Down
30 changes: 23 additions & 7 deletions lib/bumblebee/layers/transformer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ defmodule Bumblebee.Layers.Transformer do
:block_type,
:attention_scale,
:query_norm,
:key_norm
:key_norm,
:attention_dense
]

opts =
Expand Down Expand Up @@ -354,7 +355,8 @@ defmodule Bumblebee.Layers.Transformer do
attention_scale: nil,
rotary_embedding: nil,
query_norm: nil,
key_norm: nil
key_norm: nil,
attention_dense: nil
])

name = opts[:name]
Expand Down Expand Up @@ -386,6 +388,7 @@ defmodule Bumblebee.Layers.Transformer do
rotary_embedding = opts[:rotary_embedding]
query_norm = opts[:query_norm]
key_norm = opts[:key_norm]
attention_dense = opts[:attention_dense]

ffn_fun =
case ffn do
Expand Down Expand Up @@ -446,6 +449,7 @@ defmodule Bumblebee.Layers.Transformer do
rotary_embedding: rotary_embedding,
query_norm: query_norm,
key_norm: key_norm,
attention_dense: attention_dense,
name: join(name, "self_attention")
)

Expand Down Expand Up @@ -491,6 +495,7 @@ defmodule Bumblebee.Layers.Transformer do
attention_window_size: attention_window_size,
attention_scale: attention_scale,
rotary_embedding: rotary_embedding,
attention_dense: attention_dense,
name: join(name, "cross_attention")
)

Expand Down Expand Up @@ -772,7 +777,8 @@ defmodule Bumblebee.Layers.Transformer do
output_use_bias: true,
rotary_embedding: nil,
query_norm: nil,
key_norm: nil
key_norm: nil,
attention_dense: nil
])

attention_mask = opts[:attention_mask]
Expand All @@ -792,6 +798,7 @@ defmodule Bumblebee.Layers.Transformer do
rotary_embedding = opts[:rotary_embedding]
query_norm = opts[:query_norm]
key_norm = opts[:key_norm]
attention_dense = opts[:attention_dense]

query_use_bias = opts[:query_use_bias]
key_use_bias = opts[:key_use_bias]
Expand All @@ -804,9 +811,18 @@ defmodule Bumblebee.Layers.Transformer do
inner_size = num_heads * attention_head_size
inner_kv_size = num_key_value_heads * attention_head_size

# Helper to create dense layer, using custom attention_dense if provided
dense_fn = fn input, units, dense_opts ->
if attention_dense do
attention_dense.(input, units, dense_opts)
else
Axon.dense(input, units, dense_opts)
end
end

query =
query
|> Axon.dense(inner_size,
|> dense_fn.(inner_size,
kernel_initializer: kernel_initializer,
name: join(name, "query"),
use_bias: query_use_bias
Expand All @@ -815,7 +831,7 @@ defmodule Bumblebee.Layers.Transformer do

key =
key
|> Axon.dense(inner_kv_size,
|> dense_fn.(inner_kv_size,
kernel_initializer: kernel_initializer,
name: join(name, "key"),
use_bias: key_use_bias
Expand All @@ -824,7 +840,7 @@ defmodule Bumblebee.Layers.Transformer do

value =
value
|> Axon.dense(inner_kv_size,
|> dense_fn.(inner_kv_size,
kernel_initializer: kernel_initializer,
name: join(name, "value"),
use_bias: value_use_bias
Expand Down Expand Up @@ -937,7 +953,7 @@ defmodule Bumblebee.Layers.Transformer do
attention_output =
attention_output
|> Layers.flatten_trailing()
|> Axon.dense(hidden_size,
|> dense_fn.(hidden_size,
kernel_initializer: kernel_initializer,
name: join(name, "output"),
use_bias: output_use_bias
Expand Down
Loading
Loading