Conversation
Add comprehensive FP8 quantized model support for models like Qwen3-FP8. This enables loading and running FP8 models with per-block scale factors. Changes: bumblebee.ex: - Add :preserve_source_types option to load_model/2 to keep FP8 types pytorch_params.ex: - Pass preserve_source_types through param loading pipeline - Modify ensure_type/3 to preserve FP8 types when option is set layers.ex: - Add fp8_aware_dense/3 layer that handles FP8 quantized weights - Implements block-wise dequantization using scale_inv parameter - Automatically falls back to identity scaling for non-FP8 models layers/transformer.ex: - Add :attention_dense option to blocks/2, block/2, multi_head_attention/4 - Allows custom dense function for Q, K, V, and output projections text/qwen3.ex: - Update decoder to use fp8_aware_dense for attention via attention_dense - Update gated_ffn to use fp8_aware_dense for FFN layers - Add scale_inv to params_mapping for all attention and FFN layers The implementation supports both: - Pre-dequantization: Convert FP8->F32 before loading - Native FP8: Load FP8 weights directly, apply scale_inv at runtime Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Update dependencies to use git versions of nx and safetensors which
have the new FP8 type representation: {:f8_e4m3fn, 8} instead of
{:f, 8, :e4m3fn}.
Changes:
- Update mix.exs to use git deps for nx, exla, torchx, and safetensors
- Update FP8 type detection pattern in pytorch_params.ex
- Add TODO comments noting deps should be switched back to hex when released
Tested with Qwen/Qwen3-4B-Instruct-2507-FP8 model - loads and generates
correctly with preserve_source_types: true.
Add a new section demonstrating how to load and use FP8 quantized Qwen3 models with preserve_source_types: true option. Updated introduction and summary to reflect the new capability.
e8e7f67 to
d6d5f62
Compare
|
To generate the fp8 tiny model Generate a tiny FP8 Qwen3 model for testing Bumblebee's FP8 support.
This creates a minimal model with:
- FP8 E4M3FN weights for linear layers
- Corresponding weight_scale_inv tensors (128x128 block scaling)
- Saved in safetensors format
Usage:
python generate_fp8_qwen3.py
# Then upload to HuggingFace: huggingface-cli upload roulis/tiny-fp8-qwen3 ./tiny-fp8-qwen3
"""
import torch
import json
import os
from safetensors.torch import save_file
# Tiny model config matching existing tiny-random-Qwen3ForCausalLM
CONFIG = {
"architectures": ["Qwen3ForCausalLM"],
"hidden_size": 32,
"intermediate_size": 64,
"num_attention_heads": 4,
"num_hidden_layers": 2,
"num_key_value_heads": 2,
"vocab_size": 1024,
"head_dim": 8, # hidden_size / num_attention_heads
"rms_norm_eps": 1e-6,
"rope_theta": 1000000.0,
"max_position_embeddings": 512,
"torch_dtype": "float8_e4m3fn",
"model_type": "qwen3",
"use_qk_norm": True,
"tie_word_embeddings": True,
"quantization_config": {
"quant_method": "fp8",
"weight_block_size": [128, 128]
}
}
BLOCK_SIZE = 128
def create_fp8_weight(shape, seed=42):
"""Create a random FP8 E4M3FN weight tensor."""
torch.manual_seed(seed)
# Create random values in valid FP8 E4M3FN range (-448 to 448)
weight_f32 = torch.randn(shape) * 0.1
weight_fp8 = weight_f32.to(torch.float8_e4m3fn)
return weight_fp8
def create_scale_inv(weight_shape):
"""Create scale_inv tensor for block-wise dequantization.
Shape: [ceil(out_features/128), ceil(in_features/128)]
For testing, use scale of 1.0 (identity) so dequantized = original.
"""
out_features, in_features = weight_shape
out_blocks = (out_features + BLOCK_SIZE - 1) // BLOCK_SIZE
in_blocks = (in_features + BLOCK_SIZE - 1) // BLOCK_SIZE
# Use 1.0 for identity scaling (easier to verify in tests)
return torch.ones(out_blocks, in_blocks, dtype=torch.float32)
def generate_model():
hidden_size = CONFIG["hidden_size"]
intermediate_size = CONFIG["intermediate_size"]
num_heads = CONFIG["num_attention_heads"]
num_kv_heads = CONFIG["num_key_value_heads"]
head_dim = CONFIG["head_dim"]
vocab_size = CONFIG["vocab_size"]
num_layers = CONFIG["num_hidden_layers"]
tensors = {}
seed = 0
# Embedding (not quantized)
tensors["model.embed_tokens.weight"] = torch.randn(vocab_size, hidden_size)
for layer_idx in range(num_layers):
prefix = f"model.layers.{layer_idx}"
# Self-attention projections (FP8 quantized)
q_size = num_heads * head_dim
kv_size = num_kv_heads * head_dim
# Q projection
tensors[f"{prefix}.self_attn.q_proj.weight"] = create_fp8_weight((q_size, hidden_size), seed)
seed += 1
tensors[f"{prefix}.self_attn.q_proj.weight_scale_inv"] = create_scale_inv((q_size, hidden_size))
# K projection
tensors[f"{prefix}.self_attn.k_proj.weight"] = create_fp8_weight((kv_size, hidden_size), seed)
seed += 1
tensors[f"{prefix}.self_attn.k_proj.weight_scale_inv"] = create_scale_inv((kv_size, hidden_size))
# V projection
tensors[f"{prefix}.self_attn.v_proj.weight"] = create_fp8_weight((kv_size, hidden_size), seed)
seed += 1
tensors[f"{prefix}.self_attn.v_proj.weight_scale_inv"] = create_scale_inv((kv_size, hidden_size))
# O projection
tensors[f"{prefix}.self_attn.o_proj.weight"] = create_fp8_weight((hidden_size, q_size), seed)
seed += 1
tensors[f"{prefix}.self_attn.o_proj.weight_scale_inv"] = create_scale_inv((hidden_size, q_size))
# QK norms (not quantized)
tensors[f"{prefix}.self_attn.q_norm.weight"] = torch.ones(head_dim)
tensors[f"{prefix}.self_attn.k_norm.weight"] = torch.ones(head_dim)
# MLP (FP8 quantized)
tensors[f"{prefix}.mlp.gate_proj.weight"] = create_fp8_weight((intermediate_size, hidden_size), seed)
seed += 1
tensors[f"{prefix}.mlp.gate_proj.weight_scale_inv"] = create_scale_inv((intermediate_size, hidden_size))
tensors[f"{prefix}.mlp.up_proj.weight"] = create_fp8_weight((intermediate_size, hidden_size), seed)
seed += 1
tensors[f"{prefix}.mlp.up_proj.weight_scale_inv"] = create_scale_inv((intermediate_size, hidden_size))
tensors[f"{prefix}.mlp.down_proj.weight"] = create_fp8_weight((hidden_size, intermediate_size), seed)
seed += 1
tensors[f"{prefix}.mlp.down_proj.weight_scale_inv"] = create_scale_inv((hidden_size, intermediate_size))
# Layer norms (not quantized)
tensors[f"{prefix}.input_layernorm.weight"] = torch.ones(hidden_size)
tensors[f"{prefix}.post_attention_layernorm.weight"] = torch.ones(hidden_size)
# Final norm (not quantized)
tensors["model.norm.weight"] = torch.ones(hidden_size)
# LM head (can be tied to embeddings, but we include it for completeness)
# Not quantized since it shares with embeddings
return tensors
def main():
output_dir = "tiny-fp8-qwen3"
os.makedirs(output_dir, exist_ok=True)
# Generate model tensors
tensors = generate_model()
# Save as safetensors
save_file(tensors, os.path.join(output_dir, "model.safetensors"))
# Save config
with open(os.path.join(output_dir, "config.json"), "w") as f:
json.dump(CONFIG, f, indent=2)
print(f"Model saved to {output_dir}/")
print(f"Total tensors: {len(tensors)}")
print("\nTo upload to HuggingFace:")
print(f" huggingface-cli upload roulis/tiny-fp8-qwen3 {output_dir}")
if __name__ == "__main__":
main() |
- Add fp8_aware_dense layer unit tests - Add FP8 Qwen3 model loading test using roulis/tiny-fp8-qwen3 - Include Python script to generate tiny FP8 test models
d6d5f62 to
6893058
Compare
| # Preserve FP8 E4M3FN types when preserve_source_types is enabled | ||
| {_expected, {:f8_e4m3fn, 8}, true} -> tensor |
There was a problem hiding this comment.
We likely don't want to do this here, because Axon may cast and it can lead to inconsistent behaviour (see #311). Ideally we want to apply an Axon.MixedPrecision policy, but we cannot determine it upfront. Also Axon policies apply per layer, but in this case we may have a layer where each param has different type. I need to think about the best way to address it and the loading API we should have.
There was a problem hiding this comment.
@seanmor5 do you have any thoughts on how to handle layers where parameters have different types, as part of Axon.MixedPrecision?
There was a problem hiding this comment.
Cladios suggested something like below
Allow the params and compute fields of a policy to be either:
- A type tuple (current behavior) — e.g., {:f, 16} — applies uniformly to all params
- A map — e.g., %{"kernel" => {:f8_e4m3fn, 8}, :default => {:f, 32}} — per-parameter types
- nil — no casting (current behavior)
The output field stays as-is (type or nil) since it applies to layer outputs, not individual params.
API Example
Existing API (unchanged)
policy = Axon.MixedPrecision.create_policy(params: {:bf, 16}, compute: {:bf, 16}, output: {:f, 32})
New: per-parameter types
policy = Axon.MixedPrecision.create_policy(
params: %{"kernel" => {:f8_e4m3fn, 8}, :default => {:f, 32}},
compute: %{"kernel" => nil, "scale_inv" => nil, :default => {:f, 32}},
output: {:f, 32}
)
A nil value for a specific param in the map means "don't cast this parameter" (the layer function handles it). The :default key provides the fallback type for params not
explicitly listed. If :default is absent, unlisted params are not cast.
Implementation
- lib/axon/mixed_precision.ex — Add cast/4 and update cast/3
Add a new cast/4 function that accepts a param_name:
def cast(%Policy{} = policy, tensor_or_container, variable_type, param_name)
Uses a resolve_type/2 helper to look up the type:
defp resolve_type(nil, _name), do: nil
defp resolve_type(%{} = map, name), do: Map.get(map, name, Map.get(map, :default))
defp resolve_type(type, _name), do: type
Update cast/3 to handle map values by falling back to :default (for layer input/output casts where no param name is available).
- lib/axon/mixed_precision/policy.ex — Update Inspect
Handle map values in the inspect protocol. When a field is a map, display it as {per-param} or show the map contents.
- lib/axon/compiler.ex — Two changes
a) Init path (line 1076, layer_init_fun):
Currently destructures %{params: dtype} and passes a single dtype to init_param. Change init_param to resolve per-parameter types:
line 1130, change from:
dtype = dtype || Nx.type(template)
to:
dtype = resolve_param_type(params_policy, name) || Nx.type(template)
Add resolve_param_type/2 (same logic as resolve_type above).
b) Predict path (line 920-926, parameter casting):
Pass the parameter name through to safe_policy_cast:
line 926, change from:
safe_policy_cast(maybe_freeze(param, frz), policy, :compute)
to:
safe_policy_cast(maybe_freeze(param, frz), policy, :compute, v)
Add an optional param_name argument to safe_policy_cast that delegates to cast/4.
- Tests
Add compiler tests for:
- Per-parameter init types (kernel gets one type, bias gets another)
- Per-parameter compute types (some params cast, others preserved)
- :default key behavior
- Backward compatibility (existing uniform type policies still work)
Files to Modify
┌────────────────────────────────────┬───────────────────────────────────────────────────┐
│ File │ Change │
├────────────────────────────────────┼───────────────────────────────────────────────────┤
│ lib/axon/mixed_precision.ex │ Add cast/4, update cast/3, add resolve_type/2 │
├────────────────────────────────────┼───────────────────────────────────────────────────┤
│ lib/axon/mixed_precision/policy.ex │ Update Inspect for map values │
├────────────────────────────────────┼───────────────────────────────────────────────────┤
│ lib/axon/compiler.ex │ Per-param type resolution in init + predict paths │
├────────────────────────────────────┼───────────────────────────────────────────────────┤
│ test/axon/compiler_test.exs │ Tests for per-parameter mixed precision │
└────────────────────────────────────┴───────────────────────────────────────────────────┘
Verification
- mix test — all existing tests pass (backward compat)
- mix test test/axon/compiler_test.exs — new per-parameter tests pass
- Manual verification: create a model with per-param policy, confirm params have expected types after init and during compute
Summary
This PR adds support for loading and running FP8 (8-bit floating point) quantized models natively in Bumblebee. FP8 models use approximately half the memory
of BF16 models while maintaining good inference quality.
Changes
Core FP8 Support
preserve_source_typesoption toBumblebee.load_model/2to keep FP8 weights in their native formatdequantize_kernel/3function inBumblebee.Layersfor runtime FP8 → F32 conversion using scale_inv tensors{:f8_e4m3fn, 8}Qwen3 FP8 Integration
params_mappingfor FP8 weight scales (weight_scale_inv) in Qwen3 architectureDependencies
nx,exla,torchx, andsafetensorsfor FP8 type supportDocumentation
Usage
Loading an FP8 Model
Supported FP8 Models
Notes