From ed6b793b19fe37ade07a55f70f801cf09183036b Mon Sep 17 00:00:00 2001 From: Torsten Scholak Date: Sun, 14 Dec 2025 02:06:30 +0000 Subject: [PATCH 01/25] Refactor Apriel2 cache and add Qwen2 converter MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Cache improvements: - Add methods to _AttentionCache and _SSMCache (reset, reorder, crop, batch_repeat, batch_select, is_initialized, batch_size) - Add _iter_caches() helper to flatten stochastic layer dicts - Simplify Apriel2Cache methods using new abstractions - Fix sliding window attention mask sizes (cumulative_length tracking) - Localize KDA tuple handling in _SSMCache Test improvements: - Split tests into contract tests (vs HuggingFace) and Apriel2-specific - Add shared fixtures to conftest.py - Add edge case tests for SSM tuple operations - Remove duplicated fixture definitions Qwen2 converter: - Add Qwen2/Qwen2.5 to Apriel2 config conversion - Add weight mapping plan for Qwen2 models 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- fast_llm_external_models/apriel2/cache.py | 311 ++-- .../apriel2/conversion/qwen2/__init__.py | 6 + .../apriel2/conversion/qwen2/config.py | 81 ++ .../apriel2/conversion/qwen2/plan.py | 113 ++ fast_llm_external_models/apriel2/convert.py | 9 +- .../tests/test_apriel2/conftest.py | 187 +++ .../tests/test_apriel2/test_cache.py | 1258 ----------------- .../test_cache_apriel2_specific.py | 342 +++++ .../test_apriel2/test_cache_contracts.py | 592 ++++++++ 9 files changed, 1499 insertions(+), 1400 deletions(-) create mode 100644 fast_llm_external_models/apriel2/conversion/qwen2/__init__.py create mode 100644 fast_llm_external_models/apriel2/conversion/qwen2/config.py create mode 100644 fast_llm_external_models/apriel2/conversion/qwen2/plan.py delete mode 100644 fast_llm_external_models/tests/test_apriel2/test_cache.py create mode 100644 fast_llm_external_models/tests/test_apriel2/test_cache_apriel2_specific.py create mode 100644 fast_llm_external_models/tests/test_apriel2/test_cache_contracts.py diff --git a/fast_llm_external_models/apriel2/cache.py b/fast_llm_external_models/apriel2/cache.py index 86c67a085..32db547b9 100644 --- a/fast_llm_external_models/apriel2/cache.py +++ b/fast_llm_external_models/apriel2/cache.py @@ -4,14 +4,18 @@ class _AttentionCache: - __slots__ = ["key", "value", "window"] + __slots__ = ["key", "value", "window", "cumulative_length"] def __init__(self, window=None): self.key = None self.value = None self.window = window + self.cumulative_length = 0 def update(self, key, value): + new_tokens = key.shape[-2] + self.cumulative_length += new_tokens + if self.key is None: if self.window and key.shape[-2] > self.window: self.key = key[..., -self.window :, :].contiguous() @@ -35,6 +39,40 @@ def _window(self, cache, new): return cache return torch.cat([cache, new], -2)[..., -self.window :, :].contiguous() + def reset(self): + self.key = None + self.value = None + self.cumulative_length = 0 + + def reorder(self, beam_idx): + if self.key is not None: + self.key = self.key.index_select(0, beam_idx.to(self.key.device)) + self.value = self.value.index_select(0, beam_idx.to(self.value.device)) + + def crop(self, max_length): + if self.key is not None: + self.key = self.key[..., :max_length, :] + self.value = self.value[..., :max_length, :] + self.cumulative_length = self.key.shape[-2] + + def batch_repeat(self, repeats): + if self.key is not None: + self.key = self.key.repeat_interleave(repeats, dim=0) + self.value = self.value.repeat_interleave(repeats, dim=0) + + def batch_select(self, indices): + if self.key is not None: + self.key = self.key.index_select(0, indices.to(self.key.device)) + self.value = self.value.index_select(0, indices.to(self.value.device)) + + @property + def is_initialized(self): + return self.key is not None + + @property + def batch_size(self): + return self.key.shape[0] if self.key is not None else None + class _SSMCache: __slots__ = ["conv", "recurrent"] @@ -43,6 +81,52 @@ def __init__(self): self.conv = None self.recurrent = None + def reset(self): + self.conv = None + self.recurrent = None + + def reorder(self, beam_idx): + if self.conv is not None: + if isinstance(self.conv, tuple): + self.conv = tuple(c.index_select(0, beam_idx.to(c.device)) for c in self.conv) + else: + self.conv = self.conv.index_select(0, beam_idx.to(self.conv.device)) + if self.recurrent is not None: + self.recurrent = self.recurrent.index_select(0, beam_idx.to(self.recurrent.device)) + + def crop(self, max_length): + pass # SSM caches don't have sequence dimension to crop + + def batch_repeat(self, repeats): + if self.conv is not None: + if isinstance(self.conv, tuple): + self.conv = tuple(c.repeat_interleave(repeats, dim=0) for c in self.conv) + else: + self.conv = self.conv.repeat_interleave(repeats, dim=0) + if self.recurrent is not None: + self.recurrent = self.recurrent.repeat_interleave(repeats, dim=0) + + def batch_select(self, indices): + if self.conv is not None: + if isinstance(self.conv, tuple): + self.conv = tuple(c.index_select(0, indices.to(c.device)) for c in self.conv) + else: + self.conv = self.conv.index_select(0, indices.to(self.conv.device)) + if self.recurrent is not None: + self.recurrent = self.recurrent.index_select(0, indices.to(self.recurrent.device)) + + @property + def is_initialized(self): + return self.conv is not None + + @property + def batch_size(self): + if self.conv is None: + return None + if isinstance(self.conv, tuple): + return self.conv[0].shape[0] + return self.conv.shape[0] + class _DummyCacheLayer: pass @@ -93,14 +177,19 @@ def set_active_mixer(self, layer_idx, mixer_name): self.active_mixers[layer_idx] = mixer_name def get_seq_length(self, layer_idx=0): + """Returns the cumulative sequence length of tokens seen by the cache. + + For sliding window caches, this returns the total tokens seen (not just cached). + This matches HuggingFace's DynamicSlidingWindowLayer behavior. + """ layer = self.layers[layer_idx] if isinstance(layer, dict): mixer = self.active_mixers[layer_idx] if mixer and isinstance(layer[mixer], _AttentionCache): - return layer[mixer].key.shape[-2] if layer[mixer].key is not None else 0 + return layer[mixer].cumulative_length return 0 if isinstance(layer, _AttentionCache): - return layer.key.shape[-2] if layer.key is not None else 0 + return layer.cumulative_length return 0 def get_max_cache_shape(self, layer_idx=0): @@ -114,22 +203,61 @@ def get_max_cache_shape(self, layer_idx=0): return None def get_mask_sizes(self, cache_position, layer_idx): + """Return the length and offset of the cache, used to generate the attention mask. + + For standard (non-sliding) attention: + kv_offset = 0 (KV[0] corresponds to sequence position 0) + kv_length = cumulative_length + query_length + + For sliding window attention: + kv_offset = max(cumulative_length - window + 1, 0) + kv_length = min(cumulative_length, window - 1) + query_length + + For SSM/linear layers: + kv_offset = 0, kv_length = query_length (no KV cache to attend to) + """ query_length = cache_position.shape[0] - past_seen_tokens = self.get_seq_length(layer_idx) - kv_length = query_length + past_seen_tokens - kv_offset = past_seen_tokens - return kv_length, kv_offset + layer = self.layers[layer_idx] + + # Handle stochastic layers by getting the active mixer's cache + if isinstance(layer, dict): + mixer = self.active_mixers[layer_idx] + if mixer is None: + # No active mixer set, return defaults + return query_length, 0 + cache = layer[mixer] + else: + cache = layer + + # SSM layers don't have KV cache for attention mask purposes + if isinstance(cache, _SSMCache): + return query_length, 0 + + # Attention cache - check if sliding window + if isinstance(cache, _AttentionCache): + cumulative = cache.cumulative_length + window = cache.window + + if window is not None: + # Sliding window attention + kv_offset = max(cumulative - window + 1, 0) + if cumulative >= window: + kv_length = window - 1 + query_length + else: + kv_length = cumulative + query_length + else: + # Full attention + kv_offset = 0 + kv_length = cumulative + query_length + + return kv_length, kv_offset + + # Fallback + return query_length, 0 @property def has_previous_state(self): - for layer in self.layers: - if isinstance(layer, dict): - for cache in layer.values(): - if isinstance(cache, _SSMCache) and cache.conv is not None: - return True - elif isinstance(layer, _SSMCache) and layer.conv is not None: - return True - return False + return any(isinstance(cache, _SSMCache) and cache.conv is not None for cache in self._iter_caches()) @property def key_cache(self): @@ -147,101 +275,33 @@ def conv_states(self): def recurrent_states(self): return _LayerListAccessor(self, "recurrent") - def reorder_cache(self, beam_idx): - for i, layer in enumerate(self.layers): + def _iter_caches(self): + """Iterate over all leaf cache objects (flattening stochastic layer dicts).""" + for layer in self.layers: if isinstance(layer, dict): - for cache in layer.values(): - self._reorder_cache_obj(cache, beam_idx) + yield from layer.values() else: - self._reorder_cache_obj(layer, beam_idx) + yield layer - def _reorder_cache_obj(self, cache, beam_idx): - if isinstance(cache, _AttentionCache): - if cache.key is not None: - cache.key = cache.key.index_select(0, beam_idx.to(cache.key.device)) - cache.value = cache.value.index_select(0, beam_idx.to(cache.value.device)) - elif isinstance(cache, _SSMCache): - if cache.conv is not None: - # Handle both single tensor (GDN/Mamba) and tuple (KDA) conv states - if isinstance(cache.conv, tuple): - cache.conv = tuple(c.index_select(0, beam_idx.to(c.device)) for c in cache.conv) - else: - cache.conv = cache.conv.index_select(0, beam_idx.to(cache.conv.device)) - if cache.recurrent is not None: - cache.recurrent = cache.recurrent.index_select(0, beam_idx.to(cache.recurrent.device)) + def reorder_cache(self, beam_idx): + for cache in self._iter_caches(): + cache.reorder(beam_idx) def reset(self): - for layer in self.layers: - if isinstance(layer, dict): - for cache in layer.values(): - self._reset_cache_obj(cache) - else: - self._reset_cache_obj(layer) - - def _reset_cache_obj(self, cache): - if isinstance(cache, _AttentionCache): - cache.key = None - cache.value = None - elif isinstance(cache, _SSMCache): - cache.conv = None - cache.recurrent = None + for cache in self._iter_caches(): + cache.reset() def crop(self, max_length): - for layer in self.layers: - if isinstance(layer, dict): - for cache in layer.values(): - if isinstance(cache, _AttentionCache) and cache.key is not None: - cache.key = cache.key[..., :max_length, :] - cache.value = cache.value[..., :max_length, :] - elif isinstance(layer, _AttentionCache) and layer.key is not None: - layer.key = layer.key[..., :max_length, :] - layer.value = layer.value[..., :max_length, :] + for cache in self._iter_caches(): + cache.crop(max_length) def batch_repeat_interleave(self, repeats): - for layer in self.layers: - if isinstance(layer, dict): - for cache in layer.values(): - self._batch_repeat_cache_obj(cache, repeats) - else: - self._batch_repeat_cache_obj(layer, repeats) - - def _batch_repeat_cache_obj(self, cache, repeats): - if isinstance(cache, _AttentionCache): - if cache.key is not None: - cache.key = cache.key.repeat_interleave(repeats, dim=0) - cache.value = cache.value.repeat_interleave(repeats, dim=0) - elif isinstance(cache, _SSMCache): - if cache.conv is not None: - # Handle both single tensor (GDN/Mamba) and tuple (KDA) conv states - if isinstance(cache.conv, tuple): - cache.conv = tuple(c.repeat_interleave(repeats, dim=0) for c in cache.conv) - else: - cache.conv = cache.conv.repeat_interleave(repeats, dim=0) - if cache.recurrent is not None: - cache.recurrent = cache.recurrent.repeat_interleave(repeats, dim=0) + for cache in self._iter_caches(): + cache.batch_repeat(repeats) def batch_select_indices(self, indices): - for layer in self.layers: - if isinstance(layer, dict): - for cache in layer.values(): - self._batch_select_cache_obj(cache, indices) - else: - self._batch_select_cache_obj(layer, indices) - - def _batch_select_cache_obj(self, cache, indices): - if isinstance(cache, _AttentionCache): - if cache.key is not None: - cache.key = cache.key.index_select(0, indices.to(cache.key.device)) - cache.value = cache.value.index_select(0, indices.to(cache.value.device)) - elif isinstance(cache, _SSMCache): - if cache.conv is not None: - # Handle both single tensor (GDN/Mamba) and tuple (KDA) conv states - if isinstance(cache.conv, tuple): - cache.conv = tuple(c.index_select(0, indices.to(c.device)) for c in cache.conv) - else: - cache.conv = cache.conv.index_select(0, indices.to(cache.conv.device)) - if cache.recurrent is not None: - cache.recurrent = cache.recurrent.index_select(0, indices.to(cache.recurrent.device)) + for cache in self._iter_caches(): + cache.batch_select(indices) @property def is_compileable(self): @@ -249,19 +309,7 @@ def is_compileable(self): @property def is_initialized(self): - for layer in self.layers: - if isinstance(layer, dict): - for cache in layer.values(): - if isinstance(cache, _AttentionCache) and cache.key is not None: - return True - if isinstance(cache, _SSMCache) and cache.conv is not None: - return True - else: - if isinstance(layer, _AttentionCache) and layer.key is not None: - return True - if isinstance(layer, _SSMCache) and layer.conv is not None: - return True - return False + return any(cache.is_initialized for cache in self._iter_caches()) @property def is_sliding(self): @@ -280,39 +328,20 @@ def is_sliding(self): @property def max_batch_size(self): - for layer in self.layers: - if isinstance(layer, dict): - for cache in layer.values(): - if isinstance(cache, _AttentionCache) and cache.key is not None: - return cache.key.shape[0] - if isinstance(cache, _SSMCache) and cache.conv is not None: - # Handle both single tensor and tuple conv states - if isinstance(cache.conv, tuple): - return cache.conv[0].shape[0] - return cache.conv.shape[0] - else: - if isinstance(layer, _AttentionCache) and layer.key is not None: - return layer.key.shape[0] - if isinstance(layer, _SSMCache) and layer.conv is not None: - # Handle both single tensor and tuple conv states - if isinstance(layer.conv, tuple): - return layer.conv[0].shape[0] - return layer.conv.shape[0] + for cache in self._iter_caches(): + bs = cache.batch_size + if bs is not None: + return bs return None @property def max_cache_len(self): - max_len = None - for layer in self.layers: - if isinstance(layer, dict): - for cache in layer.values(): - if isinstance(cache, _AttentionCache): - if cache.window is not None: - max_len = cache.window if max_len is None else min(max_len, cache.window) - elif isinstance(layer, _AttentionCache): - if layer.window is not None: - max_len = layer.window if max_len is None else min(max_len, layer.window) - return max_len + windows = [ + cache.window + for cache in self._iter_caches() + if isinstance(cache, _AttentionCache) and cache.window is not None + ] + return min(windows) if windows else None def __len__(self): return len(self.layers) diff --git a/fast_llm_external_models/apriel2/conversion/qwen2/__init__.py b/fast_llm_external_models/apriel2/conversion/qwen2/__init__.py new file mode 100644 index 000000000..d0a0b8e6e --- /dev/null +++ b/fast_llm_external_models/apriel2/conversion/qwen2/__init__.py @@ -0,0 +1,6 @@ +"""Qwen2/Qwen2.5 to Apriel2 conversion module.""" + +from fast_llm_external_models.apriel2.conversion.qwen2.config import convert_config +from fast_llm_external_models.apriel2.conversion.qwen2.plan import plan_qwen2_to_apriel2 + +__all__ = ["convert_config", "plan_qwen2_to_apriel2"] diff --git a/fast_llm_external_models/apriel2/conversion/qwen2/config.py b/fast_llm_external_models/apriel2/conversion/qwen2/config.py new file mode 100644 index 000000000..36df744c0 --- /dev/null +++ b/fast_llm_external_models/apriel2/conversion/qwen2/config.py @@ -0,0 +1,81 @@ +"""Qwen2/Qwen2.5 to Apriel2 config conversion.""" + + +def convert_config(qwen2_config: dict) -> dict: + """Convert Qwen2/Qwen2.5 config to Apriel2TextConfig format. + + Qwen2.5 architecture: + - Standard transformer with GQA (grouped query attention) + - QKV bias enabled, O bias disabled + - MLP bias disabled + - Gated SwiGLU MLP + - RMSNorm + - RoPE embeddings + + Args: + qwen2_config: HuggingFace Qwen2Config as dict + + Returns: + Apriel2TextConfig-compatible dict + """ + hidden_size = qwen2_config["hidden_size"] + num_attention_heads = qwen2_config["num_attention_heads"] + num_key_value_heads = qwen2_config.get("num_key_value_heads", num_attention_heads) + head_dim = hidden_size // num_attention_heads + + # Qwen2 uses QKV bias but not O bias + # The add_linear_biases in Apriel2 attention config controls all biases uniformly, + # but we can set it to True and the o_proj bias will just be missing from weights + # (handled by strict=False loading or explicit handling in the plan) + + return { + "model_type": "apriel2_text", + "architectures": ["Apriel2ForCausalLM"], + "auto_map": { + "AutoConfig": "configuration_apriel2.Apriel2TextConfig", + "AutoModel": "modeling_apriel2.Apriel2TextModel", + "AutoModelForCausalLM": "modeling_apriel2.Apriel2ForCausalLM", + }, + "hidden_size": hidden_size, + "vocab_size": qwen2_config["vocab_size"], + "tie_word_embeddings": qwen2_config.get("tie_word_embeddings", False), + "decoder": { + "type": "fixed", + "num_blocks": qwen2_config["num_hidden_layers"], + "block": { + "mixer": { + "type": "attention", + "heads": num_attention_heads, + "head_groups": num_key_value_heads, + "head_size": head_dim, + # Qwen2 has QKV bias but not O bias + # We set True and handle O bias separately + "add_linear_biases": True, + "rotary": { + "type": "mistral_1d", + "theta": qwen2_config.get("rope_theta", 1000000.0), + }, + }, + "mlp": { + "type": "mlp", + "intermediate_size": qwen2_config["intermediate_size"], + "activation": qwen2_config.get("hidden_act", "silu"), + "gated": True, + "add_linear_biases": False, + }, + "normalization": { + "type": "rms_norm", + "epsilon": qwen2_config.get("rms_norm_eps", 1e-6), + }, + }, + }, + "head": { + "normalization": { + "type": "rms_norm", + "epsilon": qwen2_config.get("rms_norm_eps", 1e-6), + } + }, + "embeddings": { + "max_position_embeddings": qwen2_config.get("max_position_embeddings", 32768), + }, + } diff --git a/fast_llm_external_models/apriel2/conversion/qwen2/plan.py b/fast_llm_external_models/apriel2/conversion/qwen2/plan.py new file mode 100644 index 000000000..e5ae3e9d8 --- /dev/null +++ b/fast_llm_external_models/apriel2/conversion/qwen2/plan.py @@ -0,0 +1,113 @@ +"""Qwen2/Qwen2.5 to Apriel2 weight conversion plan.""" + +from fast_llm_external_models.apriel2.conversion.expr import ( + Expr, + ExprPlan, + Init, + Ref, + W, +) + + +def plan_qwen2_to_apriel2(qwen2_config: dict) -> ExprPlan: + """Build an expression plan for Qwen2/Qwen2.5 to Apriel2 conversion. + + This is a pure mapping (all Ref expressions) since Qwen2→Apriel2 + is just renaming keys. The weight tensors are identical. + + Key mapping (source keys have "model." prefix in safetensors): + Qwen2 (safetensor key) Apriel2 + ---------------------- ------- + model.embed_tokens.weight -> model.embed_tokens.weight + model.norm.weight -> model.norm.weight + model.layers.{i}.input_layernorm.weight -> model.decoder.blocks.{i}.input_layernorm.weight + model.layers.{i}.post_attention_layernorm.weight -> model.decoder.blocks.{i}.post_attention_layernorm.weight + model.layers.{i}.self_attn.q_proj.weight -> model.decoder.blocks.{i}.mixer.q_proj.weight + model.layers.{i}.self_attn.k_proj.weight -> model.decoder.blocks.{i}.mixer.k_proj.weight + model.layers.{i}.self_attn.v_proj.weight -> model.decoder.blocks.{i}.mixer.v_proj.weight + model.layers.{i}.self_attn.o_proj.weight -> model.decoder.blocks.{i}.mixer.o_proj.weight + model.layers.{i}.mlp.gate_proj.weight -> model.decoder.blocks.{i}.mlp.gate_proj.weight + model.layers.{i}.mlp.up_proj.weight -> model.decoder.blocks.{i}.mlp.up_proj.weight + model.layers.{i}.mlp.down_proj.weight -> model.decoder.blocks.{i}.mlp.down_proj.weight + + Note: Qwen2 has QKV biases but no O bias. We skip the biases in the conversion + since Apriel2 is configured with add_linear_biases=False for uniform handling. + + Args: + qwen2_config: HuggingFace Qwen2Config as dict + + Returns: + ExprPlan with Ref mappings + """ + mappings: dict[str, Expr] = {} + + num_layers = qwen2_config["num_hidden_layers"] + hidden_size = qwen2_config["hidden_size"] + + # Static mappings (embeddings and final norm) + # Note: Qwen2 safetensor keys have "model." prefix + static_mappings = [ + (W("model", "embed_tokens", "weight"), W("model", "embed_tokens", "weight")), + (W("model", "norm", "weight"), W("model", "norm", "weight")), + ] + + # lm_head - only if not tied + if not qwen2_config.get("tie_word_embeddings", False): + static_mappings.append( + (W("lm_head", "weight"), W("lm_head", "weight")) + ) + + for src, tgt in static_mappings: + mappings[tgt] = Ref(key=src) + + # Layer mappings + for layer in range(num_layers): + # Source has "model.layers.{i}" prefix + qwen_layer = W("model", "layers", layer) + apriel_layer = W("model", "decoder", "blocks", layer) + + # Attention projections (weights and biases) + # Qwen2 has QKV bias but no O bias + for proj in ["q_proj", "k_proj", "v_proj", "o_proj"]: + src = qwen_layer / "self_attn" / proj / "weight" + tgt = apriel_layer / "mixer" / proj / "weight" + mappings[tgt] = Ref(key=src) + + # QKV biases (Qwen2 has these, but not O bias) + for proj in ["q_proj", "k_proj", "v_proj"]: + src = qwen_layer / "self_attn" / proj / "bias" + tgt = apriel_layer / "mixer" / proj / "bias" + mappings[tgt] = Ref(key=src) + + # O bias - Qwen2 doesn't have this, so initialize to zeros + # Shape is hidden_size (d_model) + mappings[apriel_layer / "mixer" / "o_proj" / "bias"] = Init( + shape=(hidden_size,), + init_type="zeros", + ) + + # MLP projections + for proj in ["gate_proj", "up_proj", "down_proj"]: + src = qwen_layer / "mlp" / proj / "weight" + tgt = apriel_layer / "mlp" / proj / "weight" + mappings[tgt] = Ref(key=src) + + # Layer norms + mappings[apriel_layer / "input_layernorm" / "weight"] = Ref( + key=qwen_layer / "input_layernorm" / "weight" + ) + mappings[apriel_layer / "post_attention_layernorm" / "weight"] = Ref( + key=qwen_layer / "post_attention_layernorm" / "weight" + ) + + return ExprPlan( + mappings=mappings, + source_format="qwen2", + target_format="apriel2", + metadata={ + "num_layers": num_layers, + "hidden_size": qwen2_config["hidden_size"], + "num_attention_heads": qwen2_config["num_attention_heads"], + "num_key_value_heads": qwen2_config.get("num_key_value_heads", qwen2_config["num_attention_heads"]), + }, + ) diff --git a/fast_llm_external_models/apriel2/convert.py b/fast_llm_external_models/apriel2/convert.py index cbf921b31..05c38c7ce 100644 --- a/fast_llm_external_models/apriel2/convert.py +++ b/fast_llm_external_models/apriel2/convert.py @@ -15,6 +15,7 @@ Supported source formats: - llava: Llava/Pixtral models +- qwen2: Qwen2/Qwen2.5 models - apriel2: Apriel2 models (surgery-only mode - no conversion, just apply surgeries) """ @@ -46,6 +47,7 @@ # Import source-specific converters from fast_llm_external_models.apriel2.conversion import llava as llava_converter +from fast_llm_external_models.apriel2.conversion import qwen2 as qwen2_converter logger = logging.getLogger(__name__) @@ -73,6 +75,7 @@ def _identity_plan(config: dict) -> ExprPlan: # Each entry maps format name to (config_converter, plan_builder) SOURCE_FORMATS: dict[str, tuple[Callable[[dict], dict], Callable[[dict], ExprPlan]]] = { "llava": (llava_converter.convert_config, llava_converter.plan_llava_to_apriel2), + "qwen2": (qwen2_converter.convert_config, qwen2_converter.plan_qwen2_to_apriel2), "apriel2": (_identity_config, _identity_plan), } @@ -88,8 +91,12 @@ def detect_source_format(config: dict) -> str | None: if model_type in ("llava", "pixtral") or "text_config" in config: return "llava" + # Qwen2/Qwen2.5 detection + if model_type == "qwen2": + return "qwen2" + # Apriel2 detection - check for Apriel2-specific structure - if model_type == "apriel2" or "decoder" in config: + if model_type in ("apriel2", "apriel2_text") or "decoder" in config: return "apriel2" return None diff --git a/fast_llm_external_models/tests/test_apriel2/conftest.py b/fast_llm_external_models/tests/test_apriel2/conftest.py index 8585aec65..5c127d97e 100644 --- a/fast_llm_external_models/tests/test_apriel2/conftest.py +++ b/fast_llm_external_models/tests/test_apriel2/conftest.py @@ -7,6 +7,8 @@ import torch from transformers import LlavaConfig, LlavaForConditionalGeneration, MistralConfig +from fast_llm_external_models.apriel2.cache import _AttentionCache, _SSMCache + # Skip marker for tests that require CUDA for Mamba forward pass requires_cuda = pytest.mark.skipif( @@ -1532,3 +1534,188 @@ def torture_surgery_chain(): }, }, ] + + +# ============================================================================= +# Cache Test Fixtures - Tensor Dimensions +# ============================================================================= + + +@pytest.fixture +def batch_size(): + """Default batch size for cache tests.""" + return 2 + + +@pytest.fixture +def num_heads(): + """Default number of attention heads for cache tests.""" + return 4 + + +@pytest.fixture +def head_dim(): + """Default head dimension for cache tests.""" + return 16 + + +@pytest.fixture +def make_kv(batch_size, num_heads, head_dim): + """Factory fixture for creating KV tensors.""" + + def _make_kv(seq_len): + return ( + torch.randn(batch_size, num_heads, seq_len, head_dim), + torch.randn(batch_size, num_heads, seq_len, head_dim), + ) + + return _make_kv + + +# ============================================================================= +# Cache Test Fixtures - HuggingFace Cache Layers +# ============================================================================= + + +@pytest.fixture +def hf_dynamic_layer(): + """HuggingFace DynamicLayer for full attention contract testing.""" + from transformers.cache_utils import DynamicLayer + + return DynamicLayer() + + +@pytest.fixture +def hf_sliding_layer(window_size): + """HuggingFace DynamicSlidingWindowLayer for sliding window contract testing.""" + from transformers.cache_utils import DynamicSlidingWindowLayer + + return DynamicSlidingWindowLayer(sliding_window=window_size) + + +# ============================================================================= +# Cache Test Fixtures - Apriel2 Low-level Caches +# ============================================================================= + + +@pytest.fixture +def apriel_attention_cache(): + """Apriel2 attention cache without window (full attention).""" + return _AttentionCache(window=None) + + +@pytest.fixture +def apriel_sliding_cache(window_size): + """Apriel2 attention cache with sliding window.""" + return _AttentionCache(window=window_size) + + +@pytest.fixture +def ssm_cache(): + """Apriel2 SSM cache for Mamba/GDN/KDA layers.""" + return _SSMCache() + + +# ============================================================================= +# Cache Test Fixtures - Apriel2 Configs (Simple Versions) +# ============================================================================= + + +@pytest.fixture +def attention_config(): + """Pure attention config (2 layers, no sliding window).""" + from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config + + return Apriel2Config( + vocab_size=100, + hidden_size=64, + decoder={ + "type": "fixed", + "num_blocks": 2, + "block": { + "mixer": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16}, + "mlp": {"type": "mlp", "intermediate_size": 256}, + "normalization": {"type": "rms_norm", "epsilon": 1e-5}, + }, + }, + ) + + +@pytest.fixture +def swa_config(): + """Sliding window attention config (2 layers, window=8).""" + from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config + + return Apriel2Config( + vocab_size=100, + hidden_size=64, + decoder={ + "type": "fixed", + "num_blocks": 2, + "block": { + "mixer": { + "type": "attention", + "heads": 4, + "head_groups": 2, + "head_size": 16, + "window_size": 8, + }, + "mlp": {"type": "mlp", "intermediate_size": 256}, + "normalization": {"type": "rms_norm", "epsilon": 1e-5}, + }, + }, + ) + + +@pytest.fixture +def ssm_config(): + """Pure SSM config (2 layers).""" + from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config + + return Apriel2Config( + vocab_size=100, + hidden_size=64, + decoder={ + "type": "fixed", + "num_blocks": 2, + "block": { + "mixer": {"type": "mamba", "state_size": 16}, + "mlp": {"type": "mlp", "intermediate_size": 256}, + "normalization": {"type": "rms_norm", "epsilon": 1e-5}, + }, + }, + ) + + +@pytest.fixture +def stochastic_config(): + """Stochastic mixer config with attention and mamba (2 layers).""" + from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config + + return Apriel2Config( + vocab_size=100, + hidden_size=64, + decoder={ + "type": "fixed", + "num_blocks": 2, + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16}, + "mamba": {"type": "mamba", "state_size": 16}, + }, + }, + "mlp": {"type": "mlp", "intermediate_size": 256}, + "normalization": {"type": "rms_norm", "epsilon": 1e-5}, + }, + }, + ) + + +# Parameterized window size fixture (used by hf_sliding_layer and apriel_sliding_cache) +@pytest.fixture(params=[4, 8, 16, 32]) +def window_size(request): + """Parameterized window sizes for sliding window tests.""" + return request.param diff --git a/fast_llm_external_models/tests/test_apriel2/test_cache.py b/fast_llm_external_models/tests/test_apriel2/test_cache.py deleted file mode 100644 index ca8158b4f..000000000 --- a/fast_llm_external_models/tests/test_apriel2/test_cache.py +++ /dev/null @@ -1,1258 +0,0 @@ -"""Comprehensive tests for Apriel2Cache. - -Architecture Overview -===================== -Apriel2Cache manages state for autoregressive generation across different mixer types: - -1. **Attention Cache** (_AttentionCache): Stores key/value states - - Supports sliding window (window_size) for SWA - - Efficient roll optimization for single-token decode - -2. **SSM Cache** (_SSMCache): Stores conv and recurrent states - - Used by Mamba, GDN, KDA - - KDA uses tuple conv states (q, k, v), others use single tensor - -3. **Stochastic Mixer Routing**: For layers with multiple mixer options - - Each mixer has independent cache (no sharing) - - active_mixer pointer routes operations to correct sub-cache - - Switching mixers preserves each mixer's independent state - -Cache Invalidation Semantics -============================ -When switching between mixers in a stochastic layer: -- Each mixer maintains its OWN independent history -- Switching does NOT invalidate the previous mixer's cache -- Switching does NOT copy state between mixers -- To invalidate: call reset() explicitly - -This is intentional for training with stochastic sampling where each mixer -should learn from its own history. For inference, main_mixer_name is fixed. - -Test Organization -================= -1. CREATION & PROPERTIES - Cache initialization, config parsing -2. ATTENTION CACHE - Updates, sliding window, concatenation -3. SSM CACHE - Conv states, recurrent states, KDA tuples -4. STOCHASTIC ROUTING - Active mixer, isolation, switching -5. CACHE INVALIDATION - Reset, per-mixer reset, coherence -6. BEAM SEARCH - batch_repeat, reorder, select -7. HF INTEGRATION - get_mask_sizes, indexing, properties -8. GENERATION PATTERNS - Prefill→decode, crop→continue -9. ERROR HANDLING - Guards, bounds, invalid operations -""" - -import pytest -import torch - -from fast_llm_external_models.apriel2.cache import ( - Apriel2Cache, - _AttentionCache, - _SSMCache, -) - - -# ============================================================================= -# FIXTURES - Configs and Sample Data -# ============================================================================= - - -@pytest.fixture -def tiny_attention_config(): - """Minimal config with pure attention layers.""" - from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config - - return Apriel2Config( - vocab_size=100, - hidden_size=64, - decoder={ - "type": "fixed", - "num_blocks": 2, - "block": { - "mixer": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16}, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - }, - ) - - -@pytest.fixture -def swa_config(): - """Config with sliding window attention.""" - from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config - - return Apriel2Config( - vocab_size=100, - hidden_size=64, - decoder={ - "type": "fixed", - "num_blocks": 2, - "block": { - "mixer": { - "type": "attention", - "heads": 4, - "head_groups": 2, - "head_size": 16, - "window_size": 8, # Small for testing - }, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - }, - ) - - -@pytest.fixture -def ssm_config(): - """Config with pure SSM layers (mamba).""" - from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config - - return Apriel2Config( - vocab_size=100, - hidden_size=64, - decoder={ - "type": "fixed", - "num_blocks": 2, - "block": { - "mixer": { - "type": "mamba", - "d_inner": 128, - "d_state": 16, - "dt_rank": 4, - "d_conv": 4, - }, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - }, - ) - - -@pytest.fixture -def kda_config(): - """Config with pure KDA layers.""" - from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config - - return Apriel2Config( - vocab_size=100, - hidden_size=64, - decoder={ - "type": "fixed", - "num_blocks": 2, - "block": { - "mixer": { - "type": "kda", - "heads": 4, - "head_dim": 16, - "convolution_layer": {"kernel_size": 4}, - "normalization": {"epsilon": 1e-5}, - }, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - }, - ) - - -@pytest.fixture -def stochastic_config(): - """Config with stochastic mixer (attention + mamba).""" - from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config - - return Apriel2Config( - vocab_size=100, - hidden_size=64, - decoder={ - "type": "pattern", - "num_blocks": 2, - "pattern": ["attn", "stochastic"], - "blocks": { - "attn": { - "mixer": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16}, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - "stochastic": { - "mixer": { - "type": "stochastic", - "main_mixer_name": "attention", - "mixers": { - "attention": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16}, - "mamba": {"type": "mamba", "d_inner": 128, "d_state": 16, "dt_rank": 4, "d_conv": 4}, - }, - }, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - }, - }, - ) - - -@pytest.fixture -def all_mixers_config(): - """Config with stochastic mixer containing all 5 mixer types.""" - from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config - - return Apriel2Config( - vocab_size=100, - hidden_size=64, - decoder={ - "type": "pattern", - "num_blocks": 2, - "pattern": ["attn", "all_mixers"], - "blocks": { - "attn": { - "mixer": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16}, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - "all_mixers": { - "mixer": { - "type": "stochastic", - "main_mixer_name": "attention", - "mixers": { - "attention": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16}, - "swa": { - "type": "attention", - "heads": 4, - "head_groups": 2, - "head_size": 16, - "window_size": 1024, - }, - "mamba": {"type": "mamba", "d_inner": 128, "d_state": 16, "dt_rank": 4, "d_conv": 4}, - "gdn": { - "type": "gdn", - "value_heads": 4, - "key_heads": 2, - "key_head_dim": 16, - "value_head_dim": 16, - "convolution_layer": {"kernel_size": 4}, - }, - "kda": { - "type": "kda", - "heads": 4, - "head_dim": 16, - "convolution_layer": {"kernel_size": 4}, - "normalization": {"epsilon": 1e-5}, - }, - }, - }, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - }, - }, - ) - - -@pytest.fixture -def multi_window_config(): - """Config with multiple different window sizes.""" - from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config - - return Apriel2Config( - vocab_size=100, - hidden_size=64, - decoder={ - "type": "pattern", - "num_blocks": 3, - "pattern": ["full", "small_window", "large_window"], - "blocks": { - "full": { - "mixer": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16}, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - "small_window": { - "mixer": { - "type": "attention", - "heads": 4, - "head_groups": 2, - "head_size": 16, - "window_size": 512, - }, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - "large_window": { - "mixer": { - "type": "attention", - "heads": 4, - "head_groups": 2, - "head_size": 16, - "window_size": 2048, - }, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - }, - }, - ) - - -@pytest.fixture -def sample_kv(): - """Sample key/value tensors: [batch=2, heads=4, seq=10, head_dim=16].""" - return torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16) - - -@pytest.fixture -def sample_conv_single(): - """Sample single-tensor conv state: [batch=2, d_inner=128, kernel=4].""" - return torch.randn(2, 128, 4) - - -@pytest.fixture -def sample_conv_tuple(): - """Sample tuple conv state for KDA: (q, k, v) each [batch=2, d=64, kernel=3].""" - return (torch.randn(2, 64, 3), torch.randn(2, 64, 3), torch.randn(2, 64, 3)) - - -@pytest.fixture -def sample_recurrent(): - """Sample recurrent state: [batch=2, heads=4, head_dim=16, d_state=16].""" - return torch.randn(2, 4, 16, 16) - - -# ============================================================================= -# SECTION 1: CACHE CREATION & PROPERTIES -# ============================================================================= - - -class TestCacheCreation: - """Test cache initialization from config.""" - - def test_attention_cache_creation(self, tiny_attention_config): - """Create cache for pure attention config.""" - cache = Apriel2Cache(tiny_attention_config) - - assert len(cache) == 2 - assert cache.mixer_types == ["attention", "attention"] - assert all(isinstance(l, _AttentionCache) for l in cache.layers) - - def test_ssm_cache_creation(self, ssm_config): - """Create cache for pure SSM config.""" - cache = Apriel2Cache(ssm_config) - - assert len(cache) == 2 - assert cache.mixer_types == ["mamba", "mamba"] - assert all(isinstance(l, _SSMCache) for l in cache.layers) - - def test_kda_cache_creation(self, kda_config): - """Create cache for pure KDA config.""" - cache = Apriel2Cache(kda_config) - - assert len(cache) == 2 - assert cache.mixer_types == ["kda", "kda"] - assert all(isinstance(l, _SSMCache) for l in cache.layers) - - def test_stochastic_cache_creation(self, stochastic_config): - """Create cache for stochastic mixer config.""" - cache = Apriel2Cache(stochastic_config) - - assert len(cache) == 2 - # Layer 0: pure attention, Layer 1: stochastic (dict) - assert isinstance(cache.layers[0], _AttentionCache) - assert isinstance(cache.layers[1], dict) - assert set(cache.layers[1].keys()) == {"attention", "mamba"} - - def test_swa_window_captured(self, swa_config): - """Verify sliding window size is captured.""" - cache = Apriel2Cache(swa_config) - - assert cache.layers[0].window == 8 - assert cache.is_sliding == [True, True] - - def test_active_mixers_initialized_none(self, stochastic_config): - """Verify active_mixers starts as None for all layers.""" - cache = Apriel2Cache(stochastic_config) - - assert cache.active_mixers == [None, None] - - -class TestCacheProperties: - """Test cache property accessors.""" - - def test_empty_cache_properties(self, tiny_attention_config): - """Test properties of uninitialized cache.""" - cache = Apriel2Cache(tiny_attention_config) - - assert cache.is_initialized == False - assert cache.has_previous_state == False - assert cache.max_batch_size is None - assert cache.max_cache_len is None - assert cache.is_compileable == False - - def test_is_initialized_attention(self, tiny_attention_config, sample_kv): - """is_initialized detects attention cache.""" - cache = Apriel2Cache(tiny_attention_config) - cache.update(*sample_kv, layer_idx=0) - - assert cache.is_initialized == True - - def test_is_initialized_ssm(self, ssm_config, sample_conv_single): - """is_initialized detects SSM cache.""" - cache = Apriel2Cache(ssm_config) - cache.conv_states[0] = sample_conv_single - - assert cache.is_initialized == True - - def test_has_previous_state_ssm_only(self, ssm_config, sample_conv_single): - """has_previous_state only looks at SSM conv states.""" - cache = Apriel2Cache(ssm_config) - - assert cache.has_previous_state == False - cache.conv_states[0] = sample_conv_single - assert cache.has_previous_state == True - - def test_has_previous_state_ignores_attention(self, tiny_attention_config, sample_kv): - """has_previous_state ignores attention cache.""" - cache = Apriel2Cache(tiny_attention_config) - cache.update(*sample_kv, layer_idx=0) - - # Attention cache is set, but has_previous_state only checks SSM - assert cache.has_previous_state == False - - def test_max_batch_size_from_attention(self, tiny_attention_config, sample_kv): - """max_batch_size from attention cache.""" - cache = Apriel2Cache(tiny_attention_config) - cache.update(*sample_kv, layer_idx=0) - - assert cache.max_batch_size == 2 - - def test_max_batch_size_from_ssm(self, ssm_config, sample_conv_single): - """max_batch_size from SSM cache.""" - cache = Apriel2Cache(ssm_config) - cache.conv_states[0] = sample_conv_single - - assert cache.max_batch_size == 2 - - def test_max_batch_size_from_kda_tuple(self, kda_config, sample_conv_tuple): - """max_batch_size from KDA tuple conv state.""" - cache = Apriel2Cache(kda_config) - cache.conv_states[0] = sample_conv_tuple - - assert cache.max_batch_size == 2 - - def test_max_cache_len_single_window(self, swa_config): - """max_cache_len with single window size.""" - cache = Apriel2Cache(swa_config) - assert cache.max_cache_len == 8 - - def test_max_cache_len_multiple_windows(self, multi_window_config): - """max_cache_len returns minimum window.""" - cache = Apriel2Cache(multi_window_config) - assert cache.max_cache_len == 512 # min(512, 2048) - - def test_max_cache_len_no_windows(self, tiny_attention_config): - """max_cache_len is None when no windows.""" - cache = Apriel2Cache(tiny_attention_config) - assert cache.max_cache_len is None - - def test_is_sliding_mixed(self, multi_window_config): - """is_sliding reflects per-layer window presence.""" - cache = Apriel2Cache(multi_window_config) - assert cache.is_sliding == [False, True, True] - - -# ============================================================================= -# SECTION 2: ATTENTION CACHE OPERATIONS -# ============================================================================= - - -class TestAttentionCacheBasics: - """Test basic attention cache operations.""" - - def test_update_stores_kv(self, tiny_attention_config, sample_kv): - """update() stores key/value states.""" - cache = Apriel2Cache(tiny_attention_config) - key, value = sample_kv - - k_out, v_out = cache.update(key, value, layer_idx=0) - - torch.testing.assert_close(k_out, key) - torch.testing.assert_close(v_out, value) - assert cache.get_seq_length(0) == 10 - - def test_update_concatenates(self, tiny_attention_config, sample_kv): - """Subsequent updates concatenate.""" - cache = Apriel2Cache(tiny_attention_config) - key, value = sample_kv - - cache.update(key, value, layer_idx=0) - k_out, v_out = cache.update(key, value, layer_idx=0) - - assert k_out.shape[-2] == 20 - assert cache.get_seq_length(0) == 20 - - def test_key_value_cache_accessors(self, tiny_attention_config, sample_kv): - """Test key_cache and value_cache accessors.""" - cache = Apriel2Cache(tiny_attention_config) - cache.update(*sample_kv, layer_idx=0) - - assert cache.key_cache[0] is not None - assert cache.value_cache[0] is not None - torch.testing.assert_close(cache.key_cache[0], sample_kv[0]) - - -class TestSlidingWindowAttention: - """Test sliding window attention behavior.""" - - def test_initial_within_window(self, swa_config): - """Initial sequence within window is kept.""" - cache = Apriel2Cache(swa_config) - key = torch.randn(2, 4, 5, 16) # seq=5 < window=8 - value = torch.randn(2, 4, 5, 16) - - cache.update(key, value, layer_idx=0) - - assert cache.get_seq_length(0) == 5 - - def test_initial_exceeds_window(self, swa_config): - """Initial sequence > window is truncated to last window tokens.""" - cache = Apriel2Cache(swa_config) - key = torch.arange(12).float().view(1, 1, 12, 1).expand(2, 4, 12, 16) - value = key.clone() - - k_out, v_out = cache.update(key, value, layer_idx=0) - - assert cache.get_seq_length(0) == 8 - # Should keep tokens 4-11 (last 8) - assert k_out[0, 0, 0, 0].item() == 4.0 - - def test_single_token_roll_path(self, swa_config): - """Single token decode with full window uses efficient roll.""" - cache = Apriel2Cache(swa_config) - - # Fill window exactly - key1 = torch.arange(8).float().view(1, 1, 8, 1).expand(2, 4, 8, 16) - cache.update(key1, key1.clone(), layer_idx=0) - - # Decode single token - key2 = torch.full((2, 4, 1, 16), 8.0) - k_out, _ = cache.update(key2, key2.clone(), layer_idx=0) - - assert cache.get_seq_length(0) == 8 - assert k_out[0, 0, 0, 0].item() == 1.0 # Token 0 rolled out - assert k_out[0, 0, 7, 0].item() == 8.0 # New token at end - - def test_multi_token_cat_slice_path(self, swa_config): - """Multiple tokens use cat+slice path.""" - cache = Apriel2Cache(swa_config) - - # Fill window - key1 = torch.randn(2, 4, 8, 16) - cache.update(key1, key1.clone(), layer_idx=0) - - # Add 3 tokens - key2 = torch.randn(2, 4, 3, 16) - k_out, _ = cache.update(key2, key2.clone(), layer_idx=0) - - assert cache.get_seq_length(0) == 8 - torch.testing.assert_close(k_out[..., -3:, :], key2) - - def test_partial_then_fill_then_overflow(self, swa_config): - """Progressive filling: partial → full → overflow.""" - cache = Apriel2Cache(swa_config) - - cache.update(torch.randn(2, 4, 5, 16), torch.randn(2, 4, 5, 16), layer_idx=0) - assert cache.get_seq_length(0) == 5 - - cache.update(torch.randn(2, 4, 3, 16), torch.randn(2, 4, 3, 16), layer_idx=0) - assert cache.get_seq_length(0) == 8 - - cache.update(torch.randn(2, 4, 2, 16), torch.randn(2, 4, 2, 16), layer_idx=0) - assert cache.get_seq_length(0) == 8 - - def test_contiguous_output(self, swa_config): - """Outputs are contiguous after windowing.""" - cache = Apriel2Cache(swa_config) - - cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0) - cache.update(torch.randn(2, 4, 5, 16), torch.randn(2, 4, 5, 16), layer_idx=0) - - assert cache.layers[0].key.is_contiguous() - assert cache.layers[0].value.is_contiguous() - - -# ============================================================================= -# SECTION 3: SSM CACHE OPERATIONS -# ============================================================================= - - -class TestSSMCacheBasics: - """Test basic SSM cache operations.""" - - def test_conv_states_accessor(self, ssm_config, sample_conv_single): - """Test conv_states accessor.""" - cache = Apriel2Cache(ssm_config) - - cache.conv_states[0] = sample_conv_single - torch.testing.assert_close(cache.conv_states[0], sample_conv_single) - - def test_recurrent_states_accessor(self, ssm_config, sample_recurrent): - """Test recurrent_states accessor.""" - cache = Apriel2Cache(ssm_config) - - cache.recurrent_states[0] = sample_recurrent - torch.testing.assert_close(cache.recurrent_states[0], sample_recurrent) - - def test_ssm_seq_length_always_zero(self, ssm_config, sample_conv_single): - """get_seq_length returns 0 for SSM (no KV cache).""" - cache = Apriel2Cache(ssm_config) - cache.conv_states[0] = sample_conv_single - - assert cache.get_seq_length(0) == 0 - - -class TestKDACache: - """Test KDA-specific cache operations with tuple conv states.""" - - def test_tuple_conv_storage(self, kda_config, sample_conv_tuple): - """KDA stores tuple conv states.""" - cache = Apriel2Cache(kda_config) - - cache.conv_states[0] = sample_conv_tuple - - assert isinstance(cache.conv_states[0], tuple) - assert len(cache.conv_states[0]) == 3 - for i in range(3): - torch.testing.assert_close(cache.conv_states[0][i], sample_conv_tuple[i]) - - def test_tuple_with_recurrent(self, kda_config, sample_conv_tuple, sample_recurrent): - """KDA can have both tuple conv and recurrent states.""" - cache = Apriel2Cache(kda_config) - - cache.conv_states[0] = sample_conv_tuple - cache.recurrent_states[0] = sample_recurrent - - assert isinstance(cache.conv_states[0], tuple) - assert cache.recurrent_states[0] is not None - - def test_has_previous_state_detects_tuple(self, kda_config, sample_conv_tuple): - """has_previous_state works with tuple conv states.""" - cache = Apriel2Cache(kda_config) - - assert cache.has_previous_state == False - cache.conv_states[0] = sample_conv_tuple - assert cache.has_previous_state == True - - -# ============================================================================= -# SECTION 4: STOCHASTIC ROUTING -# ============================================================================= - - -class TestStochasticRouting: - """Test stochastic mixer cache routing.""" - - def test_set_active_mixer(self, stochastic_config): - """set_active_mixer sets the pointer.""" - cache = Apriel2Cache(stochastic_config) - - cache.set_active_mixer(1, "attention") - assert cache.active_mixers[1] == "attention" - - cache.set_active_mixer(1, "mamba") - assert cache.active_mixers[1] == "mamba" - - def test_operations_route_to_active(self, stochastic_config, sample_kv): - """Operations route to currently active mixer.""" - cache = Apriel2Cache(stochastic_config) - - cache.set_active_mixer(1, "attention") - cache.update(*sample_kv, layer_idx=1) - attn_len = cache.get_seq_length(1) - - cache.set_active_mixer(1, "mamba") - mamba_len = cache.get_seq_length(1) - - assert attn_len == 10 - assert mamba_len == 0 # Mamba cache is separate and empty - - def test_each_mixer_independent_cache(self, stochastic_config, sample_kv, sample_conv_single): - """Each mixer maintains independent cache.""" - cache = Apriel2Cache(stochastic_config) - - # Fill attention cache - cache.set_active_mixer(1, "attention") - cache.update(*sample_kv, layer_idx=1) - - # Fill mamba cache - cache.set_active_mixer(1, "mamba") - cache.conv_states[1] = sample_conv_single - - # Both preserved - cache.set_active_mixer(1, "attention") - assert cache.get_seq_length(1) == 10 - - cache.set_active_mixer(1, "mamba") - torch.testing.assert_close(cache.conv_states[1], sample_conv_single) - - -class TestMixerSwitching: - """Test behavior when switching between mixers mid-generation.""" - - def test_switch_preserves_previous_state(self, stochastic_config, sample_kv): - """Switching mixers preserves previous mixer's state.""" - cache = Apriel2Cache(stochastic_config) - - cache.set_active_mixer(1, "attention") - cache.update(*sample_kv, layer_idx=1) - original_key = cache.layers[1]["attention"].key.clone() - - # Switch to mamba, do something - cache.set_active_mixer(1, "mamba") - cache.conv_states[1] = torch.randn(2, 128, 4) - - # Switch back - attention unchanged - cache.set_active_mixer(1, "attention") - torch.testing.assert_close(cache.layers[1]["attention"].key, original_key) - - def test_switch_does_not_copy_state(self, stochastic_config, sample_kv): - """Switching does NOT copy state between mixers.""" - cache = Apriel2Cache(stochastic_config) - - # Fill attention with 10 tokens - cache.set_active_mixer(1, "attention") - cache.update(*sample_kv, layer_idx=1) - - # Switch to mamba - it has NO history from attention - cache.set_active_mixer(1, "mamba") - assert cache.conv_states[1] is None - assert cache.recurrent_states[1] is None - - def test_has_previous_state_checks_all_sub_caches(self, stochastic_config): - """has_previous_state checks ALL sub-caches, not just active.""" - cache = Apriel2Cache(stochastic_config) - - cache.set_active_mixer(1, "mamba") - cache.conv_states[1] = torch.randn(2, 128, 4) - - # Even if we switch away, has_previous_state still detects it - cache.set_active_mixer(1, "attention") - assert cache.has_previous_state == True - - -class TestAllMixerTypes: - """Test cache isolation across all 5 mixer types.""" - - def test_all_five_mixer_types_isolated(self, all_mixers_config): - """All 5 mixer types maintain isolated caches.""" - cache = Apriel2Cache(all_mixers_config) - layer_idx = 1 # Stochastic layer - - # Fill each mixer's cache - cache.set_active_mixer(layer_idx, "attention") - attn_kv = (torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16)) - cache.update(*attn_kv, layer_idx=layer_idx) - - cache.set_active_mixer(layer_idx, "swa") - swa_kv = (torch.randn(2, 4, 5, 16), torch.randn(2, 4, 5, 16)) - cache.update(*swa_kv, layer_idx=layer_idx) - - cache.set_active_mixer(layer_idx, "mamba") - mamba_conv = torch.randn(2, 128, 4) - cache.conv_states[layer_idx] = mamba_conv - - cache.set_active_mixer(layer_idx, "gdn") - gdn_conv = torch.randn(2, 64, 3) - cache.conv_states[layer_idx] = gdn_conv - - cache.set_active_mixer(layer_idx, "kda") - kda_conv = (torch.randn(2, 64, 3), torch.randn(2, 64, 3), torch.randn(2, 64, 3)) - cache.conv_states[layer_idx] = kda_conv - - # Verify all preserved - cache.set_active_mixer(layer_idx, "attention") - assert cache.get_seq_length(layer_idx) == 10 - - cache.set_active_mixer(layer_idx, "swa") - assert cache.get_seq_length(layer_idx) == 5 - - cache.set_active_mixer(layer_idx, "mamba") - torch.testing.assert_close(cache.conv_states[layer_idx], mamba_conv) - - cache.set_active_mixer(layer_idx, "gdn") - torch.testing.assert_close(cache.conv_states[layer_idx], gdn_conv) - - cache.set_active_mixer(layer_idx, "kda") - assert isinstance(cache.conv_states[layer_idx], tuple) - - -# ============================================================================= -# SECTION 5: CACHE INVALIDATION -# ============================================================================= - - -class TestCacheInvalidation: - """Test cache invalidation and reset semantics. - - Key principle: Each mixer maintains independent state. To invalidate: - - reset() clears ALL caches across ALL layers and mixers - - There is no per-mixer reset (by design - each mixer is independent) - """ - - def test_reset_clears_attention(self, tiny_attention_config, sample_kv): - """reset() clears attention cache.""" - cache = Apriel2Cache(tiny_attention_config) - cache.update(*sample_kv, layer_idx=0) - - cache.reset() - - assert cache.is_initialized == False - assert cache.get_seq_length(0) == 0 - - def test_reset_clears_ssm(self, ssm_config, sample_conv_single, sample_recurrent): - """reset() clears SSM cache.""" - cache = Apriel2Cache(ssm_config) - cache.conv_states[0] = sample_conv_single - cache.recurrent_states[0] = sample_recurrent - - cache.reset() - - assert cache.has_previous_state == False - assert cache.conv_states[0] is None - assert cache.recurrent_states[0] is None - - def test_reset_clears_kda_tuple(self, kda_config, sample_conv_tuple): - """reset() clears KDA tuple conv states.""" - cache = Apriel2Cache(kda_config) - cache.conv_states[0] = sample_conv_tuple - - cache.reset() - - assert cache.conv_states[0] is None - - def test_reset_clears_all_stochastic_mixers(self, all_mixers_config): - """reset() clears ALL mixer caches in stochastic layer.""" - cache = Apriel2Cache(all_mixers_config) - layer_idx = 1 - - # Fill all mixers - cache.set_active_mixer(layer_idx, "attention") - cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=layer_idx) - - cache.set_active_mixer(layer_idx, "mamba") - cache.conv_states[layer_idx] = torch.randn(2, 128, 4) - - cache.set_active_mixer(layer_idx, "kda") - cache.conv_states[layer_idx] = (torch.randn(2, 64, 3),) * 3 - - cache.reset() - - # All cleared - assert cache.layers[layer_idx]["attention"].key is None - assert cache.layers[layer_idx]["mamba"].conv is None - assert cache.layers[layer_idx]["kda"].conv is None - - def test_crop_truncates_attention(self, tiny_attention_config, sample_kv): - """crop() truncates attention cache to max_length.""" - cache = Apriel2Cache(tiny_attention_config) - cache.update(*sample_kv, layer_idx=0) - - cache.crop(5) - - assert cache.get_seq_length(0) == 5 - - def test_crop_affects_all_layers(self, tiny_attention_config, sample_kv): - """crop() affects all layers.""" - cache = Apriel2Cache(tiny_attention_config) - cache.update(*sample_kv, layer_idx=0) - cache.update(*sample_kv, layer_idx=1) - - cache.crop(3) - - assert cache.get_seq_length(0) == 3 - assert cache.get_seq_length(1) == 3 - - def test_crop_ignores_ssm(self, ssm_config, sample_conv_single): - """crop() only affects attention, not SSM.""" - cache = Apriel2Cache(ssm_config) - cache.conv_states[0] = sample_conv_single - - cache.crop(5) # Should not crash - - # Conv state unchanged - torch.testing.assert_close(cache.conv_states[0], sample_conv_single) - - -# ============================================================================= -# SECTION 6: BEAM SEARCH OPERATIONS -# ============================================================================= - - -class TestBatchRepeatInterleave: - """Test batch_repeat_interleave for beam search expansion.""" - - def test_repeat_attention(self, tiny_attention_config, sample_kv): - """Repeat attention cache for beam search.""" - cache = Apriel2Cache(tiny_attention_config) - cache.update(*sample_kv, layer_idx=0) - - cache.batch_repeat_interleave(3) - - assert cache.max_batch_size == 6 # 2 * 3 - - def test_repeat_ssm(self, ssm_config, sample_conv_single, sample_recurrent): - """Repeat SSM cache for beam search.""" - cache = Apriel2Cache(ssm_config) - cache.conv_states[0] = sample_conv_single - cache.recurrent_states[0] = sample_recurrent - - cache.batch_repeat_interleave(4) - - assert cache.conv_states[0].shape[0] == 8 # 2 * 4 - assert cache.recurrent_states[0].shape[0] == 8 - - def test_repeat_kda_tuple(self, kda_config, sample_conv_tuple): - """Repeat KDA tuple conv states.""" - cache = Apriel2Cache(kda_config) - cache.conv_states[0] = sample_conv_tuple - - cache.batch_repeat_interleave(3) - - for c in cache.conv_states[0]: - assert c.shape[0] == 6 - - def test_repeat_stochastic_all_mixers(self, all_mixers_config): - """Repeat all mixer caches in stochastic layer.""" - cache = Apriel2Cache(all_mixers_config) - layer_idx = 1 - - cache.set_active_mixer(layer_idx, "attention") - cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=layer_idx) - - cache.set_active_mixer(layer_idx, "mamba") - cache.conv_states[layer_idx] = torch.randn(2, 128, 4) - - cache.batch_repeat_interleave(2) - - cache.set_active_mixer(layer_idx, "attention") - assert cache.layers[layer_idx]["attention"].key.shape[0] == 4 - - cache.set_active_mixer(layer_idx, "mamba") - assert cache.conv_states[layer_idx].shape[0] == 4 - - def test_repeat_skips_none(self, tiny_attention_config): - """Repeat gracefully skips None caches.""" - cache = Apriel2Cache(tiny_attention_config) - # Don't fill anything - - cache.batch_repeat_interleave(3) # Should not crash - - assert cache.max_batch_size is None - - -class TestReorderCache: - """Test reorder_cache for beam search hypothesis selection.""" - - def test_reorder_attention(self, tiny_attention_config, sample_kv): - """Reorder attention cache.""" - cache = Apriel2Cache(tiny_attention_config) - key, value = sample_kv - # Make batches distinguishable - key = torch.arange(2).float().view(2, 1, 1, 1).expand(2, 4, 10, 16) - cache.update(key, key.clone(), layer_idx=0) - - beam_idx = torch.tensor([1, 0]) - cache.reorder_cache(beam_idx) - - assert cache.layers[0].key[0, 0, 0, 0].item() == 1.0 - assert cache.layers[0].key[1, 0, 0, 0].item() == 0.0 - - def test_reorder_ssm(self, ssm_config): - """Reorder SSM cache.""" - cache = Apriel2Cache(ssm_config) - conv = torch.arange(2).float().view(2, 1, 1).expand(2, 128, 4) - cache.conv_states[0] = conv.clone() - - beam_idx = torch.tensor([1, 0]) - cache.reorder_cache(beam_idx) - - assert cache.conv_states[0][0, 0, 0].item() == 1.0 - - def test_reorder_kda_tuple(self, kda_config): - """Reorder KDA tuple conv states.""" - cache = Apriel2Cache(kda_config) - conv_q = torch.arange(2).float().view(2, 1, 1).expand(2, 64, 3) - cache.conv_states[0] = (conv_q.clone(), conv_q.clone(), conv_q.clone()) - - beam_idx = torch.tensor([1, 0]) - cache.reorder_cache(beam_idx) - - for c in cache.conv_states[0]: - assert c[0, 0, 0].item() == 1.0 - - -class TestBatchSelectIndices: - """Test batch_select_indices for beam selection.""" - - def test_select_attention(self, tiny_attention_config, sample_kv): - """Select subset of attention cache.""" - cache = Apriel2Cache(tiny_attention_config) - key = torch.arange(4).float().view(4, 1, 1, 1).expand(4, 4, 10, 16) - cache.update(key, key.clone(), layer_idx=0) - - indices = torch.tensor([0, 3]) - cache.batch_select_indices(indices) - - assert cache.max_batch_size == 2 - assert cache.layers[0].key[0, 0, 0, 0].item() == 0.0 - assert cache.layers[0].key[1, 0, 0, 0].item() == 3.0 - - def test_select_kda_tuple(self, kda_config): - """Select subset of KDA tuple conv states.""" - cache = Apriel2Cache(kda_config) - conv = tuple(torch.arange(4).float().view(4, 1, 1).expand(4, 64, 3).clone() for _ in range(3)) - cache.conv_states[0] = conv - - indices = torch.tensor([1, 2]) - cache.batch_select_indices(indices) - - for c in cache.conv_states[0]: - assert c.shape[0] == 2 - assert c[0, 0, 0].item() == 1.0 - - -# ============================================================================= -# SECTION 7: HUGGINGFACE INTEGRATION -# ============================================================================= - - -class TestGetMaskSizes: - """Test get_mask_sizes() for attention mask computation.""" - - def test_empty_cache(self, tiny_attention_config): - """Mask sizes with empty cache.""" - cache = Apriel2Cache(tiny_attention_config) - cache_position = torch.arange(10) - - kv_length, kv_offset = cache.get_mask_sizes(cache_position, layer_idx=0) - - assert kv_length == 10 - assert kv_offset == 0 - - def test_with_cached_tokens(self, tiny_attention_config, sample_kv): - """Mask sizes with cached tokens.""" - cache = Apriel2Cache(tiny_attention_config) - cache.update(*sample_kv, layer_idx=0) # 10 tokens - - cache_position = torch.arange(5) - kv_length, kv_offset = cache.get_mask_sizes(cache_position, layer_idx=0) - - assert kv_length == 15 # 10 + 5 - assert kv_offset == 10 - - def test_single_token_decode(self, tiny_attention_config, sample_kv): - """Mask sizes for single token decode.""" - cache = Apriel2Cache(tiny_attention_config) - cache.update(*sample_kv, layer_idx=0) - - cache_position = torch.arange(1) - kv_length, kv_offset = cache.get_mask_sizes(cache_position, layer_idx=0) - - assert kv_length == 11 - assert kv_offset == 10 - - def test_ssm_returns_query_only(self, ssm_config, sample_conv_single): - """SSM layers return query_length (no KV cache).""" - cache = Apriel2Cache(ssm_config) - cache.conv_states[0] = sample_conv_single - - cache_position = torch.arange(5) - kv_length, kv_offset = cache.get_mask_sizes(cache_position, layer_idx=0) - - assert kv_length == 5 - assert kv_offset == 0 - - -class TestCacheIndexing: - """Test cache[idx] indexing.""" - - def test_attention_returns_kv(self, tiny_attention_config, sample_kv): - """Indexing attention layer returns (key, value).""" - cache = Apriel2Cache(tiny_attention_config) - cache.update(*sample_kv, layer_idx=0) - - result = cache[0] - - assert isinstance(result, tuple) - torch.testing.assert_close(result[0], sample_kv[0]) - - def test_empty_returns_empty_tensors(self, tiny_attention_config): - """Indexing empty layer returns empty tensors.""" - cache = Apriel2Cache(tiny_attention_config) - - result = cache[0] - - assert result[0].numel() == 0 - assert result[1].numel() == 0 - - def test_ssm_returns_empty(self, ssm_config, sample_conv_single): - """Indexing SSM layer returns empty (no KV).""" - cache = Apriel2Cache(ssm_config) - cache.conv_states[0] = sample_conv_single - - result = cache[0] - - assert result[0].numel() == 0 - - def test_stochastic_attention_returns_kv(self, stochastic_config, sample_kv): - """Indexing stochastic with attention active returns KV.""" - cache = Apriel2Cache(stochastic_config) - cache.set_active_mixer(1, "attention") - cache.update(*sample_kv, layer_idx=1) - - result = cache[1] - - torch.testing.assert_close(result[0], sample_kv[0]) - - -# ============================================================================= -# SECTION 8: GENERATION PATTERNS -# ============================================================================= - - -class TestGenerationPatterns: - """Test real-world generation patterns.""" - - def test_prefill_then_decode(self, tiny_attention_config, sample_kv): - """Prefill with long prompt, then decode token-by-token.""" - cache = Apriel2Cache(tiny_attention_config) - cache.update(*sample_kv, layer_idx=0) # Prefill 10 tokens - - for _ in range(5): - new_kv = (torch.randn(2, 4, 1, 16), torch.randn(2, 4, 1, 16)) - cache.update(*new_kv, layer_idx=0) - - assert cache.get_seq_length(0) == 15 - - def test_crop_then_continue(self, tiny_attention_config, sample_kv): - """Crop old context, continue generation.""" - cache = Apriel2Cache(tiny_attention_config) - cache.update(*sample_kv, layer_idx=0) - cache.update(*sample_kv, layer_idx=0) # 20 tokens - - cache.crop(5) # Keep last 5 - cache.update(torch.randn(2, 4, 3, 16), torch.randn(2, 4, 3, 16), layer_idx=0) - - assert cache.get_seq_length(0) == 8 - - def test_reset_between_generations(self, tiny_attention_config, sample_kv): - """Reset between independent generations.""" - cache = Apriel2Cache(tiny_attention_config) - - # First generation - cache.update(*sample_kv, layer_idx=0) - assert cache.is_initialized == True - - # Reset - cache.reset() - assert cache.is_initialized == False - - # Second generation - cache.update(*sample_kv, layer_idx=0) - assert cache.get_seq_length(0) == 10 - - def test_multi_layer_consistency(self, tiny_attention_config, sample_kv): - """All layers updated consistently.""" - cache = Apriel2Cache(tiny_attention_config) - - for layer_idx in range(2): - cache.update(*sample_kv, layer_idx=layer_idx) - cache.update(torch.randn(2, 4, 1, 16), torch.randn(2, 4, 1, 16), layer_idx=layer_idx) - - for layer_idx in range(2): - assert cache.get_seq_length(layer_idx) == 11 - - -# ============================================================================= -# SECTION 9: ERROR HANDLING -# ============================================================================= - - -class TestErrorHandling: - """Test error conditions and guards.""" - - def test_stochastic_update_without_active_mixer(self, stochastic_config): - """update() on stochastic without active_mixer raises.""" - cache = Apriel2Cache(stochastic_config) - - with pytest.raises(RuntimeError, match="needs active_mixer set"): - cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=1) - - def test_stochastic_accessor_without_active_mixer(self, stochastic_config): - """Accessing stochastic cache without active_mixer raises.""" - cache = Apriel2Cache(stochastic_config) - - with pytest.raises(RuntimeError, match="requires set_active_mixer"): - _ = cache.conv_states[1] - - def test_accessor_error_lists_available_mixers(self, stochastic_config): - """Error message lists available mixers.""" - cache = Apriel2Cache(stochastic_config) - - with pytest.raises(RuntimeError, match="Available mixers:"): - _ = cache.key_cache[1] - - def test_invalid_mixer_name(self, stochastic_config): - """Invalid mixer name raises KeyError on access.""" - cache = Apriel2Cache(stochastic_config) - cache.set_active_mixer(1, "nonexistent") - - with pytest.raises(KeyError): - cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=1) - - def test_layer_idx_out_of_bounds(self, tiny_attention_config): - """Out-of-bounds layer_idx raises IndexError.""" - cache = Apriel2Cache(tiny_attention_config) - - with pytest.raises(IndexError): - cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=999) - - -# ============================================================================= -# SECTION 10: INTERNAL CLASSES -# ============================================================================= - - -class TestAttentionCacheInternal: - """Test internal _AttentionCache class directly.""" - - def test_unbounded_growth(self): - """No window allows unbounded growth.""" - cache = _AttentionCache(window=None) - - for _ in range(10): - cache.update(torch.randn(2, 4, 100, 16), torch.randn(2, 4, 100, 16)) - - assert cache.key.shape[-2] == 1000 - - def test_window_enforced(self): - """Window caps cache size.""" - cache = _AttentionCache(window=50) - - for _ in range(10): - cache.update(torch.randn(2, 4, 100, 16), torch.randn(2, 4, 100, 16)) - - assert cache.key.shape[-2] == 50 - - -class TestSSMCacheInternal: - """Test internal _SSMCache class directly.""" - - def test_initial_none(self): - """Initial states are None.""" - cache = _SSMCache() - - assert cache.conv is None - assert cache.recurrent is None - - def test_stores_tuple(self): - """Can store tuple (for KDA).""" - cache = _SSMCache() - cache.conv = (torch.randn(2, 64, 3),) * 3 - - assert isinstance(cache.conv, tuple) diff --git a/fast_llm_external_models/tests/test_apriel2/test_cache_apriel2_specific.py b/fast_llm_external_models/tests/test_apriel2/test_cache_apriel2_specific.py new file mode 100644 index 000000000..e0e4db2d3 --- /dev/null +++ b/fast_llm_external_models/tests/test_apriel2/test_cache_apriel2_specific.py @@ -0,0 +1,342 @@ +"""Tests for Apriel2-specific cache behaviors with no HuggingFace equivalent. + +This module tests features unique to Apriel2Cache that cannot be validated +against upstream HF implementations: + +1. Stochastic mixer routing (switching between attention/SSM per layer) +2. Multi-mixer layer support +3. Error handling and guard rails +4. Beam search operations (batch_repeat, reorder, select) +5. Crop operation + +Fixtures used from conftest.py: + - stochastic_config: Stochastic mixer config with attention and mamba + - attention_config: Pure attention config + - ssm_config: Pure SSM config +""" + +import pytest +import torch + +from fast_llm_external_models.apriel2.cache import Apriel2Cache, _AttentionCache, _SSMCache + + +# ============================================================================= +# STOCHASTIC MIXER ROUTING +# ============================================================================= + + +class TestStochasticMixerRouting: + """Test routing operations to correct sub-cache in stochastic layers.""" + + def test_set_active_mixer(self, stochastic_config): + """set_active_mixer updates routing for layer.""" + cache = Apriel2Cache(stochastic_config) + + cache.set_active_mixer(0, "attention") + assert cache.active_mixers[0] == "attention" + + cache.set_active_mixer(0, "mamba") + assert cache.active_mixers[0] == "mamba" + + def test_update_routes_to_active_mixer(self, stochastic_config): + """update() stores in correct sub-cache based on active_mixer.""" + cache = Apriel2Cache(stochastic_config) + + # Route to attention + cache.set_active_mixer(0, "attention") + cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0) + + # Attention sub-cache should have data + assert cache.layers[0]["attention"].key is not None + # Mamba sub-cache should be empty + assert cache.layers[0]["mamba"].conv is None + + def test_each_mixer_has_independent_cache(self, stochastic_config): + """Each mixer in a stochastic layer has its own independent state.""" + cache = Apriel2Cache(stochastic_config) + + # Store in attention + cache.set_active_mixer(0, "attention") + cache.update(torch.randn(2, 4, 5, 16), torch.randn(2, 4, 5, 16), layer_idx=0) + + # Switch to mamba and store + cache.set_active_mixer(0, "mamba") + cache.layers[0]["mamba"].conv = torch.randn(2, 64, 4) + + # Attention data should be unchanged + assert cache.layers[0]["attention"].cumulative_length == 5 + + def test_switching_preserves_all_states(self, stochastic_config): + """Switching active_mixer doesn't clear other mixer's state.""" + cache = Apriel2Cache(stochastic_config) + + # Build up attention state + cache.set_active_mixer(0, "attention") + cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0) + attn_key = cache.layers[0]["attention"].key.clone() + + # Switch to mamba + cache.set_active_mixer(0, "mamba") + + # Attention state preserved + torch.testing.assert_close(cache.layers[0]["attention"].key, attn_key) + + +# ============================================================================= +# ERROR HANDLING +# ============================================================================= + + +class TestErrorHandling: + """Test guard rails and error messages.""" + + def test_update_without_active_mixer_raises(self, stochastic_config): + """update() on stochastic layer without active_mixer raises RuntimeError.""" + cache = Apriel2Cache(stochastic_config) + + with pytest.raises(RuntimeError, match="needs active_mixer set"): + cache.update(torch.randn(2, 4, 5, 16), torch.randn(2, 4, 5, 16), layer_idx=0) + + def test_accessor_without_active_mixer_raises(self, stochastic_config): + """Accessing key_cache/value_cache without active_mixer raises RuntimeError.""" + cache = Apriel2Cache(stochastic_config) + + with pytest.raises(RuntimeError, match="requires set_active_mixer"): + _ = cache.key_cache[0] + + def test_error_message_lists_available_mixers(self, stochastic_config): + """Error message includes list of available mixers.""" + cache = Apriel2Cache(stochastic_config) + + with pytest.raises(RuntimeError, match="attention.*mamba|mamba.*attention"): + _ = cache.key_cache[0] + + +# ============================================================================= +# BEAM SEARCH OPERATIONS +# ============================================================================= + + +class TestBeamSearchOperations: + """Test batch manipulation for beam search.""" + + def test_batch_repeat_interleave_attention(self, attention_config): + """batch_repeat_interleave expands batch dimension.""" + cache = Apriel2Cache(attention_config) + cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0) + + cache.batch_repeat_interleave(3) + + assert cache.layers[0].key.shape[0] == 6 # 2 * 3 + + def test_batch_repeat_interleave_ssm(self, ssm_config): + """batch_repeat_interleave works for SSM caches.""" + cache = Apriel2Cache(ssm_config) + cache.layers[0].conv = torch.randn(2, 64, 4) + + cache.batch_repeat_interleave(3) + + assert cache.layers[0].conv.shape[0] == 6 + + def test_batch_repeat_interleave_kda_tuple(self, ssm_config): + """batch_repeat_interleave handles KDA tuple conv states.""" + cache = Apriel2Cache(ssm_config) + cache.layers[0].conv = (torch.randn(2, 64, 4),) * 3 + + cache.batch_repeat_interleave(3) + + assert cache.layers[0].conv[0].shape[0] == 6 + + def test_reorder_cache_attention(self, attention_config): + """reorder_cache reorders batch dimension.""" + cache = Apriel2Cache(attention_config) + k = torch.arange(4).float().view(4, 1, 1, 1).expand(4, 4, 10, 16) + cache.update(k, k.clone(), layer_idx=0) + + beam_idx = torch.tensor([3, 2, 1, 0]) + cache.reorder_cache(beam_idx) + + # Check reordering + assert cache.layers[0].key[0, 0, 0, 0].item() == 3.0 + assert cache.layers[0].key[3, 0, 0, 0].item() == 0.0 + + def test_batch_select_indices(self, attention_config): + """batch_select_indices selects subset of batch.""" + cache = Apriel2Cache(attention_config) + cache.update(torch.randn(4, 4, 10, 16), torch.randn(4, 4, 10, 16), layer_idx=0) + + indices = torch.tensor([0, 2]) + cache.batch_select_indices(indices) + + assert cache.layers[0].key.shape[0] == 2 + + def test_reorder_cache_ssm_tuple(self, ssm_config): + """reorder_cache handles KDA tuple conv states.""" + cache = Apriel2Cache(ssm_config) + # Create distinguishable tensors for each batch position + conv0 = torch.full((1, 64, 4), 0.0) + conv1 = torch.full((1, 64, 4), 1.0) + conv2 = torch.full((1, 64, 4), 2.0) + cache.layers[0].conv = ( + torch.cat([conv0, conv1, conv2], dim=0), + torch.cat([conv0, conv1, conv2], dim=0), + torch.cat([conv0, conv1, conv2], dim=0), + ) + + beam_idx = torch.tensor([2, 1, 0]) + cache.reorder_cache(beam_idx) + + # Check reordering: batch[0] should now have value 2.0 + assert cache.layers[0].conv[0][0, 0, 0].item() == 2.0 + assert cache.layers[0].conv[0][2, 0, 0].item() == 0.0 + + def test_batch_select_indices_ssm_tuple(self, ssm_config): + """batch_select_indices handles KDA tuple conv states.""" + cache = Apriel2Cache(ssm_config) + cache.layers[0].conv = (torch.randn(4, 64, 4),) * 3 + + indices = torch.tensor([0, 2]) + cache.batch_select_indices(indices) + + assert cache.layers[0].conv[0].shape[0] == 2 + assert cache.layers[0].conv[1].shape[0] == 2 + assert cache.layers[0].conv[2].shape[0] == 2 + + +# ============================================================================= +# CROP OPERATION +# ============================================================================= + + +class TestCropOperation: + """Test cache truncation.""" + + def test_crop_truncates_attention(self, attention_config): + """crop() truncates attention cache.""" + cache = Apriel2Cache(attention_config) + cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0) + + cache.crop(5) + + assert cache.layers[0].key.shape[-2] == 5 + assert cache.get_seq_length(0) == 5 + + def test_crop_affects_all_layers(self, attention_config): + """crop() affects all layers.""" + cache = Apriel2Cache(attention_config) + cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0) + cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=1) + + cache.crop(3) + + assert cache.layers[0].key.shape[-2] == 3 + assert cache.layers[1].key.shape[-2] == 3 + + def test_crop_ignores_ssm(self, ssm_config): + """crop() doesn't affect SSM caches (they don't have seq dimension).""" + cache = Apriel2Cache(ssm_config) + cache.layers[0].conv = torch.randn(2, 64, 4) + + # Should not raise + cache.crop(5) + + # SSM state unchanged + assert cache.layers[0].conv.shape == (2, 64, 4) + + +# ============================================================================= +# CACHE PROPERTIES +# ============================================================================= + + +class TestCacheProperties: + """Test cache property methods.""" + + def test_is_initialized_attention(self, attention_config): + """is_initialized True after update.""" + cache = Apriel2Cache(attention_config) + assert not cache.is_initialized + + cache.update(torch.randn(2, 4, 5, 16), torch.randn(2, 4, 5, 16), layer_idx=0) + assert cache.is_initialized + + def test_is_initialized_ssm(self, ssm_config): + """is_initialized True after setting conv state.""" + cache = Apriel2Cache(ssm_config) + assert not cache.is_initialized + + cache.layers[0].conv = torch.randn(2, 64, 4) + assert cache.is_initialized + + def test_has_previous_state_ssm_only(self, ssm_config): + """has_previous_state checks SSM conv states.""" + cache = Apriel2Cache(ssm_config) + assert not cache.has_previous_state + + cache.layers[0].conv = torch.randn(2, 64, 4) + assert cache.has_previous_state + + def test_has_previous_state_ignores_attention(self, attention_config): + """has_previous_state ignores attention caches.""" + cache = Apriel2Cache(attention_config) + cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0) + + # Attention-only cache returns False for has_previous_state + assert not cache.has_previous_state + + def test_reset_clears_ssm_states(self, ssm_config): + """reset() clears SSM conv and recurrent states.""" + cache = Apriel2Cache(ssm_config) + cache.layers[0].conv = torch.randn(2, 64, 4) + cache.layers[0].recurrent = torch.randn(2, 64, 16) + + cache.reset() + + assert cache.layers[0].conv is None + assert cache.layers[0].recurrent is None + + def test_max_batch_size_from_ssm_tuple(self, ssm_config): + """max_batch_size works with KDA tuple conv states.""" + cache = Apriel2Cache(ssm_config) + cache.layers[0].conv = (torch.randn(3, 64, 4),) * 3 + + assert cache.max_batch_size == 3 + + def test_max_batch_size(self, attention_config): + """max_batch_size returns batch dimension.""" + cache = Apriel2Cache(attention_config) + cache.update(torch.randn(3, 4, 10, 16), torch.randn(3, 4, 10, 16), layer_idx=0) + + assert cache.max_batch_size == 3 + + def test_len_returns_num_layers(self, attention_config): + """__len__ returns number of layers.""" + cache = Apriel2Cache(attention_config) + assert len(cache) == 2 + + +# ============================================================================= +# INDEXING +# ============================================================================= + + +class TestCacheIndexing: + """Test __getitem__ for HF compatibility.""" + + def test_getitem_returns_kv_tuple(self, attention_config): + """cache[idx] returns (key, value) tuple.""" + cache = Apriel2Cache(attention_config) + cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0) + + k, v = cache[0] + assert k.shape == (2, 4, 10, 16) + assert v.shape == (2, 4, 10, 16) + + def test_getitem_empty_returns_empty_tensors(self, attention_config): + """cache[idx] on empty cache returns empty tensors.""" + cache = Apriel2Cache(attention_config) + + k, v = cache[0] + assert k.numel() == 0 + assert v.numel() == 0 diff --git a/fast_llm_external_models/tests/test_apriel2/test_cache_contracts.py b/fast_llm_external_models/tests/test_apriel2/test_cache_contracts.py new file mode 100644 index 000000000..7c38f75b7 --- /dev/null +++ b/fast_llm_external_models/tests/test_apriel2/test_cache_contracts.py @@ -0,0 +1,592 @@ +"""Contract tests for Apriel2Cache against HuggingFace cache implementations. + +This module tests that Apriel2Cache components behave equivalently to their +HuggingFace counterparts. This ensures compatibility with HF's generation +infrastructure (mask creation, beam search, etc.). + +Mapping: + Apriel2 Component HuggingFace Equivalent + ----------------- ---------------------- + _AttentionCache (no window) -> DynamicLayer + _AttentionCache (window) -> DynamicSlidingWindowLayer + _SSMCache -> MambaCache (different interface, same concept) + +Apriel2-specific features (stochastic routing, multi-mixer layers) are tested +separately in test_cache_apriel2_specific.py since they have no HF equivalent. + +Fixtures used from conftest.py: + - batch_size, num_heads, head_dim: Tensor dimensions + - hf_dynamic_layer: HuggingFace DynamicLayer + - hf_sliding_layer: HuggingFace DynamicSlidingWindowLayer (parameterized by window_size) + - apriel_attention_cache: Apriel2 _AttentionCache (no window) + - apriel_sliding_cache: Apriel2 _AttentionCache (with window, parameterized) + - window_size: Parameterized window sizes [4, 8, 16, 32] + - attention_config, swa_config: Apriel2 configs +""" + +import pytest +import torch + +from fast_llm_external_models.apriel2.cache import _AttentionCache, _SSMCache, Apriel2Cache + + +# ============================================================================= +# SECTION 1: FULL ATTENTION - _AttentionCache vs DynamicLayer +# ============================================================================= + + +class TestFullAttentionContract: + """Test _AttentionCache (no window) matches HuggingFace DynamicLayer. + + DynamicLayer is the standard cache for full causal attention. + We test that our cache produces identical mask parameters. + """ + + # ------------------------------------------------------------------------- + # get_seq_length: Must match exactly for generation to work + # ------------------------------------------------------------------------- + + @pytest.mark.parametrize("seq_len", [1, 5, 10, 50, 100]) + def test_get_seq_length_after_prefill( + self, hf_dynamic_layer, apriel_attention_cache, batch_size, num_heads, head_dim, seq_len + ): + """After prefill, cumulative_length matches HF get_seq_length.""" + key = torch.randn(batch_size, num_heads, seq_len, head_dim) + value = torch.randn(batch_size, num_heads, seq_len, head_dim) + + hf_dynamic_layer.update(key.clone(), value.clone()) + apriel_attention_cache.update(key.clone(), value.clone()) + + assert apriel_attention_cache.cumulative_length == hf_dynamic_layer.get_seq_length() + + @pytest.mark.parametrize("prefill_len", [1, 5, 10]) + @pytest.mark.parametrize("decode_steps", [1, 5, 10, 20]) + def test_get_seq_length_during_decode( + self, hf_dynamic_layer, apriel_attention_cache, batch_size, num_heads, head_dim, prefill_len, decode_steps + ): + """During decode, cumulative_length tracks total tokens seen.""" + # Prefill + key = torch.randn(batch_size, num_heads, prefill_len, head_dim) + value = torch.randn(batch_size, num_heads, prefill_len, head_dim) + hf_dynamic_layer.update(key.clone(), value.clone()) + apriel_attention_cache.update(key.clone(), value.clone()) + + # Decode + for step in range(decode_steps): + key = torch.randn(batch_size, num_heads, 1, head_dim) + value = torch.randn(batch_size, num_heads, 1, head_dim) + hf_dynamic_layer.update(key.clone(), value.clone()) + apriel_attention_cache.update(key.clone(), value.clone()) + + assert apriel_attention_cache.cumulative_length == hf_dynamic_layer.get_seq_length(), ( + f"Mismatch at decode step {step}" + ) + + # ------------------------------------------------------------------------- + # get_mask_sizes: Verify HF behavior for documentation + # ------------------------------------------------------------------------- + + @pytest.mark.parametrize("prefill_len", [1, 5, 10]) + @pytest.mark.parametrize("decode_steps", [0, 1, 5, 10]) + def test_hf_mask_sizes_kv_length( + self, hf_dynamic_layer, apriel_attention_cache, batch_size, num_heads, head_dim, prefill_len, decode_steps + ): + """Document HF's kv_length behavior and verify cumulative_length tracks correctly. + + For full attention, kv_length = cumulative_length + query_length. + This test verifies our cache tracks tokens identically to HF. + """ + # Prefill + key = torch.randn(batch_size, num_heads, prefill_len, head_dim) + value = torch.randn(batch_size, num_heads, prefill_len, head_dim) + hf_dynamic_layer.update(key.clone(), value.clone()) + apriel_attention_cache.update(key.clone(), value.clone()) + + # Decode + for _ in range(decode_steps): + key = torch.randn(batch_size, num_heads, 1, head_dim) + value = torch.randn(batch_size, num_heads, 1, head_dim) + hf_dynamic_layer.update(key.clone(), value.clone()) + apriel_attention_cache.update(key.clone(), value.clone()) + + # Verify cumulative_length matches HF + assert apriel_attention_cache.cumulative_length == hf_dynamic_layer.get_seq_length() + + # Verify HF's kv_length follows the expected formula + cache_position = torch.arange(1) # Single token decode + hf_kv_len, hf_kv_offset = hf_dynamic_layer.get_mask_sizes(cache_position) + expected_kv_len = hf_dynamic_layer.get_seq_length() + cache_position.shape[0] + assert hf_kv_len == expected_kv_len + + def test_hf_kv_offset_always_zero(self, hf_dynamic_layer, batch_size, num_heads, head_dim): + """Document that HF DynamicLayer always returns kv_offset=0. + + For full attention, all cached KV pairs map to absolute positions + starting from 0, so kv_offset is always 0. + """ + # Add many tokens + for _ in range(20): + key = torch.randn(batch_size, num_heads, 5, head_dim) + value = torch.randn(batch_size, num_heads, 5, head_dim) + hf_dynamic_layer.update(key.clone(), value.clone()) + + cache_position = torch.arange(1) + _, hf_kv_offset = hf_dynamic_layer.get_mask_sizes(cache_position) + + assert hf_kv_offset == 0, "DynamicLayer always returns kv_offset=0" + + # ------------------------------------------------------------------------- + # update: Output shape and values must match + # ------------------------------------------------------------------------- + + @pytest.mark.parametrize("seq_len", [1, 5, 10]) + def test_update_returns_same_shape( + self, hf_dynamic_layer, apriel_attention_cache, batch_size, num_heads, head_dim, seq_len + ): + """update() returns tensors with matching shapes.""" + key = torch.randn(batch_size, num_heads, seq_len, head_dim) + value = torch.randn(batch_size, num_heads, seq_len, head_dim) + + hf_k, hf_v = hf_dynamic_layer.update(key.clone(), value.clone()) + apr_k, apr_v = apriel_attention_cache.update(key.clone(), value.clone()) + + assert hf_k.shape == apr_k.shape + assert hf_v.shape == apr_v.shape + + def test_update_concatenates_identically( + self, hf_dynamic_layer, apriel_attention_cache, batch_size, num_heads, head_dim + ): + """Multiple updates produce identical concatenated states.""" + # Use deterministic values for comparison + k1 = torch.arange(10).float().view(1, 1, 10, 1).expand(batch_size, num_heads, 10, head_dim) + v1 = k1.clone() + + hf_dynamic_layer.update(k1.clone(), v1.clone()) + apriel_attention_cache.update(k1.clone(), v1.clone()) + + k2 = torch.arange(10, 15).float().view(1, 1, 5, 1).expand(batch_size, num_heads, 5, head_dim) + v2 = k2.clone() + + hf_k, hf_v = hf_dynamic_layer.update(k2.clone(), v2.clone()) + apr_k, apr_v = apriel_attention_cache.update(k2.clone(), v2.clone()) + + torch.testing.assert_close(hf_k, apr_k) + torch.testing.assert_close(hf_v, apr_v) + + +# ============================================================================= +# SECTION 2: SLIDING WINDOW - _AttentionCache vs DynamicSlidingWindowLayer +# ============================================================================= + + +class TestSlidingWindowContract: + """Test _AttentionCache (with window) matches HuggingFace DynamicSlidingWindowLayer. + + DynamicSlidingWindowLayer is used for sliding window attention (e.g., Mistral). + Critical behaviors: + - cumulative_length tracks ALL tokens seen (not just cached) + - kv_offset increases once window is exceeded + - kv_length is capped at window size + + Uses fixtures from conftest.py: + - window_size: parameterized [4, 8, 16, 32] + - hf_sliding_layer: DynamicSlidingWindowLayer + - apriel_sliding_cache: _AttentionCache with window + """ + + # ------------------------------------------------------------------------- + # cumulative_length: Must track total tokens, not cached tokens + # ------------------------------------------------------------------------- + + @pytest.mark.parametrize("prefill_len", [1, 3, 5, 10, 20]) + def test_cumulative_length_matches_after_prefill( + self, hf_sliding_layer, apriel_sliding_cache, batch_size, num_heads, head_dim, prefill_len + ): + """cumulative_length matches HF get_seq_length after prefill.""" + key = torch.randn(batch_size, num_heads, prefill_len, head_dim) + value = torch.randn(batch_size, num_heads, prefill_len, head_dim) + + hf_sliding_layer.update(key.clone(), value.clone()) + apriel_sliding_cache.update(key.clone(), value.clone()) + + assert apriel_sliding_cache.cumulative_length == hf_sliding_layer.get_seq_length() + + def test_cumulative_length_continues_past_window( + self, hf_sliding_layer, apriel_sliding_cache, window_size, batch_size, num_heads, head_dim + ): + """cumulative_length keeps growing even after window is full.""" + total_tokens = window_size * 3 # Way past window + + for i in range(total_tokens): + key = torch.randn(batch_size, num_heads, 1, head_dim) + value = torch.randn(batch_size, num_heads, 1, head_dim) + + hf_sliding_layer.update(key.clone(), value.clone()) + apriel_sliding_cache.update(key.clone(), value.clone()) + + expected = i + 1 + assert apriel_sliding_cache.cumulative_length == expected + assert hf_sliding_layer.get_seq_length() == expected + + # ------------------------------------------------------------------------- + # get_mask_sizes: kv_offset must increase once window is exceeded + # ------------------------------------------------------------------------- + + def test_kv_offset_zero_before_window_full( + self, hf_sliding_layer, apriel_sliding_cache, window_size, batch_size, num_heads, head_dim + ): + """kv_offset is 0 while cumulative < window. + + Before the window is full, kv_offset should be 0 because all cached tokens + correspond to absolute positions starting from 0. + """ + # Add tokens up to window-1 + for i in range(window_size - 1): + key = torch.randn(batch_size, num_heads, 1, head_dim) + value = torch.randn(batch_size, num_heads, 1, head_dim) + + hf_sliding_layer.update(key.clone(), value.clone()) + apriel_sliding_cache.update(key.clone(), value.clone()) + + cache_position = torch.arange(1) + hf_kv_len, hf_kv_offset = hf_sliding_layer.get_mask_sizes(cache_position) + + # Verify HF returns 0 offset before window full + assert hf_kv_offset == 0, f"HF offset should be 0 at step {i}" + # Verify Apriel cache tracks cumulative correctly + assert apriel_sliding_cache.cumulative_length == i + 1 + + def test_kv_offset_increases_after_window_full( + self, hf_sliding_layer, apriel_sliding_cache, window_size, batch_size, num_heads, head_dim + ): + """kv_offset increases once cumulative >= window. + + Once the window is full, the cache discards oldest tokens. kv_offset tracks + which absolute position KV[0] corresponds to. + """ + # Fill to exactly window + for _ in range(window_size): + key = torch.randn(batch_size, num_heads, 1, head_dim) + value = torch.randn(batch_size, num_heads, 1, head_dim) + hf_sliding_layer.update(key.clone(), value.clone()) + apriel_sliding_cache.update(key.clone(), value.clone()) + + cache_position = torch.arange(1) + hf_kv_len, hf_kv_offset = hf_sliding_layer.get_mask_sizes(cache_position) + + # At window boundary, offset should be 1 + assert hf_kv_offset == 1, "HF offset should be 1 at window boundary" + assert apriel_sliding_cache.cumulative_length == window_size + + # Add more tokens and verify offset keeps increasing with HF + for i in range(5): + key = torch.randn(batch_size, num_heads, 1, head_dim) + value = torch.randn(batch_size, num_heads, 1, head_dim) + hf_sliding_layer.update(key.clone(), value.clone()) + apriel_sliding_cache.update(key.clone(), value.clone()) + + hf_kv_len, hf_kv_offset = hf_sliding_layer.get_mask_sizes(cache_position) + + expected_offset = i + 2 + assert hf_kv_offset == expected_offset + assert apriel_sliding_cache.cumulative_length == window_size + i + 1 + + def test_kv_length_capped_at_window( + self, hf_sliding_layer, apriel_sliding_cache, window_size, batch_size, num_heads, head_dim + ): + """kv_length is capped at window size once exceeded. + + For a query of length 1 after the window is full, kv_length = window + (window-1 cached tokens + 1 query token). + """ + # Way past window + for _ in range(window_size * 2): + key = torch.randn(batch_size, num_heads, 1, head_dim) + value = torch.randn(batch_size, num_heads, 1, head_dim) + hf_sliding_layer.update(key.clone(), value.clone()) + apriel_sliding_cache.update(key.clone(), value.clone()) + + cache_position = torch.arange(1) + hf_kv_len, _ = hf_sliding_layer.get_mask_sizes(cache_position) + + # HF returns window (window-1 cached + 1 query) + assert hf_kv_len == window_size + # Verify our cache tracked cumulative correctly + assert apriel_sliding_cache.cumulative_length == window_size * 2 + + # ------------------------------------------------------------------------- + # Full sequence length tracking through generation + # ------------------------------------------------------------------------- + + @pytest.mark.parametrize("prefill_len", [1, 3, 5, 10, 20]) + def test_cumulative_length_tracks_all_tokens( + self, hf_sliding_layer, apriel_sliding_cache, window_size, batch_size, num_heads, head_dim, prefill_len + ): + """cumulative_length tracks total tokens seen through prefill + decode. + + This is the foundation for correct mask size computation. We verify that + our _AttentionCache tracks tokens identically to HuggingFace's DynamicSlidingWindowLayer. + The actual get_mask_sizes computation is tested in TestApriel2CacheIntegration. + """ + # Prefill + key = torch.randn(batch_size, num_heads, prefill_len, head_dim) + value = torch.randn(batch_size, num_heads, prefill_len, head_dim) + hf_sliding_layer.update(key.clone(), value.clone()) + apriel_sliding_cache.update(key.clone(), value.clone()) + + assert apriel_sliding_cache.cumulative_length == hf_sliding_layer.get_seq_length() + + # Decode past window + for i in range(window_size + 10): + key = torch.randn(batch_size, num_heads, 1, head_dim) + value = torch.randn(batch_size, num_heads, 1, head_dim) + hf_sliding_layer.update(key.clone(), value.clone()) + apriel_sliding_cache.update(key.clone(), value.clone()) + + assert apriel_sliding_cache.cumulative_length == hf_sliding_layer.get_seq_length(), ( + f"cumulative_length mismatch at step {i}" + ) + + +# ============================================================================= +# SECTION 3: SSM CACHE - _SSMCache vs MambaCache concept +# ============================================================================= + + +class TestSSMCacheContract: + """Document _SSMCache interface and verify basic contract. + + Unlike attention caches which have HF equivalents (DynamicLayer, DynamicSlidingWindowLayer), + SSM caches have no direct HF counterpart with matching interface. HF's MambaCache uses + different methods (update_conv_state, update_ssm_state), so we can't do direct comparison. + + These tests document the interface contract: + 1. `conv` and `recurrent` attributes for storing states + 2. Both support None (lazy initialization) + 3. `conv` can be tuple (for KDA which has separate q/k/v conv states) + + Higher-level operations (reorder, batch_repeat, reset) are tested in + TestBeamSearchOperations in test_cache_apriel2_specific.py. + """ + + def test_conv_state_storage(self, ssm_cache): + """conv attribute stores conv states (batch, intermediate, kernel_size).""" + conv = torch.randn(2, 64, 4) + ssm_cache.conv = conv + torch.testing.assert_close(ssm_cache.conv, conv) + + def test_recurrent_state_storage(self, ssm_cache): + """recurrent attribute stores SSM states (batch, intermediate, state_size).""" + recurrent = torch.randn(2, 64, 16) + ssm_cache.recurrent = recurrent + torch.testing.assert_close(ssm_cache.recurrent, recurrent) + + def test_conv_state_tuple_for_kda(self, ssm_cache): + """conv can be tuple for KDA's separate q/k/v convolutions.""" + conv_tuple = (torch.randn(2, 64, 4), torch.randn(2, 64, 4), torch.randn(2, 64, 4)) + ssm_cache.conv = conv_tuple + assert isinstance(ssm_cache.conv, tuple) + assert len(ssm_cache.conv) == 3 + + def test_initial_states_none(self, ssm_cache): + """States are None initially (lazy initialization pattern).""" + assert ssm_cache.conv is None + assert ssm_cache.recurrent is None + + def test_states_independent(self, ssm_cache): + """conv and recurrent states are independent.""" + ssm_cache.conv = torch.randn(2, 64, 4) + assert ssm_cache.recurrent is None # recurrent unchanged + + ssm_cache.recurrent = torch.randn(2, 64, 16) + assert ssm_cache.conv is not None # conv unchanged + + +# ============================================================================= +# SECTION 4: APRIEL2CACHE INTEGRATION +# ============================================================================= + + +class TestApriel2CacheIntegration: + """Test Apriel2Cache correctly delegates to underlying caches. + + Uses fixtures from conftest.py: + - attention_config: Pure attention config + - swa_config: Sliding window attention config (window=8) + """ + + def test_get_seq_length_matches_dynamic_layer(self, attention_config): + """Apriel2Cache.get_seq_length matches DynamicLayer for full attention.""" + from transformers.cache_utils import DynamicLayer + + cache = Apriel2Cache(attention_config) + hf_layer = DynamicLayer() + + key = torch.randn(2, 4, 10, 16) + value = torch.randn(2, 4, 10, 16) + + cache.update(key.clone(), value.clone(), layer_idx=0) + hf_layer.update(key.clone(), value.clone()) + + assert cache.get_seq_length(0) == hf_layer.get_seq_length() + + def test_get_mask_sizes_matches_dynamic_layer(self, attention_config): + """Apriel2Cache.get_mask_sizes matches DynamicLayer.""" + from transformers.cache_utils import DynamicLayer + + cache = Apriel2Cache(attention_config) + hf_layer = DynamicLayer() + + key = torch.randn(2, 4, 10, 16) + value = torch.randn(2, 4, 10, 16) + + cache.update(key.clone(), value.clone(), layer_idx=0) + hf_layer.update(key.clone(), value.clone()) + + cache_position = torch.arange(1) + hf_kv_len, hf_kv_offset = hf_layer.get_mask_sizes(cache_position) + apr_kv_len, apr_kv_offset = cache.get_mask_sizes(cache_position, layer_idx=0) + + assert apr_kv_len == hf_kv_len + assert apr_kv_offset == hf_kv_offset + + def test_get_mask_sizes_matches_sliding_layer(self, swa_config): + """Apriel2Cache.get_mask_sizes matches DynamicSlidingWindowLayer.""" + from transformers.cache_utils import DynamicSlidingWindowLayer + + cache = Apriel2Cache(swa_config) + hf_layer = DynamicSlidingWindowLayer(sliding_window=8) + + # Fill past window + for _ in range(15): + key = torch.randn(2, 4, 1, 16) + value = torch.randn(2, 4, 1, 16) + cache.update(key.clone(), value.clone(), layer_idx=0) + hf_layer.update(key.clone(), value.clone()) + + cache_position = torch.arange(1) + hf_kv_len, hf_kv_offset = hf_layer.get_mask_sizes(cache_position) + apr_kv_len, apr_kv_offset = cache.get_mask_sizes(cache_position, layer_idx=0) + + assert apr_kv_len == hf_kv_len + assert apr_kv_offset == hf_kv_offset + + def test_reset_clears_cumulative_length(self, attention_config): + """reset() clears cumulative_length (matches DynamicLayer.reset).""" + cache = Apriel2Cache(attention_config) + + cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0) + assert cache.get_seq_length(0) == 10 + + cache.reset() + assert cache.get_seq_length(0) == 0 + + +# ============================================================================= +# SECTION 5: MASK CORRECTNESS (SEMANTIC TESTS) +# ============================================================================= + + +class TestMaskCorrectness: + """Test that mask parameters produce semantically correct masks. + + These tests verify the END RESULT: masks created with our parameters + allow the correct attention patterns. + """ + + def test_full_attention_decode_can_attend_to_all(self): + """During decode, query can attend to all cached positions.""" + from transformers.masking_utils import sdpa_mask, causal_mask_function + + cache = _AttentionCache(window=None) + + # Prefill + decode + for _ in range(10): + cache.update(torch.randn(1, 1, 1, 16), torch.randn(1, 1, 1, 16)) + + # Mask for decode step + cache_position = torch.tensor([10]) # Position of new token + kv_length = cache.cumulative_length + 1 + kv_offset = 0 + + mask = sdpa_mask( + batch_size=1, + cache_position=cache_position, + kv_length=kv_length, + kv_offset=kv_offset, + mask_function=causal_mask_function, + ) + + if mask is not None: + # Query at position 10 should attend to positions 0-10 + query_mask = mask[0, 0, 0, :] + for kv_idx in range(kv_length): + assert query_mask[kv_idx].item() == True, f"Should attend to position {kv_idx}" + + @pytest.mark.parametrize("window_size", [4, 8, 16]) + def test_sliding_window_decode_respects_window(self, window_size): + """During decode, query only attends within sliding window.""" + from transformers.masking_utils import sdpa_mask, sliding_window_causal_mask_function + + cache = _AttentionCache(window=window_size) + + # Fill way past window + total_tokens = window_size * 2 + for _ in range(total_tokens): + cache.update(torch.randn(1, 1, 1, 16), torch.randn(1, 1, 1, 16)) + + # Mask for decode step + cache_position = torch.tensor([total_tokens]) + cumulative = cache.cumulative_length + kv_offset = max(cumulative - window_size + 1, 0) + kv_length = window_size - 1 + 1 # cached + query + + mask = sdpa_mask( + batch_size=1, + cache_position=cache_position, + kv_length=kv_length, + kv_offset=kv_offset, + mask_function=sliding_window_causal_mask_function(window_size), + ) + + if mask is not None: + query_mask = mask[0, 0, 0, :] + query_pos = cache_position[0].item() + + for kv_idx in range(kv_length): + abs_pos = kv_offset + kv_idx + in_window = abs_pos > query_pos - window_size + causal = abs_pos <= query_pos + expected = in_window and causal + + assert query_mask[kv_idx].item() == expected, ( + f"Position {abs_pos}: expected {expected}, got {query_mask[kv_idx].item()}" + ) + + def test_prefill_has_causal_pattern(self): + """During prefill, mask has proper causal (lower triangular) pattern.""" + from transformers.masking_utils import sdpa_mask, causal_mask_function + + cache = _AttentionCache(window=None) + cache.update(torch.randn(1, 1, 5, 16), torch.randn(1, 1, 5, 16)) + + cache_position = torch.arange(5) + kv_length = cache.cumulative_length + kv_offset = 0 + + mask = sdpa_mask( + batch_size=1, + cache_position=cache_position, + kv_length=kv_length, + kv_offset=kv_offset, + mask_function=causal_mask_function, + allow_is_causal_skip=False, # Force mask creation + ) + + if mask is not None: + # Check causal pattern + for q_idx in range(5): + for kv_idx in range(5): + expected = kv_idx <= q_idx + actual = mask[0, 0, q_idx, kv_idx].item() + assert actual == expected, f"q={q_idx}, kv={kv_idx}: expected {expected}" From 843a355a6c37ea8f74783fef987b658b8549af51 Mon Sep 17 00:00:00 2001 From: bigximik Date: Fri, 28 Nov 2025 14:09:39 +0000 Subject: [PATCH 02/25] fix qwen converted to correctly load qkv biases --- fast_llm/models/gpt/conversion/qwen2.py | 32 +++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/fast_llm/models/gpt/conversion/qwen2.py b/fast_llm/models/gpt/conversion/qwen2.py index a8bc33454..57c9614bd 100644 --- a/fast_llm/models/gpt/conversion/qwen2.py +++ b/fast_llm/models/gpt/conversion/qwen2.py @@ -2,6 +2,7 @@ from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.layers.attention.config import AttentionConfig +from fast_llm.layers.decoder.mlp.config import MLPConfig from fast_llm.models.gpt.conversion.config import Qwen2CheckpointFormat from fast_llm.models.gpt.conversion.llama import ( LlamaAttentionConverter, @@ -10,6 +11,7 @@ LlamaDecoderConverter, LlamaHeadConverter, LlamaHuggingfaceCheckpointHandler, + LlamaMLPConverter, ) from fast_llm.utils import Assert @@ -17,6 +19,22 @@ class Qwen2AttentionConverter(LlamaAttentionConverter): # TODO: Support sliding window with max_window_layers (need 2 kinds of block?) + @classmethod + def import_config(cls, config: dict) -> dict: + config["attention_bias"] = True + out = super().import_config(config) + out["query_layer"] = {"bias": {"enabled": True}} + out["key_layer"] = {"bias": {"enabled": True}} + out["value_layer"] = {"bias": {"enabled": True}} + out["dense_layer"] = {"bias": {"enabled": False}} + return out + + @classmethod + def export_config(cls, config: AttentionConfig) -> dict: + out = super().export_config(config) + del out["attention_bias"] + return out + @classmethod def _check_config(cls, config: AttentionConfig) -> None: Assert.is_(type(config), AttentionConfig) @@ -33,8 +51,22 @@ def _check_config(cls, config: AttentionConfig) -> None: Assert.incl(config.dense_layer.bias.enabled, (None, False)) +class Qwen2MLPConverter(LlamaMLPConverter): + @classmethod + def import_config(cls, config: dict) -> dict: + config["mlp_bias"] = False + return super().import_config(config) + + @classmethod + def export_config(cls, config: MLPConfig) -> dict: + out = super().export_config(config) + del out["mlp_bias"] + return out + + class Qwen2BlockConverter(LlamaBlockConverter): mixer_converter_class: typing.ClassVar[type[Qwen2AttentionConverter]] = Qwen2AttentionConverter + mlp_converter_class: typing.ClassVar[type[Qwen2MLPConverter]] = Qwen2MLPConverter class Qwen2DecoderConverter(LlamaDecoderConverter): From 33b6d31dd842022812814655ba3ef2a6558ad010 Mon Sep 17 00:00:00 2001 From: bigximik Date: Tue, 2 Dec 2025 12:09:03 +0000 Subject: [PATCH 03/25] fix converters --- fast_llm/models/gpt/conversion/qwen2.py | 37 +++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/fast_llm/models/gpt/conversion/qwen2.py b/fast_llm/models/gpt/conversion/qwen2.py index 57c9614bd..4ebf18c3a 100644 --- a/fast_llm/models/gpt/conversion/qwen2.py +++ b/fast_llm/models/gpt/conversion/qwen2.py @@ -1,10 +1,12 @@ import typing from fast_llm.engine.checkpoint.config import CheckpointFormat +from fast_llm.engine.checkpoint.external import WeightConverter from fast_llm.layers.attention.config import AttentionConfig from fast_llm.layers.decoder.mlp.config import MLPConfig from fast_llm.models.gpt.conversion.config import Qwen2CheckpointFormat from fast_llm.models.gpt.conversion.llama import ( + KeyValueWeightConverter, LlamaAttentionConverter, LlamaBaseModelConverter, LlamaBlockConverter, @@ -12,6 +14,8 @@ LlamaHeadConverter, LlamaHuggingfaceCheckpointHandler, LlamaMLPConverter, + QueryWeightConverter, + get_weight_and_bias_converters, ) from fast_llm.utils import Assert @@ -50,6 +54,39 @@ def _check_config(cls, config: AttentionConfig) -> None: Assert.is_(config.value_layer.bias.enabled, True) Assert.incl(config.dense_layer.bias.enabled, (None, False)) + @classmethod + def get_converters( + cls, + config: AttentionConfig, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + ) -> list[WeightConverter]: + return [ + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.query", + f"{hf_prefix}.q_proj", + True, + QueryWeightConverter, + config, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.key_value", + (f"{hf_prefix}.k_proj", f"{hf_prefix}.v_proj"), + True, + KeyValueWeightConverter, + config, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.dense", + f"{hf_prefix}.o_proj", + False, + drop_on_export=drop_on_export, + ), + ] + class Qwen2MLPConverter(LlamaMLPConverter): @classmethod From 78229757422ccd58256cdd5afe2067884d3a80f5 Mon Sep 17 00:00:00 2001 From: Torsten Scholak Date: Sun, 14 Dec 2025 20:51:29 +0000 Subject: [PATCH 04/25] Add per-layer bias support, surgery improvements, and integration tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit adds comprehensive support for per-layer bias configurations in Apriel2 conversions and improves the surgery/conversion infrastructure. Key changes: **Per-layer bias configuration:** - Support weight-specific bias settings (query_layer.bias.enabled, etc.) - Bias inheritance for stochastic mixer submixers - Proper handling of Qwen-style bias pattern (QKV bias, no O bias) **Surgery and conversion improvements:** - Document monoidal structure in compose_configs and plan_surgery - Fix non-gated MLP handling (gate_proj only when gated=True) - Fix vision_encoder=None handling in converters - Change to relative imports in apriel2 modules for portability **Test infrastructure:** - Add requires_fastllm decorator for Fast-LLM dependent tests - Fix autouse fixture scoping (module-scoped for proper ordering) - Add comprehensive integration tests with parameterized inputs - Test all conversion stages: Qwen2 -> Apriel2 -> Supernet -> Roundtrip - Parameterized test inputs for batch size, padding, and generation length **Integration test structure:** - TestConfigPreservation: Verify config correctness at each stage - TestNumericalEquivalence: Verify logits and generation match - 24 tests covering 3 stages × 3 input variations × 2 checks 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- fast_llm/models/gpt/conversion/apriel2.py | 125 +++++-- .../apriel2/conversion/__init__.py | 2 +- .../apriel2/conversion/config.py | 142 +++++--- .../apriel2/conversion/converters.py | 195 ++++++++-- .../apriel2/conversion/qwen2/config.py | 14 +- .../apriel2/conversion/qwen2/plan.py | 20 +- .../apriel2/modeling_apriel2.py | 93 ++++- .../tests/test_apriel2/conftest.py | 152 +++++++- .../test_apriel2/test_compose_configs.py | 157 ++++++++ .../tests/test_apriel2/test_expr_plan.py | 202 +++++++++++ .../tests/test_apriel2/test_integration.py | 335 ++++++++++++++++++ .../tests/test_apriel2/test_modeling.py | 3 +- .../test_plan_composition_torture.py | 148 ++++++++ 13 files changed, 1452 insertions(+), 136 deletions(-) create mode 100644 fast_llm_external_models/tests/test_apriel2/test_integration.py diff --git a/fast_llm/models/gpt/conversion/apriel2.py b/fast_llm/models/gpt/conversion/apriel2.py index 7682196c8..eb5641aea 100644 --- a/fast_llm/models/gpt/conversion/apriel2.py +++ b/fast_llm/models/gpt/conversion/apriel2.py @@ -39,8 +39,20 @@ def import_config(cls, config: dict) -> dict: "head_groups": config["head_groups"], "head_size": config["head_size"], "rotary": rotary, - "add_linear_biases": config["add_linear_biases"], } + # Per-layer bias configuration mirroring Fast-LLM structure + # If per-layer configs exist, use them; otherwise fall back to add_linear_biases + if "query_layer" in config: + result["query_layer"] = config["query_layer"] + if "key_layer" in config: + result["key_layer"] = config["key_layer"] + if "value_layer" in config: + result["value_layer"] = config["value_layer"] + if "dense_layer" in config: + result["dense_layer"] = config["dense_layer"] + # add_linear_biases serves as default for layers without explicit config + if "add_linear_biases" in config: + result["add_linear_biases"] = config["add_linear_biases"] if "window_size" in config: result["window_size"] = config["window_size"] return result @@ -58,18 +70,37 @@ def export_config(cls, config: AttentionConfig) -> dict: else: raise NotImplementedError(f"Unsupported rotary type: {type(config.rotary).__name__}") - return { + result = { "type": "attention", "heads": config.heads, "head_groups": config.head_groups, "head_size": config.head_size, - "add_linear_biases": config.add_linear_biases, "rotary": { "type": rotary_type, "theta": config.rotary.theta, }, "window_size": config.window_size, } + # Export per-layer bias configuration + # Only include if explicitly set (not None) + if config.query_layer.bias.enabled is not None: + result["query_layer"] = {"bias": {"enabled": config.query_layer.bias.enabled}} + if config.key_layer.bias.enabled is not None: + result["key_layer"] = {"bias": {"enabled": config.key_layer.bias.enabled}} + if config.value_layer.bias.enabled is not None: + result["value_layer"] = {"bias": {"enabled": config.value_layer.bias.enabled}} + if config.dense_layer.bias.enabled is not None: + result["dense_layer"] = {"bias": {"enabled": config.dense_layer.bias.enabled}} + # add_linear_biases as fallback default + result["add_linear_biases"] = config.add_linear_biases + return result + + @classmethod + def _get_effective_bias(cls, layer_config, default: bool) -> bool: + """Get effective bias setting: use layer-specific if set, else default.""" + if layer_config.bias.enabled is not None: + return layer_config.bias.enabled + return default @classmethod def get_converters( @@ -79,11 +110,20 @@ def get_converters( hf_prefix: str, drop_on_export: bool = False, ) -> list[WeightConverter]: + # Determine effective bias for each projection + q_bias = cls._get_effective_bias(config.query_layer, config.add_linear_biases) + k_bias = cls._get_effective_bias(config.key_layer, config.add_linear_biases) + v_bias = cls._get_effective_bias(config.value_layer, config.add_linear_biases) + o_bias = cls._get_effective_bias(config.dense_layer, config.add_linear_biases) + # For key_value, both k and v must have same bias setting + # (they're combined in Fast-LLM's key_value layer) + kv_bias = k_bias and v_bias + return [ *get_weight_and_bias_converters( f"{fast_llm_prefix}.query", f"{hf_prefix}.q_proj", - config.add_linear_biases, + q_bias, QueryWeightConverter, config, drop_on_export=drop_on_export, @@ -91,7 +131,7 @@ def get_converters( *get_weight_and_bias_converters( f"{fast_llm_prefix}.key_value", (f"{hf_prefix}.k_proj", f"{hf_prefix}.v_proj"), - config.add_linear_biases, + kv_bias, KeyValueWeightConverter, config, drop_on_export=drop_on_export, @@ -99,7 +139,7 @@ def get_converters( *get_weight_and_bias_converters( f"{fast_llm_prefix}.dense", f"{hf_prefix}.o_proj", - config.add_linear_biases, + o_bias, drop_on_export=drop_on_export, ), ] @@ -524,6 +564,12 @@ def import_config(cls, config: dict, block_config: dict) -> dict: "gated": mlp_config["gated"], "add_linear_biases": mlp_config["add_linear_biases"], } + # Import per-layer MLP bias settings (layer_1, layer_2) + for layer_name in ("layer_1", "layer_2"): + if layer_name in mlp_config: + layer_cfg = mlp_config[layer_name] + if "bias" in layer_cfg: + mlp[layer_name] = {"bias": layer_cfg["bias"]} normalization = block_config["normalization"] @@ -578,6 +624,11 @@ def export_config(cls, config: DecoderBlockConfig) -> dict: "gated": config.mlp.gated, "add_linear_biases": config.mlp.add_linear_biases, } + # Export per-layer MLP bias settings (layer_1, layer_2) + if config.mlp.layer_1.bias.enabled is not None: + mlp["layer_1"] = {"bias": {"enabled": config.mlp.layer_1.bias.enabled}} + if config.mlp.layer_2.bias.enabled is not None: + mlp["layer_2"] = {"bias": {"enabled": config.mlp.layer_2.bias.enabled}} normalization = {"type": norm_type_str, "epsilon": config.normalization.epsilon} @@ -624,22 +675,52 @@ def get_converters( ) ) - converters.extend([ - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_1", - (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), - config.mlp.add_linear_biases, - SplitWeightConverter, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_2", - f"{hf_prefix}.mlp.down_proj", - config.mlp.add_linear_biases, - MLPLayer2Converter, - drop_on_export=drop_on_export, - ), - ]) + # Per-layer MLP bias: use layer-specific setting if set, else default + def get_mlp_layer_bias(layer_config, default: bool) -> bool: + if layer_config.bias.enabled is not None: + return layer_config.bias.enabled + return default + + layer_1_bias = get_mlp_layer_bias(config.mlp.layer_1, config.mlp.add_linear_biases) + layer_2_bias = get_mlp_layer_bias(config.mlp.layer_2, config.mlp.add_linear_biases) + + if config.mlp.gated: + # Gated MLP: gate_proj + up_proj -> layer_1 (split), down_proj -> layer_2 + converters.extend([ + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_1", + (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), + layer_1_bias, + SplitWeightConverter, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_2", + f"{hf_prefix}.mlp.down_proj", + layer_2_bias, + MLPLayer2Converter, + drop_on_export=drop_on_export, + ), + ]) + else: + # Non-gated MLP: up_proj -> layer_1, down_proj -> layer_2 + # Note: layer_2 still needs MLPLayer2Converter for the transpose + converters.extend([ + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_1", + f"{hf_prefix}.mlp.up_proj", + layer_1_bias, + WeightConverter, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_2", + f"{hf_prefix}.mlp.down_proj", + layer_2_bias, + MLPLayer2Converter, + drop_on_export=drop_on_export, + ), + ]) converters.extend([ *LlamaNormalizationConverter.get_converters( diff --git a/fast_llm_external_models/apriel2/conversion/__init__.py b/fast_llm_external_models/apriel2/conversion/__init__.py index 983a632e0..60fc0ef0a 100644 --- a/fast_llm_external_models/apriel2/conversion/__init__.py +++ b/fast_llm_external_models/apriel2/conversion/__init__.py @@ -63,7 +63,7 @@ target_config = compose_configs(source_config, surgery_spec) # 2. Build plan for weight transformation - plan = plan_surgery(source_config, surgery_spec) + plan = plan_surgery(source_config, target_config) # 3. Execute plan to transform weights target_weights = execute(plan, source_weights, seed=42) diff --git a/fast_llm_external_models/apriel2/conversion/config.py b/fast_llm_external_models/apriel2/conversion/config.py index 48f8ff44b..f5b19e208 100644 --- a/fast_llm_external_models/apriel2/conversion/config.py +++ b/fast_llm_external_models/apriel2/conversion/config.py @@ -1,56 +1,59 @@ """Config composition for Apriel2 architecture transformations. This module handles STRUCTURAL composition of configs, independent of weight handling. -The `init` field in surgery specs is preserved as metadata for the plan builder but -does not affect how configs are composed. +The `init` field in surgery specs is metadata for plan_surgery(), not for composition. -Composition Cases -================= +Algebraic Structure +=================== + +The system has a precise algebraic structure with two interacting components: -compose_configs(base, overlay) handles four cases based on completeness: +**Surgery Specs (Monoid)** + Partial config dicts form a monoid under deep merge: + - Identity: {} (empty dict) + - Operation: compose_configs(partial1, partial2) = deep_merge(partial1, partial2) + - Associativity: (a ∘ b) ∘ c = a ∘ (b ∘ c) -1. **Complete + Partial** → Apply surgery semantics (inheritance, cross-type derivation) -2. **Partial + Partial** → Deep merge (monoid operation on surgery specs) -3. **Partial + Complete** → Overlay wins (complete config replaces partial) -4. **Complete + Complete** → Deep merge, then strip `init` fields +**Complete Configs (Monoid Action)** + Surgery specs ACT on complete configs: + - Action: compose_configs(complete, partial) → complete + - For additive surgeries: (s · t₁) · t₂ = s · (t₁ ∘ t₂) + - For replacement surgeries: action law intentionally fails (last-write-wins) -A config is "complete" if it has `hidden_size` and `decoder` (i.e., it's a full model -config, not a surgery spec). +This separation is fundamental: surgery specs compose declaratively (what fields to +merge), while the action on configs interprets those fields with inheritance semantics. -Surgery Semantics +Composition Cases ================= -When applying a surgery spec to a complete config: +compose_configs(base, overlay) dispatches based on completeness: + +1. **Complete + Partial** → Monoid action (inheritance, cross-type derivation) +2. **Partial + Partial** → Monoid operation (deep merge) +3. **Partial + Complete** → Overlay wins (complete replaces partial) +4. **Complete + Complete** → Deep merge, strip `init` fields -**Inheritance** - Unspecified parameters inherit from the source config. New blocks inherit - from the "default" block (first block in pattern, or the single fixed block). +A config is "complete" if it has `hidden_size` and `decoder`. -**Cross-Type Derivation** - When changing mixer types, geometric parameters are derived where possible: - - attention → sliding_window: preserve heads, head_groups, head_size - - attention → gdn: heads → value_heads, head_groups → key_heads - - attention → mamba: derive d_inner, d_xb, dt_rank from hidden_size - - attention → kda: preserve heads, head_size → head_dim +Inheritance Semantics +===================== -**Stochastic Mixer Composition** - Two semantics based on whether surgery declares `type: stochastic`: - - Replacement: surgery declares type → only surgery's sub-mixers included - - Additive: surgery omits type → source sub-mixers preserved, surgery adds/modifies +When the monoid action applies a surgery to a complete config: - This distinction means the monoid action law holds for additive surgeries but - intentionally fails for replacement surgeries (they have "last-write-wins" semantics). +- Unspecified fields inherit from source +- New blocks inherit from the "default" block +- Cross-type derivation maps geometry (attention.heads → gdn.value_heads, etc.) +- Stochastic mixers: additive (no type decl) preserves source, replacement replaces The `init` Field ================ -The `init` field is metadata for the plan builder, NOT for config composition: -- `init: transfer` → plan builder creates weight transfer mappings -- `init: random` → plan builder creates random initialization +The `init` field is metadata for plan_surgery(), NOT for config composition: +- `init: transfer` → plan uses weight transfer/conversion +- `init: random` → plan uses random initialization -After surgery is applied to produce a complete config, ALL `init` fields are stripped. -This ensures configs are purely structural and plan creation is Markovian (depends only -on current config + surgery, not on history). +After composition produces a complete config, ALL `init` fields are stripped. +This ensures configs are purely structural and plan creation is Markovian. """ from __future__ import annotations @@ -65,14 +68,49 @@ def is_complete(config: dict) -> bool: def compose_configs(base: dict, overlay: dict | None) -> dict: - """Compose two configs. + """Compose two configs using monoid or monoid action semantics. + + This function implements two algebraic operations depending on argument types: + + 1. **Monoid Action** (complete + partial): Apply surgery to a complete config. + Unspecified fields inherit from base; `init` fields are stripped from result. + + 2. **Monoid Operation** (partial + partial): Merge two surgery specs. + Deep merge with overlay winning on conflicts; `init` fields preserved. Args: - base: Base config (complete or partial surgery spec). - overlay: Overlay config (complete or partial surgery spec). + base: Base config (complete) or surgery spec (partial). + overlay: Surgery spec to apply (partial) or config to merge. Returns: - Composed config. + - If base is complete: Complete config with surgery applied, `init` stripped. + - If both partial: Merged surgery spec with `init` preserved. + + Algebraic Properties: + Surgery specs form a monoid: (a ∘ b) ∘ c = a ∘ (b ∘ c), identity = {} + + For additive surgeries, the action law holds: + compose(compose(s, t1), t2) == compose(s, compose(t1, t2)) + + For replacement surgeries (declaring type:), action law intentionally fails. + + Example: + # Apply surgery to complete config (monoid action) + source = {"hidden_size": 256, "decoder": {...}} # complete + surgery = {"decoder": {"block": {"mixer": {"type": "mamba"}}}} # partial + + target = compose_configs(source, surgery) + # target is complete with inherited fields, init stripped + + # Merge two surgery specs (monoid operation) + s1 = {"decoder": {"block": {"mixer": {"mixers": {"a": {...}}}}}} + s2 = {"decoder": {"block": {"mixer": {"mixers": {"b": {...}}}}}} + + merged = compose_configs(s1, s2) + # merged has both mixers a and b, init preserved + + # Use composed config with plan_surgery + plan = plan_surgery(source, target) """ if not overlay: return copy.deepcopy(base) @@ -134,20 +172,24 @@ def _strip_keys(config: Any, keys_to_strip: set[str]) -> None: def apply_surgery(source_config: dict, surgery_config: dict | None) -> dict: - """Apply surgery specification to a complete source config. + """Apply surgery spec to complete config (the monoid action). + + This is the internal implementation of the monoid action: surgery specs + acting on complete configs. Called by compose_configs when base is complete + and overlay is partial. - This handles: - - Top-level scalar overrides - - Decoder composition (fixed vs pattern) - - Stochastic mixer sub-mixer inheritance - - Cross-type derivation (attention → gdn, attention → mamba) + Implements inheritance semantics: + - Unspecified fields inherit from source + - Cross-type derivation maps geometry (attention → gdn, etc.) + - Stochastic sub-mixers inherit from source's main mixer + - All `init` fields stripped from result Args: - source_config: Complete Apriel2 config. - surgery_config: Partial surgery specification. + source_config: Complete Apriel2 config (the state being acted on). + surgery_config: Partial surgery spec (the monoid element acting). Returns: - Complete Apriel2 config with surgery applied. + Complete config with surgery applied, `init` fields stripped. """ if not surgery_config: return copy.deepcopy(source_config) @@ -392,6 +434,12 @@ def _compose_single_mixer(source: dict, surgery: dict, hidden_size: int) -> dict result[key] = surgery[key] elif key in source: result[key] = source[key] + # Copy per-layer bias settings (query_layer, key_layer, value_layer, dense_layer) + for key in ["query_layer", "key_layer", "value_layer", "dense_layer", "add_linear_biases"]: + if key in surgery: + result[key] = surgery[key] + elif key in source: + result[key] = copy.deepcopy(source[key]) # Preserve init if "init" in surgery: result["init"] = surgery["init"] diff --git a/fast_llm_external_models/apriel2/conversion/converters.py b/fast_llm_external_models/apriel2/conversion/converters.py index 6d1350c54..b54bb5a87 100644 --- a/fast_llm_external_models/apriel2/conversion/converters.py +++ b/fast_llm_external_models/apriel2/conversion/converters.py @@ -79,6 +79,21 @@ # This is the single source of truth for each mixer's weight schema. +def _get_attention_bias_enabled(config: dict, layer_name: str) -> bool: + """Get whether bias is enabled for an attention layer. + + Checks per-layer bias config (e.g., query_layer.bias.enabled). + Falls back to add_linear_biases if not set. + """ + layer_cfg = config.get(layer_name, {}) + bias_cfg = layer_cfg.get("bias", {}) + enabled = bias_cfg.get("enabled") + if enabled is not None: + return enabled + # Fall back to add_linear_biases + return config.get("add_linear_biases", False) + + def _plan_attention_mixer( *, prefix: W, @@ -90,9 +105,13 @@ def _plan_attention_mixer( Weight schema: - q_proj.weight: (q_size, hidden_size) + - q_proj.bias: (q_size,) [if query_layer.bias.enabled] - k_proj.weight: (kv_size, hidden_size) + - k_proj.bias: (kv_size,) [if key_layer.bias.enabled] - v_proj.weight: (kv_size, hidden_size) + - v_proj.bias: (kv_size,) [if value_layer.bias.enabled] - o_proj.weight: (hidden_size, q_size) + - o_proj.bias: (hidden_size,) [if dense_layer.bias.enabled] Args: prefix: Target weight path prefix. @@ -100,12 +119,28 @@ def _plan_attention_mixer( hidden_size: Model hidden size. source_prefix: If provided, passthrough from source. If None, random init. """ + # Check per-layer bias configuration + q_bias = _get_attention_bias_enabled(config, "query_layer") + k_bias = _get_attention_bias_enabled(config, "key_layer") + v_bias = _get_attention_bias_enabled(config, "value_layer") + o_bias = _get_attention_bias_enabled(config, "dense_layer") + if source_prefix is not None: - # Passthrough - return ExprPlan(mappings={ + # Passthrough weights + mappings: dict[W, Expr] = { prefix / proj / "weight": Ref(key=source_prefix / proj / "weight") for proj in ["q_proj", "k_proj", "v_proj", "o_proj"] - }) + } + # Passthrough biases if enabled + if q_bias: + mappings[prefix / "q_proj" / "bias"] = Ref(key=source_prefix / "q_proj" / "bias") + if k_bias: + mappings[prefix / "k_proj" / "bias"] = Ref(key=source_prefix / "k_proj" / "bias") + if v_bias: + mappings[prefix / "v_proj" / "bias"] = Ref(key=source_prefix / "v_proj" / "bias") + if o_bias: + mappings[prefix / "o_proj" / "bias"] = Ref(key=source_prefix / "o_proj" / "bias") + return ExprPlan(mappings=mappings) # Random init heads = config["heads"] @@ -114,12 +149,22 @@ def _plan_attention_mixer( q_size = heads * head_size kv_size = head_groups * head_size - return ExprPlan(mappings={ + mappings = { prefix / "q_proj" / "weight": Init(shape=(q_size, hidden_size), init_type="kaiming"), prefix / "k_proj" / "weight": Init(shape=(kv_size, hidden_size), init_type="kaiming"), prefix / "v_proj" / "weight": Init(shape=(kv_size, hidden_size), init_type="kaiming"), prefix / "o_proj" / "weight": Init(shape=(hidden_size, q_size), init_type="kaiming"), - }) + } + # Random init biases if enabled + if q_bias: + mappings[prefix / "q_proj" / "bias"] = Init(shape=(q_size,), init_type="zeros") + if k_bias: + mappings[prefix / "k_proj" / "bias"] = Init(shape=(kv_size,), init_type="zeros") + if v_bias: + mappings[prefix / "v_proj" / "bias"] = Init(shape=(kv_size,), init_type="zeros") + if o_bias: + mappings[prefix / "o_proj" / "bias"] = Init(shape=(hidden_size,), init_type="zeros") + return ExprPlan(mappings=mappings) def _plan_mamba_mixer( @@ -786,7 +831,45 @@ def plan_surgery( source_config: dict, target_config: dict, ) -> ExprPlan: - """Build plan for Apriel2→Apriel2 surgery (MIL, DIL, KIL, stochastic mixers, etc.).""" + """Build a weight conversion plan between two Apriel2 configurations. + + This function creates an ExprPlan that maps source weight keys to expressions + defining how to compute target weights. The plan handles same-type passthrough, + cross-type conversions (MIL, DIL, KIL), and stochastic mixer routing. + + Args: + source_config: Complete Apriel2 config dict describing the source architecture. + Must have all structural fields (hidden_size, decoder, etc.) fully specified. + target_config: Complete Apriel2 config dict describing the target architecture. + Must be fully specified with all inherited fields resolved. Use + compose_configs(source_config, surgery_spec) to produce this from a + partial surgery specification. + + Returns: + ExprPlan mapping target weight keys to expressions over source weights. + + Example: + # Apply a surgery that wraps attention in a stochastic mixer + surgery_spec = { + "decoder": {"block": {"mixer": { + "type": "stochastic", + "mixers": {"attention": {"type": "attention", "init": "transfer"}} + }}} + } + + # First compose to get complete target config with inherited fields + target_config = compose_configs(source_config, surgery_spec) + + # Then build the plan from two complete configs + plan = plan_surgery(source_config, target_config) + new_weights = execute(plan, source_weights, seed=0) + + Note: + Both arguments must be complete configs. The target_config determines the + full target architecture including all inherited fields (bias settings, + rotary config, etc.). Passing a partial surgery spec directly will result + in missing inherited fields and incorrect plans. + """ hidden_size = target_config.get("hidden_size", source_config.get("hidden_size")) assert hidden_size is not None, "hidden_size must be specified in source or target config" @@ -845,8 +928,8 @@ def _plan_non_decoder_weights(config: dict) -> ExprPlan: norm = W("model", "norm", "weight") mappings[norm] = Ref(key=norm) - if "vision_encoder" in config: - vision_config = config["vision_encoder"] + vision_config = config.get("vision_encoder") + if vision_config: vision = W("model", "vision_encoder") patch_emb = vision / "embeddings" / "patch_embeddings" / "weight" @@ -986,6 +1069,24 @@ def _plan_mixer( ) +def _get_mlp_bias_enabled(config: dict, layer_name: str) -> bool: + """Get whether bias is enabled for an MLP layer. + + Checks per-layer bias config (e.g., layer_1.bias.enabled, layer_2.bias.enabled). + Falls back to add_linear_biases if not set. + + Note: layer_1 corresponds to gate_proj and up_proj (gated MLP) or just up_proj (non-gated) + layer_2 corresponds to down_proj + """ + layer_cfg = config.get(layer_name, {}) + bias_cfg = layer_cfg.get("bias", {}) + enabled = bias_cfg.get("enabled") + if enabled is not None: + return enabled + # Fall back to add_linear_biases + return config.get("add_linear_biases", False) + + def _plan_mlp( target_layer_idx: int, source_layer_idx: int, @@ -1006,7 +1107,7 @@ def _plan_mlp_transfer( target_mlp: dict, hidden_size: int, ) -> ExprPlan: - """Passthrough for MLP weights.""" + """Passthrough for MLP weights and biases.""" source_mlp_path = W("model", "decoder", "blocks", source_layer_idx, "mlp") target_mlp_path = W("model", "decoder", "blocks", target_layer_idx, "mlp") @@ -1019,10 +1120,37 @@ def _plan_mlp_transfer( f"Use 'init: random' to initialize randomly." ) - return ExprPlan(mappings={ + # Check per-layer bias configuration + layer_1_bias = _get_mlp_bias_enabled(target_mlp, "layer_1") + layer_2_bias = _get_mlp_bias_enabled(target_mlp, "layer_2") + + # Check if gated MLP (has gate_proj) or non-gated (only up_proj) + gated = target_mlp.get("gated", True) # Default to gated for backwards compatibility + + # Passthrough weights + # layer_1 = gate_proj + up_proj (gated) or just up_proj (non-gated) + # layer_2 = down_proj + if gated: + weight_projs = ["gate_proj", "up_proj", "down_proj"] + else: + weight_projs = ["up_proj", "down_proj"] + + mappings: dict[W, Expr] = { target_mlp_path / proj / "weight": Ref(key=source_mlp_path / proj / "weight") - for proj in ["gate_proj", "up_proj", "down_proj"] - }) + for proj in weight_projs + } + + # Passthrough biases if enabled + if layer_1_bias: + if gated: + mappings[target_mlp_path / "gate_proj" / "bias"] = Ref(key=source_mlp_path / "gate_proj" / "bias") + mappings[target_mlp_path / "up_proj" / "bias"] = Ref(key=source_mlp_path / "up_proj" / "bias") + + # layer_2 = down_proj + if layer_2_bias: + mappings[target_mlp_path / "down_proj" / "bias"] = Ref(key=source_mlp_path / "down_proj" / "bias") + + return ExprPlan(mappings=mappings) def _plan_random_mlp( @@ -1030,20 +1158,41 @@ def _plan_random_mlp( target_mlp: dict, hidden_size: int, ) -> ExprPlan: - """Random initialization for MLP.""" + """Random initialization for MLP weights and biases.""" target_mlp_path = W("model", "decoder", "blocks", target_layer_idx, "mlp") intermediate_size = target_mlp["intermediate_size"] - return ExprPlan(mappings={ - target_mlp_path / "gate_proj" / "weight": Init( - shape=(intermediate_size, hidden_size), init_type="kaiming" - ), - target_mlp_path / "up_proj" / "weight": Init( + + # Check per-layer bias configuration + layer_1_bias = _get_mlp_bias_enabled(target_mlp, "layer_1") + layer_2_bias = _get_mlp_bias_enabled(target_mlp, "layer_2") + + # Check if gated MLP (has gate_proj) or non-gated (only up_proj) + gated = target_mlp.get("gated", True) # Default to gated for backwards compatibility + + # Random init weights + mappings: dict[W, Expr] = {} + if gated: + mappings[target_mlp_path / "gate_proj" / "weight"] = Init( shape=(intermediate_size, hidden_size), init_type="kaiming" - ), - target_mlp_path / "down_proj" / "weight": Init( - shape=(hidden_size, intermediate_size), init_type="kaiming" - ), - }) + ) + mappings[target_mlp_path / "up_proj" / "weight"] = Init( + shape=(intermediate_size, hidden_size), init_type="kaiming" + ) + mappings[target_mlp_path / "down_proj" / "weight"] = Init( + shape=(hidden_size, intermediate_size), init_type="kaiming" + ) + + # Random init biases if enabled + if layer_1_bias: + if gated: + mappings[target_mlp_path / "gate_proj" / "bias"] = Init(shape=(intermediate_size,), init_type="zeros") + mappings[target_mlp_path / "up_proj" / "bias"] = Init(shape=(intermediate_size,), init_type="zeros") + + # layer_2 = down_proj + if layer_2_bias: + mappings[target_mlp_path / "down_proj" / "bias"] = Init(shape=(hidden_size,), init_type="zeros") + + return ExprPlan(mappings=mappings) def _plan_norms( diff --git a/fast_llm_external_models/apriel2/conversion/qwen2/config.py b/fast_llm_external_models/apriel2/conversion/qwen2/config.py index 36df744c0..70629fe0e 100644 --- a/fast_llm_external_models/apriel2/conversion/qwen2/config.py +++ b/fast_llm_external_models/apriel2/conversion/qwen2/config.py @@ -23,11 +23,7 @@ def convert_config(qwen2_config: dict) -> dict: num_key_value_heads = qwen2_config.get("num_key_value_heads", num_attention_heads) head_dim = hidden_size // num_attention_heads - # Qwen2 uses QKV bias but not O bias - # The add_linear_biases in Apriel2 attention config controls all biases uniformly, - # but we can set it to True and the o_proj bias will just be missing from weights - # (handled by strict=False loading or explicit handling in the plan) - + # Qwen2 uses QKV bias but not O bias - mirror Fast-LLM's per-layer config return { "model_type": "apriel2_text", "architectures": ["Apriel2ForCausalLM"], @@ -48,9 +44,11 @@ def convert_config(qwen2_config: dict) -> dict: "heads": num_attention_heads, "head_groups": num_key_value_heads, "head_size": head_dim, - # Qwen2 has QKV bias but not O bias - # We set True and handle O bias separately - "add_linear_biases": True, + # Per-layer bias config matching Fast-LLM structure + "query_layer": {"bias": {"enabled": True}}, + "key_layer": {"bias": {"enabled": True}}, + "value_layer": {"bias": {"enabled": True}}, + "dense_layer": {"bias": {"enabled": False}}, "rotary": { "type": "mistral_1d", "theta": qwen2_config.get("rope_theta", 1000000.0), diff --git a/fast_llm_external_models/apriel2/conversion/qwen2/plan.py b/fast_llm_external_models/apriel2/conversion/qwen2/plan.py index e5ae3e9d8..7752d37c9 100644 --- a/fast_llm_external_models/apriel2/conversion/qwen2/plan.py +++ b/fast_llm_external_models/apriel2/conversion/qwen2/plan.py @@ -3,7 +3,6 @@ from fast_llm_external_models.apriel2.conversion.expr import ( Expr, ExprPlan, - Init, Ref, W, ) @@ -23,15 +22,19 @@ def plan_qwen2_to_apriel2(qwen2_config: dict) -> ExprPlan: model.layers.{i}.input_layernorm.weight -> model.decoder.blocks.{i}.input_layernorm.weight model.layers.{i}.post_attention_layernorm.weight -> model.decoder.blocks.{i}.post_attention_layernorm.weight model.layers.{i}.self_attn.q_proj.weight -> model.decoder.blocks.{i}.mixer.q_proj.weight + model.layers.{i}.self_attn.q_proj.bias -> model.decoder.blocks.{i}.mixer.q_proj.bias model.layers.{i}.self_attn.k_proj.weight -> model.decoder.blocks.{i}.mixer.k_proj.weight + model.layers.{i}.self_attn.k_proj.bias -> model.decoder.blocks.{i}.mixer.k_proj.bias model.layers.{i}.self_attn.v_proj.weight -> model.decoder.blocks.{i}.mixer.v_proj.weight + model.layers.{i}.self_attn.v_proj.bias -> model.decoder.blocks.{i}.mixer.v_proj.bias model.layers.{i}.self_attn.o_proj.weight -> model.decoder.blocks.{i}.mixer.o_proj.weight model.layers.{i}.mlp.gate_proj.weight -> model.decoder.blocks.{i}.mlp.gate_proj.weight model.layers.{i}.mlp.up_proj.weight -> model.decoder.blocks.{i}.mlp.up_proj.weight model.layers.{i}.mlp.down_proj.weight -> model.decoder.blocks.{i}.mlp.down_proj.weight - Note: Qwen2 has QKV biases but no O bias. We skip the biases in the conversion - since Apriel2 is configured with add_linear_biases=False for uniform handling. + Note: Qwen2 has QKV biases but no O bias. The Apriel2 config uses per-layer + bias settings (query_layer.bias.enabled=True, dense_layer.bias.enabled=False) + to match this exactly - no workarounds needed. Args: qwen2_config: HuggingFace Qwen2Config as dict @@ -42,7 +45,6 @@ def plan_qwen2_to_apriel2(qwen2_config: dict) -> ExprPlan: mappings: dict[str, Expr] = {} num_layers = qwen2_config["num_hidden_layers"] - hidden_size = qwen2_config["hidden_size"] # Static mappings (embeddings and final norm) # Note: Qwen2 safetensor keys have "model." prefix @@ -66,8 +68,7 @@ def plan_qwen2_to_apriel2(qwen2_config: dict) -> ExprPlan: qwen_layer = W("model", "layers", layer) apriel_layer = W("model", "decoder", "blocks", layer) - # Attention projections (weights and biases) - # Qwen2 has QKV bias but no O bias + # Attention projection weights for proj in ["q_proj", "k_proj", "v_proj", "o_proj"]: src = qwen_layer / "self_attn" / proj / "weight" tgt = apriel_layer / "mixer" / proj / "weight" @@ -79,12 +80,7 @@ def plan_qwen2_to_apriel2(qwen2_config: dict) -> ExprPlan: tgt = apriel_layer / "mixer" / proj / "bias" mappings[tgt] = Ref(key=src) - # O bias - Qwen2 doesn't have this, so initialize to zeros - # Shape is hidden_size (d_model) - mappings[apriel_layer / "mixer" / "o_proj" / "bias"] = Init( - shape=(hidden_size,), - init_type="zeros", - ) + # Note: o_proj has no bias in Qwen2, and Apriel2 config has dense_layer.bias.enabled=False # MLP projections for proj in ["gate_proj", "up_proj", "down_proj"]: diff --git a/fast_llm_external_models/apriel2/modeling_apriel2.py b/fast_llm_external_models/apriel2/modeling_apriel2.py index 4c263b4e2..878677653 100644 --- a/fast_llm_external_models/apriel2/modeling_apriel2.py +++ b/fast_llm_external_models/apriel2/modeling_apriel2.py @@ -24,8 +24,8 @@ is_torch_flex_attn_available, ) -from fast_llm_external_models.apriel2.cache import Apriel2Cache -from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config, Apriel2TextConfig +from .cache import Apriel2Cache +from .configuration_apriel2 import Apriel2Config, Apriel2TextConfig # GDN implementation - matches Fast-LLM's gdn.py exactly try: @@ -395,14 +395,30 @@ def __init__(self, d_model: int, mixer_config: dict, layer_idx: int, config): # cross_document_attention: if False, use cu_seqlens to isolate sequences (e.g., images) self.cross_document_attention = mixer_config.get("cross_document_attention", True) - # Whether to add biases to linear projections - add_bias = mixer_config.get("add_linear_biases", False) - - # Projections (Fast-LLM weight names: q_proj, k_proj, v_proj, o_proj) - self.q_proj = nn.Linear(d_model, self.num_heads * self.head_dim, bias=add_bias) - self.k_proj = nn.Linear(d_model, self.num_key_value_heads * self.head_dim, bias=add_bias) - self.v_proj = nn.Linear(d_model, self.num_key_value_heads * self.head_dim, bias=add_bias) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, d_model, bias=add_bias) + # Bias configuration mirroring Fast-LLM's structure: + # - add_linear_biases: bool (default for all projections) + # - query_layer: {"bias": {"enabled": bool}} (per-layer override) + # - key_layer: {"bias": {"enabled": bool}} + # - value_layer: {"bias": {"enabled": bool}} + # - dense_layer: {"bias": {"enabled": bool}} + default_bias = mixer_config.get("add_linear_biases", False) + + def get_layer_bias(layer_name: str) -> bool: + layer_cfg = mixer_config.get(layer_name, {}) + bias_cfg = layer_cfg.get("bias", {}) + enabled = bias_cfg.get("enabled") + return default_bias if enabled is None else enabled + + q_bias = get_layer_bias("query_layer") + k_bias = get_layer_bias("key_layer") + v_bias = get_layer_bias("value_layer") + o_bias = get_layer_bias("dense_layer") + + # Projections + self.q_proj = nn.Linear(d_model, self.num_heads * self.head_dim, bias=q_bias) + self.k_proj = nn.Linear(d_model, self.num_key_value_heads * self.head_dim, bias=k_bias) + self.v_proj = nn.Linear(d_model, self.num_key_value_heads * self.head_dim, bias=v_bias) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, d_model, bias=o_bias) @classmethod def setup( @@ -1828,16 +1844,36 @@ def __init__( self.post_attention_layernorm = self._create_norm(norm_config, hidden_size, rms_norm_eps) def _create_mlp(self, mlp_config: dict, hidden_size: int): - """Create MLP based on config.""" + """Create MLP based on config. + + Supports per-layer bias configuration mirroring Fast-LLM: + - add_linear_biases: default bias setting for all layers + - layer_1.bias.enabled: override for up_proj/gate_proj + - layer_2.bias.enabled: override for down_proj + """ mlp_type = mlp_config.get("type", "mlp") if mlp_type == "mlp": intermediate_size = mlp_config["intermediate_size"] activation = mlp_config.get("activation", "silu") - gated = mlp_config["gated"] - bias = mlp_config.get("add_linear_biases", False) + gated = mlp_config.get("gated", False) + + # Per-layer bias configuration (mirrors Fast-LLM structure) + default_bias = mlp_config.get("add_linear_biases", False) + + def get_layer_bias(layer_name: str) -> bool: + layer_cfg = mlp_config.get(layer_name, {}) + bias_cfg = layer_cfg.get("bias", {}) + enabled = bias_cfg.get("enabled") + return default_bias if enabled is None else enabled + + layer_1_bias = get_layer_bias("layer_1") + layer_2_bias = get_layer_bias("layer_2") if gated: + # MistralMLP uses gate_proj, up_proj, down_proj (all bias controlled together) + # For now, we use the default bias setting for gated MLPs + # TODO: Add per-layer bias support to gated MLP mlp_cfg = SimpleNamespace( hidden_size=hidden_size, intermediate_size=intermediate_size, @@ -1845,7 +1881,13 @@ def _create_mlp(self, mlp_config: dict, hidden_size: int): ) return MistralMLP(mlp_cfg) else: - return SimpleMLP(hidden_size, intermediate_size, activation, bias) + return SimpleMLP( + hidden_size, + intermediate_size, + activation, + layer_1_bias=layer_1_bias, + layer_2_bias=layer_2_bias, + ) else: raise ValueError(f"Unknown MLP type: {mlp_type}") @@ -2179,6 +2221,8 @@ def forward( class Apriel2ForCausalLM(Apriel2PreTrainedModel, GenerationMixin): """Apriel2 model with a language modeling head (text-only).""" + _tied_weights_keys = ["lm_head.weight"] + def __init__(self, config: Apriel2TextConfig): super().__init__(config) self.model = Apriel2TextModel(config) @@ -2186,6 +2230,7 @@ def __init__(self, config: Apriel2TextConfig): self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing + # post_init() calls init_weights() which calls tie_weights() if config.tie_word_embeddings self.post_init() def get_input_embeddings(self): @@ -2583,14 +2628,26 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: class SimpleMLP(nn.Module): - """Non-gated MLP: up_proj -> activation -> down_proj.""" + """Non-gated MLP: up_proj -> activation -> down_proj. - def __init__(self, hidden_size: int, intermediate_size: int, activation: str = "silu", bias: bool = False): + Supports per-layer bias configuration mirroring Fast-LLM: + - layer_1_bias: bias for up_proj (layer_1 in Fast-LLM naming) + - layer_2_bias: bias for down_proj (layer_2 in Fast-LLM naming) + """ + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + activation: str = "silu", + layer_1_bias: bool = False, + layer_2_bias: bool = False, + ): super().__init__() from transformers.activations import ACT2FN - self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=bias) - self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=bias) + self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=layer_1_bias) + self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=layer_2_bias) self.act_fn = ACT2FN[activation] def forward(self, x): diff --git a/fast_llm_external_models/tests/test_apriel2/conftest.py b/fast_llm_external_models/tests/test_apriel2/conftest.py index 5c127d97e..cf190b50a 100644 --- a/fast_llm_external_models/tests/test_apriel2/conftest.py +++ b/fast_llm_external_models/tests/test_apriel2/conftest.py @@ -10,16 +10,40 @@ from fast_llm_external_models.apriel2.cache import _AttentionCache, _SSMCache +# Register custom marks +def pytest_configure(config): + config.addinivalue_line("markers", "slow: marks tests as slow (deselect with '-m \"not slow\"')") + + +def _can_import_fast_llm(): + """Check if Fast-LLM is available.""" + try: + from fast_llm.engine.checkpoint.convert import ConvertConfig + return True + except ImportError: + return False + + # Skip marker for tests that require CUDA for Mamba forward pass requires_cuda = pytest.mark.skipif( not torch.cuda.is_available(), reason="SSM mixers (Mamba) require CUDA for forward pass" ) +# Skip marker for tests that require Fast-LLM +requires_fastllm = pytest.mark.skipif( + not _can_import_fast_llm(), + reason="Fast-LLM not available" +) + -@pytest.fixture(autouse=True) +@pytest.fixture(scope="module", autouse=True) def set_default_device(): - """Set default device to CUDA for all tests (Mamba requires CUDA).""" + """Set default device to CUDA for all tests (Mamba requires CUDA). + + Module-scoped to ensure it runs before any module-scoped fixtures + that load models (e.g., qwen2_model_and_tokenizer). + """ if torch.cuda.is_available(): old_device = torch.get_default_device() torch.set_default_device("cuda") @@ -29,9 +53,12 @@ def set_default_device(): yield -@pytest.fixture(autouse=True) +@pytest.fixture(scope="module", autouse=True) def set_default_dtype(): - """Set default dtype to float32 for numerical comparison tests.""" + """Set default dtype to float32 for numerical comparison tests. + + Module-scoped to ensure it runs before any module-scoped fixtures. + """ old_dtype = torch.get_default_dtype() torch.set_default_dtype(torch.float32) yield @@ -763,6 +790,52 @@ def apriel2_config_comprehensive(): ) +@pytest.fixture +def apriel2_config_with_bias(): + """Apriel2 config with Qwen-style per-layer bias and non-gated MLP. + + This config exercises: + - Per-layer attention bias (QKV bias enabled, O bias disabled) + - Non-gated MLP with per-layer bias (layer_1 enabled, layer_2 disabled) + - Config structure parity with Fast-LLM's AffineLinearConfig + + Critical for testing bias inheritance through surgery operations. + """ + from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config + + return Apriel2Config( + vocab_size=100, + hidden_size=64, + decoder={ + "type": "fixed", + "num_blocks": 2, + "block": { + "mixer": { + "type": "attention", + "heads": 4, + "head_groups": 2, + "head_size": 16, + "rotary": {"type": "mistral_1d", "theta": 10000.0}, + # Qwen-style: QKV bias enabled, O bias disabled + "query_layer": {"bias": {"enabled": True}}, + "key_layer": {"bias": {"enabled": True}}, + "value_layer": {"bias": {"enabled": True}}, + "dense_layer": {"bias": {"enabled": False}}, + }, + "mlp": { + "type": "mlp", + "intermediate_size": 256, + "gated": False, # Non-gated MLP (SimpleMLP) + # Per-layer MLP bias + "layer_1": {"bias": {"enabled": True}}, + "layer_2": {"bias": {"enabled": False}}, + }, + "normalization": {"type": "rms_norm", "epsilon": 1e-5}, + }, + }, + ) + + @pytest.fixture def apriel2_cache(apriel2_config_tiny): """Create empty Apriel2Cache from tiny config.""" @@ -865,6 +938,77 @@ def additive_surgery_chain(): ] +@pytest.fixture +def bias_surgery_chain(): + """Surgery chain that exercises bias inheritance through surgery operations. + + Designed to be used with apriel2_config_with_bias as the source config. + Tests that per-layer bias settings (Qwen-style QKV bias, non-gated MLP bias) + are correctly inherited through: + - Stochastic wrapper creation + - Adding new sub-mixers that inherit from source + - Cross-type derivation (attention → sliding_window) + + Source config has: + - Attention: query/key/value bias enabled, dense bias disabled + - MLP: layer_1 bias enabled, layer_2 bias disabled (non-gated) + """ + return [ + # S1: Wrap in stochastic - bias should transfer to attention sub-mixer + { + "decoder": { + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"init": "transfer"}, + }, + }, + "mlp": {"init": "transfer"}, + "normalization": {"init": "transfer"}, + }, + }, + }, + # S2: Add sliding_window - should inherit bias from source attention + { + "decoder": { + "block": { + "mixer": { + "mixers": { + "sliding_window": { + "type": "attention", + "init": "transfer", + "window_size": 512, + }, + }, + }, + }, + }, + }, + # S3: Add new attention with DIFFERENT bias config (random init) + { + "decoder": { + "block": { + "mixer": { + "mixers": { + "full_bias_attn": { + "type": "attention", + "init": "random", + "heads": 4, + "head_groups": 2, + "head_size": 16, + "rotary": {"type": "mistral_1d", "theta": 10000.0}, + "add_linear_biases": True, # All biases enabled + }, + }, + }, + }, + }, + }, + ] + + @pytest.fixture def comprehensive_torture_chain(): """Comprehensive torture chain exercising ALL conversion paths. diff --git a/fast_llm_external_models/tests/test_apriel2/test_compose_configs.py b/fast_llm_external_models/tests/test_apriel2/test_compose_configs.py index 0bd6ac88d..4380b1fbd 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_compose_configs.py +++ b/fast_llm_external_models/tests/test_apriel2/test_compose_configs.py @@ -288,6 +288,163 @@ def test_init_random_still_inherits_config(self, source_config): assert mixer["window_size"] == 512 +class TestBiasConfigInheritance: + """Test per-layer bias inheritance through surgery composition. + + These tests verify that the per-layer bias configuration (mirroring Fast-LLM's + AffineLinearConfig) is correctly inherited through surgery operations: + - query_layer.bias.enabled, key_layer.bias.enabled, etc. for attention + - layer_1.bias.enabled, layer_2.bias.enabled for MLP + """ + + @pytest.fixture + def source_config_with_bias(self): + """Source config with Qwen-style bias (QKV enabled, O disabled).""" + return { + "model_type": "apriel2", + "architectures": ["Apriel2ForCausalLM"], + "hidden_size": 256, + "vocab_size": 1000, + "decoder": { + "type": "fixed", + "num_blocks": 4, + "block": { + "mixer": { + "type": "attention", + "heads": 8, + "head_groups": 4, + "head_size": 32, + "rotary": {"type": "mistral_1d", "theta": 10000.0}, + # Qwen-style per-layer bias + "query_layer": {"bias": {"enabled": True}}, + "key_layer": {"bias": {"enabled": True}}, + "value_layer": {"bias": {"enabled": True}}, + "dense_layer": {"bias": {"enabled": False}}, + }, + "mlp": { + "type": "mlp", + "intermediate_size": 512, + "gated": False, + # Per-layer MLP bias + "layer_1": {"bias": {"enabled": True}}, + "layer_2": {"bias": {"enabled": False}}, + }, + "normalization": {"type": "rms_norm", "epsilon": 1e-5}, + }, + }, + } + + def test_same_type_inherits_attention_bias(self, source_config_with_bias): + """Same-type surgery inherits per-layer attention bias settings.""" + surgery = { + "decoder": { + "block": { + "mixer": { + "window_size": 512, # Add sliding window behavior + }, + }, + }, + } + result = compose_configs(source_config_with_bias, surgery) + + mixer = result["decoder"]["block"]["mixer"] + assert mixer["query_layer"]["bias"]["enabled"] is True + assert mixer["key_layer"]["bias"]["enabled"] is True + assert mixer["value_layer"]["bias"]["enabled"] is True + assert mixer["dense_layer"]["bias"]["enabled"] is False + + def test_same_type_inherits_mlp_bias(self, source_config_with_bias): + """Same-type surgery inherits per-layer MLP bias settings.""" + surgery = { + "decoder": { + "block": { + "mlp": { + "intermediate_size": 1024, # Change size + }, + }, + }, + } + result = compose_configs(source_config_with_bias, surgery) + + mlp = result["decoder"]["block"]["mlp"] + assert mlp["layer_1"]["bias"]["enabled"] is True + assert mlp["layer_2"]["bias"]["enabled"] is False + assert mlp["intermediate_size"] == 1024 + + def test_cross_type_attention_to_sliding_window_preserves_bias(self, source_config_with_bias): + """attention→sliding_window cross-type preserves per-layer bias.""" + surgery = { + "decoder": { + "block": { + "mixer": { + "type": "sliding_window", # Cross-type derivation + "window_size": 512, + }, + }, + }, + } + result = compose_configs(source_config_with_bias, surgery) + + mixer = result["decoder"]["block"]["mixer"] + assert mixer["type"] == "sliding_window" + # Bias settings preserved through cross-type + assert mixer["query_layer"]["bias"]["enabled"] is True + assert mixer["key_layer"]["bias"]["enabled"] is True + assert mixer["value_layer"]["bias"]["enabled"] is True + assert mixer["dense_layer"]["bias"]["enabled"] is False + + def test_stochastic_wrapper_inherits_bias(self, source_config_with_bias): + """Wrapping in stochastic inherits bias settings to all sub-mixers.""" + surgery = { + "decoder": { + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"init": "transfer"}, + "sliding_window": { + "type": "sliding_window", + "window_size": 512, + "init": "transfer", + }, + }, + }, + }, + }, + } + result = compose_configs(source_config_with_bias, surgery) + + mixers = result["decoder"]["block"]["mixer"]["mixers"] + + # Attention sub-mixer inherits bias + assert mixers["attention"]["query_layer"]["bias"]["enabled"] is True + assert mixers["attention"]["dense_layer"]["bias"]["enabled"] is False + + # Sliding window sub-mixer also inherits bias + assert mixers["sliding_window"]["query_layer"]["bias"]["enabled"] is True + assert mixers["sliding_window"]["dense_layer"]["bias"]["enabled"] is False + + def test_surgery_can_override_bias(self, source_config_with_bias): + """Surgery can explicitly override inherited bias settings.""" + surgery = { + "decoder": { + "block": { + "mixer": { + "dense_layer": {"bias": {"enabled": True}}, # Override O bias + }, + }, + }, + } + result = compose_configs(source_config_with_bias, surgery) + + mixer = result["decoder"]["block"]["mixer"] + # Q/K/V unchanged + assert mixer["query_layer"]["bias"]["enabled"] is True + # O bias overridden + assert mixer["dense_layer"]["bias"]["enabled"] is True + + class TestComposeConfigsRealYAML: """Test compose_configs with real YAML surgery files.""" diff --git a/fast_llm_external_models/tests/test_apriel2/test_expr_plan.py b/fast_llm_external_models/tests/test_apriel2/test_expr_plan.py index c487ab3a3..569ed88fd 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_expr_plan.py +++ b/fast_llm_external_models/tests/test_apriel2/test_expr_plan.py @@ -1711,3 +1711,205 @@ def test_conversion_plan_targets_match_model_state_dict(self, llava_pixtral_conf assert not missing_from_plan, f"Plan missing keys that model expects: {sorted(missing_from_plan)}" assert not extra_in_plan, f"Plan has extra keys model doesn't expect: {sorted(extra_in_plan)}" + + +class TestBiasPlanGeneration: + """Test that surgery plans correctly handle per-layer bias configurations. + + These tests verify that plan_surgery correctly includes/excludes bias + weight mappings based on the per-layer bias settings: + - query_layer.bias.enabled, key_layer.bias.enabled, etc. for attention + - layer_1.bias.enabled, layer_2.bias.enabled for MLP + """ + + @pytest.fixture + def source_config_with_bias(self): + """Source config with Qwen-style bias (QKV enabled, O disabled).""" + return { + "model_type": "apriel2", + "hidden_size": 256, + "vocab_size": 1000, + "decoder": { + "type": "fixed", + "num_blocks": 2, + "block": { + "mixer": { + "type": "attention", + "heads": 8, + "head_groups": 4, + "head_size": 32, + "rotary": {"type": "mistral_1d", "theta": 10000.0}, + # Qwen-style: QKV bias enabled, O bias disabled + "query_layer": {"bias": {"enabled": True}}, + "key_layer": {"bias": {"enabled": True}}, + "value_layer": {"bias": {"enabled": True}}, + "dense_layer": {"bias": {"enabled": False}}, + }, + "mlp": { + "type": "mlp", + "intermediate_size": 512, + "gated": False, + # Per-layer MLP bias: layer_1 enabled, layer_2 disabled + "layer_1": {"bias": {"enabled": True}}, + "layer_2": {"bias": {"enabled": False}}, + }, + "normalization": {"type": "rms_norm", "epsilon": 1e-5}, + }, + }, + } + + def test_plan_includes_enabled_attention_biases(self, source_config_with_bias): + """Surgery plan includes bias mappings for enabled attention biases.""" + from fast_llm_external_models.apriel2.conversion.config import compose_configs + from fast_llm_external_models.apriel2.conversion.converters import plan_surgery + + target_config = compose_configs(source_config_with_bias, { + "decoder": { + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"init": "transfer"}, + }, + }, + "mlp": {"init": "transfer"}, + }, + }, + }) + + plan = plan_surgery(source_config_with_bias, target_config) + mapping_strs = [str(k) for k in plan.mappings.keys()] + + # Should have q_proj.bias, k_proj.bias, v_proj.bias mappings + q_bias = [m for m in mapping_strs if "q_proj.bias" in m] + k_bias = [m for m in mapping_strs if "k_proj.bias" in m] + v_bias = [m for m in mapping_strs if "v_proj.bias" in m] + + assert len(q_bias) > 0, "Should have q_proj.bias mappings" + assert len(k_bias) > 0, "Should have k_proj.bias mappings" + assert len(v_bias) > 0, "Should have v_proj.bias mappings" + + def test_plan_excludes_disabled_attention_biases(self, source_config_with_bias): + """Surgery plan excludes bias mappings for disabled attention biases.""" + from fast_llm_external_models.apriel2.conversion.config import compose_configs + from fast_llm_external_models.apriel2.conversion.converters import plan_surgery + + target_config = compose_configs(source_config_with_bias, { + "decoder": { + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"init": "transfer"}, + }, + }, + "mlp": {"init": "transfer"}, + }, + }, + }) + + plan = plan_surgery(source_config_with_bias, target_config) + mapping_strs = [str(k) for k in plan.mappings.keys()] + + # Should NOT have o_proj.bias mappings (disabled) + o_bias = [m for m in mapping_strs if "o_proj.bias" in m] + assert len(o_bias) == 0, f"Should not have o_proj.bias mappings, found: {o_bias}" + + def test_plan_includes_enabled_mlp_biases(self, source_config_with_bias): + """Surgery plan includes bias mappings for enabled MLP biases.""" + from fast_llm_external_models.apriel2.conversion.config import compose_configs + from fast_llm_external_models.apriel2.conversion.converters import plan_surgery + + target_config = compose_configs(source_config_with_bias, { + "decoder": { + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"init": "transfer"}, + }, + }, + "mlp": {"init": "transfer"}, + }, + }, + }) + + plan = plan_surgery(source_config_with_bias, target_config) + mapping_strs = [str(k) for k in plan.mappings.keys()] + + # Should have up_proj.bias (layer_1) mappings + up_bias = [m for m in mapping_strs if "up_proj.bias" in m] + assert len(up_bias) > 0, "Should have up_proj.bias mappings" + + def test_plan_excludes_disabled_mlp_biases(self, source_config_with_bias): + """Surgery plan excludes bias mappings for disabled MLP biases.""" + from fast_llm_external_models.apriel2.conversion.config import compose_configs + from fast_llm_external_models.apriel2.conversion.converters import plan_surgery + + target_config = compose_configs(source_config_with_bias, { + "decoder": { + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"init": "transfer"}, + }, + }, + "mlp": {"init": "transfer"}, + }, + }, + }) + + plan = plan_surgery(source_config_with_bias, target_config) + mapping_strs = [str(k) for k in plan.mappings.keys()] + + # Should NOT have down_proj.bias (layer_2) mappings + down_bias = [m for m in mapping_strs if "down_proj.bias" in m] + assert len(down_bias) == 0, f"Should not have down_proj.bias mappings, found: {down_bias}" + + def test_plan_random_init_creates_init_expressions_for_bias(self, source_config_with_bias): + """Random init creates Init expressions for bias weights.""" + from fast_llm_external_models.apriel2.conversion.converters import plan_surgery + + # Surgery spec - pass directly to plan_surgery (NOT composed, to preserve init) + surgery = { + "decoder": { + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"init": "transfer"}, + "new_attention": { + "type": "attention", + "init": "random", # This triggers random init + "heads": 8, + "head_groups": 4, + "head_size": 32, + "rotary": {"type": "mistral_1d", "theta": 10000.0}, + "add_linear_biases": True, # All biases enabled + }, + }, + }, + }, + }, + } + + # Pass surgery spec directly - init fields are preserved + plan = plan_surgery(source_config_with_bias, surgery) + + # Check that new_attention biases use Init expressions + new_mixer_bias_keys = [ + k for k in plan.mappings.keys() + if "new_attention" in str(k) and "bias" in str(k) + ] + + assert len(new_mixer_bias_keys) > 0, "Should have bias mappings for new_attention" + + for key in new_mixer_bias_keys: + expr = plan.mappings[key] + assert isinstance(expr, Init), f"{key} should be Init, got {type(expr)}" diff --git a/fast_llm_external_models/tests/test_apriel2/test_integration.py b/fast_llm_external_models/tests/test_apriel2/test_integration.py new file mode 100644 index 000000000..c11302d22 --- /dev/null +++ b/fast_llm_external_models/tests/test_apriel2/test_integration.py @@ -0,0 +1,335 @@ +"""Integration tests for Qwen2 -> Apriel2 -> Fast-LLM conversion pipeline. + +Tests verify the full conversion chain: +1. Qwen2 -> Apriel2 (external module conversion) +2. Apriel2 + Surgery -> Supernet (stochastic mixer creation) +3. Supernet -> Fast-LLM -> Supernet (roundtrip through training format) + +Test Strategy: +- Use real HuggingFace model (Qwen2.5-0.5B) for meaningful validation +- Separate config preservation tests from numerical equivalence tests +- Parameterize both conversion stages AND input variations +- Single test implementation applied across all stages +""" + +import json +import tempfile +from pathlib import Path + +import pytest +import torch + +from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config +from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForCausalLM +from fast_llm_external_models.apriel2.conversion import ( + compose, + compose_configs, + execute, + plan_surgery, +) +from fast_llm_external_models.apriel2.conversion.expr import W +from fast_llm_external_models.apriel2.conversion.qwen2.config import convert_config as convert_qwen2_config +from fast_llm_external_models.apriel2.conversion.qwen2.plan import plan_qwen2_to_apriel2 + +from .conftest import requires_fastllm + + +# ============================================================================= +# Test Input Variations +# ============================================================================= + +TEST_INPUTS = pytest.mark.parametrize( + "prompts,max_new_tokens", + [ + pytest.param(["Hello world"], 10, id="single_short"), + pytest.param(["Hi", "The quick brown fox jumps over the lazy dog"], 20, id="batch_varied"), + pytest.param(["Once upon a time"], 50, id="long_generation"), + ], +) + + +# ============================================================================= +# Conversion Fixtures +# ============================================================================= + + +@pytest.fixture(scope="module") +def qwen2_source(): + """Load Qwen2.5-0.5B as the source/reference model.""" + from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig + + model_name = "Qwen/Qwen2.5-0.5B" + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + model = AutoModelForCausalLM.from_pretrained( + model_name, torch_dtype=torch.float32, trust_remote_code=True + ) + config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + model.eval() + + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "left" + + return { + "model": model, + "tokenizer": tokenizer, + "config_dict": config.to_dict(), + "state_dict": model.state_dict(), + } + + +@pytest.fixture(scope="module") +def apriel2_converted(qwen2_source): + """Stage 1: Qwen2 -> Apriel2.""" + config_dict = convert_qwen2_config(qwen2_source["config_dict"]) + plan = plan_qwen2_to_apriel2(qwen2_source["config_dict"]) + weights = execute(plan, {W(k): v for k, v in qwen2_source["state_dict"].items()}, seed=42) + + config = Apriel2Config(**config_dict) + model = Apriel2ForCausalLM(config) + model.load_state_dict({str(k): v for k, v in weights.items()}, strict=False) + model.eval() + + return {"model": model, "config_dict": config_dict, "plan": plan, "name": "Apriel2"} + + +@pytest.fixture(scope="module") +def supernet_converted(qwen2_source, apriel2_converted): + """Stage 2: Apriel2 + Surgery -> Supernet.""" + surgery_spec = { + "decoder": { + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"type": "attention", "init": "transfer"}, + "sliding_window": { + "type": "attention", + "init": "transfer", + "window_size": 4096, + }, + }, + }, + }, + }, + } + + apriel_config = apriel2_converted["config_dict"] + supernet_config = compose_configs(apriel_config, surgery_spec) + + full_plan = compose( + apriel2_converted["plan"], + plan_surgery(apriel_config, supernet_config), + ) + + weights = execute(full_plan, {W(k): v for k, v in qwen2_source["state_dict"].items()}, seed=42) + + config = Apriel2Config(**supernet_config) + model = Apriel2ForCausalLM(config) + model.load_state_dict({str(k): v for k, v in weights.items()}, strict=False) + model.eval() + + return {"model": model, "config_dict": supernet_config, "name": "Supernet"} + + +@pytest.fixture(scope="module") +def roundtrip_converted(supernet_converted, qwen2_source): + """Stage 3: Supernet -> Fast-LLM -> Supernet.""" + from fast_llm.engine.checkpoint.config import ( + CheckpointLoadConfig, + CheckpointSaveConfig, + FastLLMCheckpointFormat, + ) + from fast_llm.engine.checkpoint.convert import ConvertConfig + from fast_llm.models.gpt.config import GPTModelConfig + from fast_llm.models.gpt.conversion.config import Apriel2TextCheckpointFormat + + with tempfile.TemporaryDirectory() as tmpdir: + tmpdir = Path(tmpdir) + supernet_path = tmpdir / "supernet" + fastllm_path = tmpdir / "fastllm" + roundtrip_path = tmpdir / "roundtrip" + + supernet_converted["model"].save_pretrained(supernet_path) + qwen2_source["tokenizer"].save_pretrained(supernet_path) + + ConvertConfig( + model=GPTModelConfig, + input=CheckpointLoadConfig(path=supernet_path, format=Apriel2TextCheckpointFormat), + output=CheckpointSaveConfig(path=fastllm_path, format=FastLLMCheckpointFormat), + ).run() + + ConvertConfig( + model=GPTModelConfig, + input=CheckpointLoadConfig(path=fastllm_path, format=FastLLMCheckpointFormat), + output=CheckpointSaveConfig(path=roundtrip_path, format=Apriel2TextCheckpointFormat), + ).run() + + model = Apriel2ForCausalLM.from_pretrained(roundtrip_path) + model.eval() + + with open(roundtrip_path / "config.json") as f: + config_dict = json.load(f) + + yield {"model": model, "config_dict": config_dict, "name": "Roundtrip"} + + +# ============================================================================= +# Parameterized Fixture: All Conversion Stages +# ============================================================================= + + +@pytest.fixture(params=["apriel2", "supernet", "roundtrip"]) +def converted_model(request, apriel2_converted, supernet_converted, roundtrip_converted): + """Parameterized fixture providing each conversion stage for testing. + + This allows a single test to run against all stages automatically. + """ + if request.param == "roundtrip": + pytest.importorskip("fast_llm") + + return { + "apriel2": apriel2_converted, + "supernet": supernet_converted, + "roundtrip": roundtrip_converted, + }[request.param] + + +# ============================================================================= +# Config Preservation Tests +# ============================================================================= + + +@pytest.mark.slow +class TestConfigPreservation: + """Verify configs are correctly preserved through the conversion chain.""" + + def test_apriel2_structure(self, qwen2_source, apriel2_converted): + """Qwen2 -> Apriel2 preserves model dimensions.""" + qwen = qwen2_source["config_dict"] + apriel = apriel2_converted["config_dict"] + + assert apriel["hidden_size"] == qwen["hidden_size"] + assert apriel["vocab_size"] == qwen["vocab_size"] + assert apriel["decoder"]["num_blocks"] == qwen["num_hidden_layers"] + + def test_apriel2_bias_pattern(self, apriel2_converted): + """Qwen2 -> Apriel2 preserves Qwen-style bias (QKV yes, O no).""" + mixer = apriel2_converted["config_dict"]["decoder"]["block"]["mixer"] + + assert mixer["query_layer"]["bias"]["enabled"] is True + assert mixer["key_layer"]["bias"]["enabled"] is True + assert mixer["value_layer"]["bias"]["enabled"] is True + assert mixer["dense_layer"]["bias"]["enabled"] is False + + def test_supernet_structure(self, supernet_converted): + """Surgery creates correct stochastic mixer structure.""" + mixer = supernet_converted["config_dict"]["decoder"]["block"]["mixer"] + + assert mixer["type"] == "stochastic" + assert mixer["main_mixer_name"] == "attention" + assert set(mixer["mixers"].keys()) == {"attention", "sliding_window"} + + def test_supernet_bias_inheritance(self, supernet_converted): + """Submixers inherit bias settings from source.""" + mixer = supernet_converted["config_dict"]["decoder"]["block"]["mixer"] + + for name in ["attention", "sliding_window"]: + assert mixer["mixers"][name]["query_layer"]["bias"]["enabled"] is True + assert mixer["mixers"][name]["dense_layer"]["bias"]["enabled"] is False + + @requires_fastllm + def test_roundtrip_structure(self, roundtrip_converted): + """Fast-LLM roundtrip preserves stochastic mixer structure.""" + mixer = roundtrip_converted["config_dict"]["decoder"]["block"]["mixer"] + + assert mixer["type"] == "stochastic" + assert mixer["main_mixer_name"] == "attention" + assert set(mixer["mixers"].keys()) == {"attention", "sliding_window"} + + @requires_fastllm + def test_roundtrip_bias_preservation(self, roundtrip_converted): + """Fast-LLM roundtrip preserves per-layer bias settings.""" + mixer = roundtrip_converted["config_dict"]["decoder"]["block"]["mixer"] + + for name in ["attention", "sliding_window"]: + assert mixer["mixers"][name]["query_layer"]["bias"]["enabled"] is True + assert mixer["mixers"][name]["dense_layer"]["bias"]["enabled"] is False + + +# ============================================================================= +# Numerical Equivalence Tests +# ============================================================================= + + +@pytest.mark.slow +class TestNumericalEquivalence: + """Verify all conversion stages produce numerically identical outputs. + + Uses parameterized fixtures to test all stages with all input variations, + giving us 3 stages × 3 inputs = 9 test cases from a single test function. + """ + + @TEST_INPUTS + def test_logits_match(self, qwen2_source, converted_model, prompts, max_new_tokens): + """Converted model produces identical logits to source.""" + tokenizer = qwen2_source["tokenizer"] + ref_model = qwen2_source["model"] + test_model = converted_model["model"] + stage = converted_model["name"] + + inputs = tokenizer(prompts, return_tensors="pt", padding=True) + ref_device = next(ref_model.parameters()).device + test_device = next(test_model.parameters()).device + + with torch.no_grad(): + ref_logits = ref_model( + input_ids=inputs.input_ids.to(ref_device), + attention_mask=inputs.attention_mask.to(ref_device), + ).logits.cpu() + + test_logits = test_model( + input_ids=inputs.input_ids.to(test_device), + attention_mask=inputs.attention_mask.to(test_device), + ).logits.cpu() + + max_diff = (ref_logits - test_logits).abs().max().item() + assert torch.allclose(ref_logits, test_logits, rtol=1e-4, atol=1e-4), ( + f"{stage} logits mismatch: max diff = {max_diff:.6f}" + ) + + @TEST_INPUTS + def test_generation_match(self, qwen2_source, converted_model, prompts, max_new_tokens): + """Converted model produces identical generation to source.""" + tokenizer = qwen2_source["tokenizer"] + ref_model = qwen2_source["model"] + test_model = converted_model["model"] + stage = converted_model["name"] + + inputs = tokenizer(prompts, return_tensors="pt", padding=True) + ref_device = next(ref_model.parameters()).device + test_device = next(test_model.parameters()).device + + with torch.no_grad(): + ref_gen = ref_model.generate( + input_ids=inputs.input_ids.to(ref_device), + attention_mask=inputs.attention_mask.to(ref_device), + max_new_tokens=max_new_tokens, + do_sample=False, + pad_token_id=tokenizer.pad_token_id, + ).cpu() + + test_gen = test_model.generate( + input_ids=inputs.input_ids.to(test_device), + attention_mask=inputs.attention_mask.to(test_device), + max_new_tokens=max_new_tokens, + do_sample=False, + pad_token_id=tokenizer.pad_token_id, + ).cpu() + + assert torch.equal(ref_gen, test_gen), ( + f"{stage} generation mismatch:\n" + f" Reference: {tokenizer.batch_decode(ref_gen, skip_special_tokens=True)}\n" + f" Test: {tokenizer.batch_decode(test_gen, skip_special_tokens=True)}" + ) diff --git a/fast_llm_external_models/tests/test_apriel2/test_modeling.py b/fast_llm_external_models/tests/test_apriel2/test_modeling.py index 5dbd36159..47c877d09 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_modeling.py +++ b/fast_llm_external_models/tests/test_apriel2/test_modeling.py @@ -12,7 +12,8 @@ class TestApriel2Modeling: "apriel2_config_tiny", "apriel2_config_stochastic", "apriel2_config_multi_mixer", - "apriel2_config_all_mixers" # Tests all 4 mixer types + "apriel2_config_all_mixers", # Tests all 4 mixer types + "apriel2_config_with_bias", # Tests per-layer bias and non-gated MLP ]) def test_model_end_to_end(self, config_name, request): """Test instantiation, forward pass, cache correctness, and generation. diff --git a/fast_llm_external_models/tests/test_apriel2/test_plan_composition_torture.py b/fast_llm_external_models/tests/test_apriel2/test_plan_composition_torture.py index 3b4adc7f5..76a77ccb6 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_plan_composition_torture.py +++ b/fast_llm_external_models/tests/test_apriel2/test_plan_composition_torture.py @@ -1980,3 +1980,151 @@ def test_expand_surgery_chain_preserves_invariant(self): # After cycling and restore, we should be back to the same state assert current_config == config_after_original + + +class TestBiasSurgeryChain: + """Torture tests for per-layer bias inheritance through surgery operations. + + Uses apriel2_config_with_bias + bias_surgery_chain to test that: + - Qwen-style per-layer attention bias (QKV enabled, O disabled) survives surgery + - Non-gated MLP per-layer bias (layer_1 enabled, layer_2 disabled) survives surgery + - Bias settings are correctly inherited by new sub-mixers + - Bias is correctly tracked in surgery plans + """ + + @pytest.fixture + def bias_source_config(self, apriel2_config_with_bias): + """Convert Apriel2Config to dict for surgery operations.""" + return apriel2_config_with_bias.to_dict() + + def test_bias_survives_stochastic_wrapper(self, bias_source_config, bias_surgery_chain): + """Test that bias settings survive wrapping in stochastic mixer.""" + # Apply first surgery (wrap in stochastic) + result = compose_configs(bias_source_config, bias_surgery_chain[0]) + + # Check attention sub-mixer inherited bias settings + mixer = result["decoder"]["block"]["mixer"] + assert mixer["type"] == "stochastic" + + attn_mixer = mixer["mixers"]["attention"] + assert attn_mixer["query_layer"]["bias"]["enabled"] is True + assert attn_mixer["key_layer"]["bias"]["enabled"] is True + assert attn_mixer["value_layer"]["bias"]["enabled"] is True + assert attn_mixer["dense_layer"]["bias"]["enabled"] is False + + # Check MLP bias survived + mlp = result["decoder"]["block"]["mlp"] + assert mlp["layer_1"]["bias"]["enabled"] is True + assert mlp["layer_2"]["bias"]["enabled"] is False + + def test_new_submixer_inherits_bias(self, bias_source_config, bias_surgery_chain): + """Test that new sub-mixers inherit bias from source attention.""" + # Apply S1 + S2 (wrap in stochastic, add sliding_window) + config = bias_source_config + for surgery in bias_surgery_chain[:2]: + config = compose_configs(config, surgery) + + # sliding_window should inherit bias from source attention + mixer = config["decoder"]["block"]["mixer"] + sw_mixer = mixer["mixers"]["sliding_window"] + + assert sw_mixer["query_layer"]["bias"]["enabled"] is True + assert sw_mixer["key_layer"]["bias"]["enabled"] is True + assert sw_mixer["value_layer"]["bias"]["enabled"] is True + assert sw_mixer["dense_layer"]["bias"]["enabled"] is False + + def test_full_bias_chain_produces_valid_config(self, bias_source_config, bias_surgery_chain): + """Test that full bias surgery chain produces valid config.""" + config = bias_source_config + for surgery in bias_surgery_chain: + config = compose_configs(config, surgery) + + # Verify final config structure + mixer = config["decoder"]["block"]["mixer"] + assert mixer["type"] == "stochastic" + assert "attention" in mixer["mixers"] + assert "sliding_window" in mixer["mixers"] + assert "full_bias_attn" in mixer["mixers"] + + # attention and sliding_window inherit Qwen-style bias + for name in ["attention", "sliding_window"]: + sub = mixer["mixers"][name] + assert sub["query_layer"]["bias"]["enabled"] is True + assert sub["dense_layer"]["bias"]["enabled"] is False + + # full_bias_attn has add_linear_biases=True but per-layer settings inherited from + # source take precedence, so O bias is still disabled + full_bias = mixer["mixers"]["full_bias_attn"] + assert full_bias.get("add_linear_biases") is True + # Per-layer dense_layer.bias.enabled=False inherited from source takes precedence + assert full_bias["dense_layer"]["bias"]["enabled"] is False + + def test_bias_plan_has_correct_mappings(self, bias_source_config, bias_surgery_chain): + """Test that surgery plan correctly includes/excludes bias weight mappings.""" + # Compose config first to get full target config with inherited bias settings + target_config = compose_configs(bias_source_config, bias_surgery_chain[0]) + plan = plan_surgery(bias_source_config, target_config) + mapping_strs = [str(k) for k in plan.mappings.keys()] + + # Should have q_proj.bias (enabled) + q_bias = [m for m in mapping_strs if "q_proj.bias" in m] + assert len(q_bias) > 0, "Should have q_proj.bias mappings" + + # Should NOT have o_proj.bias (disabled) + o_bias = [m for m in mapping_strs if "o_proj.bias" in m] + assert len(o_bias) == 0, "Should not have o_proj.bias mappings" + + # Should have up_proj.bias (layer_1 enabled) + up_bias = [m for m in mapping_strs if "up_proj.bias" in m] + assert len(up_bias) > 0, "Should have up_proj.bias mappings" + + # Should NOT have down_proj.bias (layer_2 disabled) + down_bias = [m for m in mapping_strs if "down_proj.bias" in m] + assert len(down_bias) == 0, "Should not have down_proj.bias mappings" + + def test_bias_chain_produces_working_model(self, bias_source_config, bias_surgery_chain): + """Test that bias surgery chain produces a working model.""" + from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForCausalLM + + # Apply full chain + config = bias_source_config + for surgery in bias_surgery_chain: + config = compose_configs(config, surgery) + + # Create model + apriel_config = Apriel2Config(**config) + model = Apriel2ForCausalLM(apriel_config) + model.eval() + + # Verify model structure has correct biases + block = model.model.decoder.blocks[0] + + # attention sub-mixer should have QKV bias, no O bias + attn = block.mixer.mixers["attention"] + assert attn.q_proj.bias is not None + assert attn.k_proj.bias is not None + assert attn.v_proj.bias is not None + assert attn.o_proj.bias is None + + # sliding_window should also inherit bias settings + sw = block.mixer.mixers["sliding_window"] + assert sw.q_proj.bias is not None + assert sw.o_proj.bias is None + + # full_bias_attn inherits per-layer bias from source (even with add_linear_biases=True, + # per-layer settings take precedence in same-type inheritance) + full_bias = block.mixer.mixers["full_bias_attn"] + assert full_bias.q_proj.bias is not None + # O bias is disabled because inherited per-layer dense_layer.bias.enabled=False + # takes precedence over add_linear_biases=True + assert full_bias.o_proj.bias is None + + # MLP should have layer_1 bias, no layer_2 bias + assert block.mlp.up_proj.bias is not None + assert block.mlp.down_proj.bias is None + + # Forward pass should work + input_ids = torch.randint(0, config["vocab_size"], (1, 10)) + with torch.no_grad(): + outputs = model(input_ids, use_cache=False) + assert outputs.logits.shape == (1, 10, config["vocab_size"]) From 7053d8cdf7ec941d84233fd4aa89be3ebf04b645 Mon Sep 17 00:00:00 2001 From: Torsten Scholak Date: Mon, 15 Dec 2025 19:30:40 +0000 Subject: [PATCH 05/25] Add conversation format support for SFT data preparation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Enable automatic loss masking span computation for chat/conversation datasets using HuggingFace's {% generation %}...{% endgeneration %} markers. This allows preparing SFT data (e.g., Tulu 3) with proper masking of non-assistant content. - Add ConversationSourceConfig with `type: conversation` for chat data - Add validate_chat_template() to verify tokenizer has generation markers - Add apply_chat_template_with_spans() for text + masking span extraction - Tokenizer must have built-in chat template with generation markers 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- fast_llm/data/preparator/gpt_memmap/config.py | 106 ++++++++++++++++- .../data/preparator/gpt_memmap/prepare.py | 20 +++- fast_llm/data/preprocessing/tokenizer.py | 108 ++++++++++++++++++ tests/data/test_preparator.py | 31 ++++- tests/data/test_tokenizer.py | 93 +++++++++++---- 5 files changed, 333 insertions(+), 25 deletions(-) diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 503b400c3..2aa0fbf31 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -15,11 +15,14 @@ from fast_llm.data.preparator.gpt_memmap.prepare import GPTMemmapDatasetPreparator -@config_class() +@config_class(registry=True) class LanguageModelSourceConfig(Config): """ A schema holding the name of each relevant column in the dataset. Setting optional entries will enable the associated feature. + + This is the base class for source schemas. Use `type: text` (default) for + plain text datasets, or `type: conversation` for chat/conversation datasets. """ text: str = Field( @@ -48,6 +51,8 @@ def columns(self) -> list[str]: columns.append(self.loss_masking_spans) if self.has_preference_spans: columns.extend([self.chosen_span, self.rejected_span]) + if self.has_images: + columns.extend([self.images, self.image_positions]) return columns @functools.cached_property @@ -64,12 +69,111 @@ def has_images(self) -> bool: Assert.eq(self.images is None, self.image_positions is None) return self.images is not None + @functools.cached_property + def has_conversation(self) -> bool: + """Whether this is a conversation source schema.""" + return False + def _validate(self): super()._validate() if self.has_preference_spans and self.has_loss_masking_span: raise ValueError(f"Can not enable both loss masking and preference spans.") +@config_class(dynamic_type={LanguageModelSourceConfig: "text"}) +class TextSourceConfig(LanguageModelSourceConfig): + """ + Source schema for plain text datasets (default). + + The dataset should have a text column containing the document text. + Optionally, it can have additional columns for loss masking spans, + preference spans (for DPO), or images. + """ + + pass + + +@config_class(dynamic_type={LanguageModelSourceConfig: "conversation"}) +class ConversationSourceConfig(LanguageModelSourceConfig): + """ + Source schema for chat/conversation datasets (e.g., Tulu 3, ShareGPT, OpenAI format). + + The dataset should have a messages column containing a list of message dicts, + where each message has 'role' and 'content' keys. Common roles include: + - 'system': System prompt + - 'user': User input + - 'assistant': Model response (trained on by default) + - 'tool': Tool/function results + - 'ipython': Code execution results + + The conversation is formatted using the tokenizer's chat template, which must + contain {% generation %}...{% endgeneration %} markers to define which content + to train on. Loss masking spans are automatically computed from these markers. + + Example dataset format: + { + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello!"}, + {"role": "assistant", "content": "Hi there!"}, + ] + } + """ + + # Override text field - not used directly for conversation format + text: None | str = Field( + default=None, + desc="Not used for conversation format. Text is generated from messages.", + hint=FieldHint.optional, + ) + + # Conversation-specific fields + messages: str = Field( + default="messages", + desc="Field containing the conversation messages list. Each message should have 'role' and 'content' keys.", + hint=FieldHint.core, + ) + + add_generation_prompt: bool = Field( + default=False, + desc="Whether to add a generation prompt at the end of the conversation. " + "Typically False for training data.", + hint=FieldHint.optional, + ) + + @functools.cached_property + def columns(self) -> list[str]: + # For conversation format, we read the messages column, not text + columns = [self.messages] + # Images can still be used with conversation format + if self.has_images: + columns.extend([self.images, self.image_positions]) + return columns + + @functools.cached_property + def has_conversation(self) -> bool: + return True + + @functools.cached_property + def has_loss_masking_span(self) -> bool: + # Conversation format always generates loss masking spans + return True + + def _validate(self): + # Skip parent validation that checks text field + Config._validate(self) + if self.has_preference_spans: + raise ValueError("Preference spans are not supported with conversation format.") + if self.has_images: + # Images with conversation format would require computing image positions in the + # chat-template-formatted text, which is complex and format-dependent. + # For VLM training with conversations, preprocess the data to plain text format first. + raise ValueError( + "Images are not yet supported with conversation format. " + "For multimodal conversation data, preprocess to plain text format with image positions." + ) + + @config_class() class GPTHuggingfaceDatasetConfig(Config): path: str | pathlib.Path = Field( diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 2ea81d8a6..f349b1979 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -132,6 +132,10 @@ def run(self) -> None: # Load tokenizer self._tokenizer = self._config.tokenizer.get_tokenizer() + # Validate chat template for conversation format + if self._source_schema.has_conversation: + self._tokenizer.validate_chat_template() + # Decide the datatype based on the tokenizer vocabulary size self._data_type = ( get_unsigned_integer_type(self._tokenizer.vocab_size) @@ -216,9 +220,21 @@ def _preprocessing_config(self) -> LanguageModelPreprocessingConfig: ) def _prepare_sample(self, sample: dict[str, typing.Any]) -> LanguageModelSample: - text = sample[self._source_schema.text] all_spans = [] - if self._source_schema.has_loss_masking_span: + + if self._source_schema.has_conversation: + # Conversation format: apply chat template and compute loss masking spans + messages = sample[self._source_schema.messages] + text, loss_masking_spans = self._tokenizer.apply_chat_template_with_spans( + messages, + add_generation_prompt=self._source_schema.add_generation_prompt, + ) + all_spans.extend([(SpanType.loss_masking, span) for span in loss_masking_spans]) + else: + # Plain text format + text = sample[self._source_schema.text] + + if self._source_schema.has_loss_masking_span and not self._source_schema.has_conversation: # Spans are typically stored in the (begin, last) format. We convert to (begin, end) range format. loss_masking_spans = _sort_spans( (SpanType.loss_masking, (begin, last + 1)) diff --git a/fast_llm/data/preprocessing/tokenizer.py b/fast_llm/data/preprocessing/tokenizer.py index abfb5b3d2..924dc64b2 100644 --- a/fast_llm/data/preprocessing/tokenizer.py +++ b/fast_llm/data/preprocessing/tokenizer.py @@ -213,3 +213,111 @@ def _remove_delimiters( @property def eod(self): return self.eod_id + + @staticmethod + def _has_generation_markers(template: str | None) -> bool: + """Check if a template has generation markers.""" + return template is not None and "{% generation %}" in template + + def validate_chat_template(self) -> None: + """ + Validate the tokenizer's chat template has generation markers. + + Raises: + ValueError: If the tokenizer lacks a chat template or generation markers. + """ + template = self.tokenizer.chat_template + + if template is None: + raise ValueError( + "Tokenizer does not have a chat template. " + "Conversation format requires a tokenizer with a built-in chat template " + "containing {% generation %}...{% endgeneration %} markers." + ) + + if not self._has_generation_markers(template): + raise ValueError( + "Tokenizer's chat template does not contain {% generation %}...{% endgeneration %} markers. " + "These markers are required to determine which tokens to train on. " + "Please use a tokenizer with generation markers in its chat template." + ) + + def apply_chat_template_with_spans( + self, + messages: list[dict[str, str]], + *, + add_generation_prompt: bool = False, + ) -> tuple[str, list[tuple[int, int]]]: + """ + Apply the tokenizer's chat template to messages and compute loss masking spans. + + This method converts a list of messages (OpenAI/Tulu format) into formatted + text and computes character-level spans that should be MASKED (not trained on). + + Note: Call validate_chat_template() once before using this method to ensure + the tokenizer has a valid chat template with generation markers. + + Args: + messages: List of message dicts with 'role' and 'content' keys. + add_generation_prompt: Whether to add a generation prompt at the end. + + Returns: + Tuple of (formatted_text, loss_masking_spans) where loss_masking_spans + is a list of (start, end) character positions to MASK (not train on). + """ + if not messages: + return "", [] + + return self._apply_chat_template(messages, add_generation_prompt) + + def _apply_chat_template( + self, + messages: list[dict[str, str]], + add_generation_prompt: bool, + ) -> tuple[str, list[tuple[int, int]]]: + """Use HF's return_assistant_tokens_mask for precise token-level masking.""" + # Get tokens and assistant mask + result = self.tokenizer.apply_chat_template( + messages, + tokenize=True, + return_assistant_tokens_mask=True, + return_dict=True, + add_generation_prompt=add_generation_prompt, + ) + + tokens = result["input_ids"] + train_mask = result["assistant_masks"] + + # Get text for output + full_text = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=add_generation_prompt, + ) + + # Convert token mask to character spans using detokenization + # We need spans for tokens where train_mask=0 (should be masked/not trained on) + loss_masking_spans = [] + current_span_start = None + + # Track character positions by decoding incrementally + char_positions = [0] + for i in range(len(tokens)): + decoded = self.tokenizer.decode(tokens[: i + 1]) + char_positions.append(len(decoded)) + + for i, is_train in enumerate(train_mask): + if not is_train: # This token should be masked + if current_span_start is None: + current_span_start = char_positions[i] + else: # This token should be trained on + if current_span_start is not None: + loss_masking_spans.append((current_span_start, char_positions[i])) + current_span_start = None + + # Close any open span + if current_span_start is not None: + loss_masking_spans.append((current_span_start, char_positions[-1])) + + return full_text, loss_masking_spans + diff --git a/tests/data/test_preparator.py b/tests/data/test_preparator.py index f4f6fab82..ccef94d03 100644 --- a/tests/data/test_preparator.py +++ b/tests/data/test_preparator.py @@ -6,7 +6,11 @@ from fast_llm.data.dataset.config import BlendedDatasetConfig, MemmapDatasetConfig, SamplingParameters from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig from fast_llm.data.dataset.memmap import MemmapDataset -from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig +from fast_llm.data.preparator.gpt_memmap.config import ( + ConversationSourceConfig, + GPTMemmapDatasetPreparatorConfig, + LanguageModelSourceConfig, +) from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.data.preprocessing.tokenizer import TokenizerConfig from fast_llm.utils import Assert @@ -198,3 +202,28 @@ def test_dataset_preparator_from_hub(): tokenizer.detokenize(dataset.get_document(index).tokens.tokens), f"<|endoftext|>{hf_dataset[index]["answer"]}<|endoftext|>", ) + + +# ============================================================================= +# Conversation Format Tests +# ============================================================================= + + +def test_conversation_source_config(): + """Test conversation source schema configuration.""" + config = LanguageModelSourceConfig.from_dict({"type": "conversation"}) + Assert.custom(isinstance, config, ConversationSourceConfig) + Assert.eq(config.messages, "messages") + Assert.eq(config.has_conversation, True) + Assert.eq(config.has_loss_masking_span, True) + Assert.eq(config.columns, ["messages"]) + + +def test_conversation_config_validation(): + """Test conversation config validation errors.""" + with pytest.raises(ValueError, match="Images are not yet supported"): + LanguageModelSourceConfig.from_dict({ + "type": "conversation", + "images": "images", + "image_positions": "positions", + }) diff --git a/tests/data/test_tokenizer.py b/tests/data/test_tokenizer.py index c7fdef9ca..b7e1d3e9b 100644 --- a/tests/data/test_tokenizer.py +++ b/tests/data/test_tokenizer.py @@ -1,42 +1,93 @@ import pytest -from fast_llm.data.preprocessing.tokenizer import Tokenizer, TokenizerConfig +from fast_llm.data.preprocessing.tokenizer import TokenizerConfig from fast_llm.utils import Assert from tests.utils.dataset import download_santacoder_tokenizer from tests.utils.global_variables import TOKENIZER_PATH +TEXT = "hello world" + -@pytest.fixture(scope="session") -def common_tokenizer() -> Tokenizer: +@pytest.fixture(scope="module") +def tokenizer(): download_santacoder_tokenizer() return TokenizerConfig(path=TOKENIZER_PATH).get_tokenizer() -TEXT = "hello world" - - @pytest.mark.parametrize("extra_tokens", (False, True)) @pytest.mark.parametrize( ("spans", "expected_token_spans", "expected_tokens"), ( - ([], [], [7196, 5297]), # No span - ([(1, 3)], [(1, 2)], [71, 325, 303, 5297]), # Simple span - ([(2, 2)], [(1, 1)], [284, 47443, 5297]), # Empty span - ([(0, 11)], [(0, 2)], [7196, 5297]), # Full span - ([(1, 4), (6, 7)], [(1, 2), (4, 5)], [71, 1498, 78, 207, 86, 2231]), # Two spans - ([(1, 6), (4, 7)], [(1, 4), (2, 5)], [71, 1498, 78, 207, 86, 2231]), # Overlapping spans - ([(1, 7), (4, 6)], [(1, 5), (2, 4)], [71, 1498, 78, 207, 86, 2231]), # Nested spans - ([(1, 5), (5, 7)], [(1, 3), (3, 4)], [71, 325, 303, 365, 2231]), # Consecutive spans - ([(2, 4), (2, 4)], [(1, 2), (1, 2)], [284, 683, 78, 5297]), # Duplicate spans - ([(2, 3), (5, 8), (9, 11)], [(1, 2), (3, 4), (5, 6)], [284, 75, 303, 48485, 81, 1382]), # Three spans + ([], [], [7196, 5297]), + ([(1, 3)], [(1, 2)], [71, 325, 303, 5297]), + ([(2, 2)], [(1, 1)], [284, 47443, 5297]), + ([(0, 11)], [(0, 2)], [7196, 5297]), + ([(1, 4), (6, 7)], [(1, 2), (4, 5)], [71, 1498, 78, 207, 86, 2231]), + ([(1, 6), (4, 7)], [(1, 4), (2, 5)], [71, 1498, 78, 207, 86, 2231]), + ([(1, 7), (4, 6)], [(1, 5), (2, 4)], [71, 1498, 78, 207, 86, 2231]), + ([(1, 5), (5, 7)], [(1, 3), (3, 4)], [71, 325, 303, 365, 2231]), + ([(2, 4), (2, 4)], [(1, 2), (1, 2)], [284, 683, 78, 5297]), + ([(2, 3), (5, 8), (9, 11)], [(1, 2), (3, 4), (5, 6)], [284, 75, 303, 48485, 81, 1382]), ), ) -def test_tokenize_with_spans(common_tokenizer, spans, expected_token_spans, expected_tokens, extra_tokens): - tokens, token_spans = common_tokenizer.tokenize_with_spans( - TEXT, begin=extra_tokens, end=extra_tokens, text_spans=spans - ) +def test_tokenize_with_spans(tokenizer, spans, expected_token_spans, expected_tokens, extra_tokens): + tokens, token_spans = tokenizer.tokenize_with_spans(TEXT, begin=extra_tokens, end=extra_tokens, text_spans=spans) if extra_tokens: - expected_tokens = [common_tokenizer.bod_id, *expected_tokens, common_tokenizer.eod_id] + expected_tokens = [tokenizer.bod_id, *expected_tokens, tokenizer.eod_id] expected_token_spans = [(begin + 1, end + 1) for begin, end in expected_token_spans] Assert.eq(tokens.tolist(), expected_tokens) Assert.eq(token_spans, expected_token_spans) + + +def test_validate_chat_template_no_template(tokenizer): + """Tokenizer without chat template raises.""" + with pytest.raises(ValueError, match="does not have a chat template"): + tokenizer.validate_chat_template() + + +def test_validate_chat_template_no_markers(tokenizer): + """Tokenizer with chat template but no markers raises.""" + tokenizer.tokenizer.chat_template = "{{ messages }}" + with pytest.raises(ValueError, match="does not contain.*generation"): + tokenizer.validate_chat_template() + + +def test_validate_chat_template_with_markers(tokenizer): + """Tokenizer with generation markers validates.""" + tokenizer.tokenizer.chat_template = "{% generation %}{{ m }}{% endgeneration %}" + tokenizer.validate_chat_template() + + +CHAT_TEMPLATE = ( + "{% for message in messages %}" + "{% if message.role == 'assistant' %}" + "{% generation %}{{ message.content }}{% endgeneration %}" + "{% else %}" + "<{{ message.role }}>{{ message.content }}" + "{% endif %}" + "{% endfor %}" +) + + +@pytest.mark.parametrize( + ("messages", "expected_text", "expected_spans"), + ( + ([], "", []), + ( + [{"role": "user", "content": "Hi"}, {"role": "assistant", "content": "Hello"}], + "HiHello", + [(0, 26), (31, 43)], + ), + ( + [{"role": "user", "content": "A"}, {"role": "assistant", "content": "B"}, {"role": "user", "content": "C"}, {"role": "assistant", "content": "D"}], + "ABCD", + [(0, 25), (26, 63), (64, 76)], + ), + ), +) +def test_apply_chat_template_with_spans(tokenizer, messages, expected_text, expected_spans): + """Chat template produces correct text and masking spans.""" + tokenizer.tokenizer.chat_template = CHAT_TEMPLATE + text, spans = tokenizer.apply_chat_template_with_spans(messages) + Assert.eq(text, expected_text) + Assert.eq(spans, expected_spans) From 53d657069e46f6a830aef86ce852b2be79fa203a Mon Sep 17 00:00:00 2001 From: Torsten Scholak Date: Mon, 15 Dec 2025 20:25:48 +0000 Subject: [PATCH 06/25] Cleanup: remove private method indirection, revert test changes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Inline _apply_chat_template into apply_chat_template_with_spans - Revert unnecessary test refactoring in test_tokenizer.py - Remove trivial config tests from test_preparator.py 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- fast_llm/data/preprocessing/tokenizer.py | 9 ---- tests/data/test_preparator.py | 31 +----------- tests/data/test_tokenizer.py | 61 +++++++++++++----------- 3 files changed, 33 insertions(+), 68 deletions(-) diff --git a/fast_llm/data/preprocessing/tokenizer.py b/fast_llm/data/preprocessing/tokenizer.py index 924dc64b2..372d8cd90 100644 --- a/fast_llm/data/preprocessing/tokenizer.py +++ b/fast_llm/data/preprocessing/tokenizer.py @@ -267,15 +267,6 @@ def apply_chat_template_with_spans( """ if not messages: return "", [] - - return self._apply_chat_template(messages, add_generation_prompt) - - def _apply_chat_template( - self, - messages: list[dict[str, str]], - add_generation_prompt: bool, - ) -> tuple[str, list[tuple[int, int]]]: - """Use HF's return_assistant_tokens_mask for precise token-level masking.""" # Get tokens and assistant mask result = self.tokenizer.apply_chat_template( messages, diff --git a/tests/data/test_preparator.py b/tests/data/test_preparator.py index ccef94d03..f4f6fab82 100644 --- a/tests/data/test_preparator.py +++ b/tests/data/test_preparator.py @@ -6,11 +6,7 @@ from fast_llm.data.dataset.config import BlendedDatasetConfig, MemmapDatasetConfig, SamplingParameters from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig from fast_llm.data.dataset.memmap import MemmapDataset -from fast_llm.data.preparator.gpt_memmap.config import ( - ConversationSourceConfig, - GPTMemmapDatasetPreparatorConfig, - LanguageModelSourceConfig, -) +from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.data.preprocessing.tokenizer import TokenizerConfig from fast_llm.utils import Assert @@ -202,28 +198,3 @@ def test_dataset_preparator_from_hub(): tokenizer.detokenize(dataset.get_document(index).tokens.tokens), f"<|endoftext|>{hf_dataset[index]["answer"]}<|endoftext|>", ) - - -# ============================================================================= -# Conversation Format Tests -# ============================================================================= - - -def test_conversation_source_config(): - """Test conversation source schema configuration.""" - config = LanguageModelSourceConfig.from_dict({"type": "conversation"}) - Assert.custom(isinstance, config, ConversationSourceConfig) - Assert.eq(config.messages, "messages") - Assert.eq(config.has_conversation, True) - Assert.eq(config.has_loss_masking_span, True) - Assert.eq(config.columns, ["messages"]) - - -def test_conversation_config_validation(): - """Test conversation config validation errors.""" - with pytest.raises(ValueError, match="Images are not yet supported"): - LanguageModelSourceConfig.from_dict({ - "type": "conversation", - "images": "images", - "image_positions": "positions", - }) diff --git a/tests/data/test_tokenizer.py b/tests/data/test_tokenizer.py index b7e1d3e9b..4b8f45d8d 100644 --- a/tests/data/test_tokenizer.py +++ b/tests/data/test_tokenizer.py @@ -1,61 +1,64 @@ import pytest -from fast_llm.data.preprocessing.tokenizer import TokenizerConfig +from fast_llm.data.preprocessing.tokenizer import Tokenizer, TokenizerConfig from fast_llm.utils import Assert from tests.utils.dataset import download_santacoder_tokenizer from tests.utils.global_variables import TOKENIZER_PATH -TEXT = "hello world" - -@pytest.fixture(scope="module") -def tokenizer(): +@pytest.fixture(scope="session") +def common_tokenizer() -> Tokenizer: download_santacoder_tokenizer() return TokenizerConfig(path=TOKENIZER_PATH).get_tokenizer() +TEXT = "hello world" + + @pytest.mark.parametrize("extra_tokens", (False, True)) @pytest.mark.parametrize( ("spans", "expected_token_spans", "expected_tokens"), ( - ([], [], [7196, 5297]), - ([(1, 3)], [(1, 2)], [71, 325, 303, 5297]), - ([(2, 2)], [(1, 1)], [284, 47443, 5297]), - ([(0, 11)], [(0, 2)], [7196, 5297]), - ([(1, 4), (6, 7)], [(1, 2), (4, 5)], [71, 1498, 78, 207, 86, 2231]), - ([(1, 6), (4, 7)], [(1, 4), (2, 5)], [71, 1498, 78, 207, 86, 2231]), - ([(1, 7), (4, 6)], [(1, 5), (2, 4)], [71, 1498, 78, 207, 86, 2231]), - ([(1, 5), (5, 7)], [(1, 3), (3, 4)], [71, 325, 303, 365, 2231]), - ([(2, 4), (2, 4)], [(1, 2), (1, 2)], [284, 683, 78, 5297]), - ([(2, 3), (5, 8), (9, 11)], [(1, 2), (3, 4), (5, 6)], [284, 75, 303, 48485, 81, 1382]), + ([], [], [7196, 5297]), # No span + ([(1, 3)], [(1, 2)], [71, 325, 303, 5297]), # Simple span + ([(2, 2)], [(1, 1)], [284, 47443, 5297]), # Empty span + ([(0, 11)], [(0, 2)], [7196, 5297]), # Full span + ([(1, 4), (6, 7)], [(1, 2), (4, 5)], [71, 1498, 78, 207, 86, 2231]), # Two spans + ([(1, 6), (4, 7)], [(1, 4), (2, 5)], [71, 1498, 78, 207, 86, 2231]), # Overlapping spans + ([(1, 7), (4, 6)], [(1, 5), (2, 4)], [71, 1498, 78, 207, 86, 2231]), # Nested spans + ([(1, 5), (5, 7)], [(1, 3), (3, 4)], [71, 325, 303, 365, 2231]), # Consecutive spans + ([(2, 4), (2, 4)], [(1, 2), (1, 2)], [284, 683, 78, 5297]), # Duplicate spans + ([(2, 3), (5, 8), (9, 11)], [(1, 2), (3, 4), (5, 6)], [284, 75, 303, 48485, 81, 1382]), # Three spans ), ) -def test_tokenize_with_spans(tokenizer, spans, expected_token_spans, expected_tokens, extra_tokens): - tokens, token_spans = tokenizer.tokenize_with_spans(TEXT, begin=extra_tokens, end=extra_tokens, text_spans=spans) +def test_tokenize_with_spans(common_tokenizer, spans, expected_token_spans, expected_tokens, extra_tokens): + tokens, token_spans = common_tokenizer.tokenize_with_spans( + TEXT, begin=extra_tokens, end=extra_tokens, text_spans=spans + ) if extra_tokens: - expected_tokens = [tokenizer.bod_id, *expected_tokens, tokenizer.eod_id] + expected_tokens = [common_tokenizer.bod_id, *expected_tokens, common_tokenizer.eod_id] expected_token_spans = [(begin + 1, end + 1) for begin, end in expected_token_spans] Assert.eq(tokens.tolist(), expected_tokens) Assert.eq(token_spans, expected_token_spans) -def test_validate_chat_template_no_template(tokenizer): +def test_validate_chat_template_no_template(common_tokenizer): """Tokenizer without chat template raises.""" with pytest.raises(ValueError, match="does not have a chat template"): - tokenizer.validate_chat_template() + common_tokenizer.validate_chat_template() -def test_validate_chat_template_no_markers(tokenizer): +def test_validate_chat_template_no_markers(common_tokenizer): """Tokenizer with chat template but no markers raises.""" - tokenizer.tokenizer.chat_template = "{{ messages }}" + common_tokenizer.tokenizer.chat_template = "{{ messages }}" with pytest.raises(ValueError, match="does not contain.*generation"): - tokenizer.validate_chat_template() + common_tokenizer.validate_chat_template() -def test_validate_chat_template_with_markers(tokenizer): +def test_validate_chat_template_with_markers(common_tokenizer): """Tokenizer with generation markers validates.""" - tokenizer.tokenizer.chat_template = "{% generation %}{{ m }}{% endgeneration %}" - tokenizer.validate_chat_template() + common_tokenizer.tokenizer.chat_template = "{% generation %}{{ m }}{% endgeneration %}" + common_tokenizer.validate_chat_template() CHAT_TEMPLATE = ( @@ -85,9 +88,9 @@ def test_validate_chat_template_with_markers(tokenizer): ), ), ) -def test_apply_chat_template_with_spans(tokenizer, messages, expected_text, expected_spans): +def test_apply_chat_template_with_spans(common_tokenizer, messages, expected_text, expected_spans): """Chat template produces correct text and masking spans.""" - tokenizer.tokenizer.chat_template = CHAT_TEMPLATE - text, spans = tokenizer.apply_chat_template_with_spans(messages) + common_tokenizer.tokenizer.chat_template = CHAT_TEMPLATE + text, spans = common_tokenizer.apply_chat_template_with_spans(messages) Assert.eq(text, expected_text) Assert.eq(spans, expected_spans) From d053d47d41abb23cdb729f01b02faad6fde7433f Mon Sep 17 00:00:00 2001 From: Torsten Scholak Date: Tue, 16 Dec 2025 12:28:39 +0000 Subject: [PATCH 07/25] Refactor test organization: rename modules and remove duplication MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Rename test_plan_composition_torture.py → test_conversion_e2e.py (reflects actual purpose: end-to-end integration tests) - Rename test_algebraic_properties.py → test_plan_execution.py (clearer: tests plan execution and algebraic composition laws) - Remove stale NOTE comments referencing deleted tests - Fix fixture naming collision: attention_config → attention_config_dict in TestMarkovianProperty to avoid shadowing conftest fixtures - Consolidate shared fixtures in conftest.py Test organization now follows clear separation: - test_compose_configs.py: Config dict composition (structure/completeness) - test_plan_execution.py: Plan execution (weight transfer/correctness) - test_conversion_e2e.py: Full pipeline integration tests 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- .../tests/test_apriel2/conftest.py | 142 +++++ .../test_apriel2/test_compose_configs.py | 261 +++----- ...tion_torture.py => test_conversion_e2e.py} | 342 +--------- .../tests/test_apriel2/test_plan_execution.py | 597 ++++++++++++++++++ 4 files changed, 833 insertions(+), 509 deletions(-) rename fast_llm_external_models/tests/test_apriel2/{test_plan_composition_torture.py => test_conversion_e2e.py} (84%) create mode 100644 fast_llm_external_models/tests/test_apriel2/test_plan_execution.py diff --git a/fast_llm_external_models/tests/test_apriel2/conftest.py b/fast_llm_external_models/tests/test_apriel2/conftest.py index cf190b50a..320813747 100644 --- a/fast_llm_external_models/tests/test_apriel2/conftest.py +++ b/fast_llm_external_models/tests/test_apriel2/conftest.py @@ -1680,6 +1680,148 @@ def torture_surgery_chain(): ] +# ============================================================================= +# Shared Config Dict Fixtures (for compose_configs / plan_surgery tests) +# ============================================================================= + + +@pytest.fixture +def base_config_dict(): + """Complete Apriel2 config dict without biases (Llama-style). + + Use this as the base config for testing compose_configs and plan_surgery. + Returns a dict (not Apriel2Config) for direct use with compose_configs. + """ + return { + "model_type": "apriel2", + "hidden_size": 256, + "vocab_size": 1000, + "bos_token_id": 1, + "eos_token_id": 2, + "tie_word_embeddings": False, + "decoder": { + "type": "fixed", + "num_blocks": 2, + "block": { + "mixer": { + "type": "attention", + "heads": 8, + "head_groups": 4, + "head_size": 32, + "rotary": {"type": "mistral_1d", "theta": 10000.0}, + }, + "mlp": {"type": "mlp", "intermediate_size": 512, "gated": True}, + "normalization": {"type": "rms_norm", "epsilon": 1e-5}, + }, + }, + } + + +@pytest.fixture +def base_config_with_bias_dict(): + """Complete Apriel2 config dict with Qwen-style biases. + + - QKV bias enabled, O bias disabled + - Gated MLP (no per-layer bias control in this style) + + Use this for testing bias inheritance through surgery operations. + Returns a dict (not Apriel2Config) for direct use with compose_configs. + """ + return { + "model_type": "apriel2", + "hidden_size": 256, + "vocab_size": 1000, + "bos_token_id": 1, + "eos_token_id": 2, + "tie_word_embeddings": False, + "decoder": { + "type": "fixed", + "num_blocks": 2, + "block": { + "mixer": { + "type": "attention", + "heads": 8, + "head_groups": 4, + "head_size": 32, + "rotary": {"type": "mistral_1d", "theta": 10000.0}, + "query_layer": {"bias": {"enabled": True}}, + "key_layer": {"bias": {"enabled": True}}, + "value_layer": {"bias": {"enabled": True}}, + "dense_layer": {"bias": {"enabled": False}}, + }, + "mlp": {"type": "mlp", "intermediate_size": 512, "gated": True}, + "normalization": {"type": "rms_norm", "epsilon": 1e-5}, + }, + }, + } + + +def make_weights_for_config(config: dict) -> dict: + """Create random weights matching a config's expected schema. + + This is a helper function (not a fixture) for creating test weights. + Use it in tests that need weights for plan execution. + + Args: + config: Complete Apriel2 config dict + + Returns: + Dict mapping weight key strings to torch tensors + """ + from fast_llm_external_models.apriel2.conversion import W + + hidden = config["hidden_size"] + vocab = config["vocab_size"] + decoder = config["decoder"] + num_blocks = decoder["num_blocks"] + block = decoder["block"] + mixer = block["mixer"] + mlp = block["mlp"] + + heads = mixer["heads"] + head_groups = mixer["head_groups"] + head_size = mixer["head_size"] + inter = mlp["intermediate_size"] + + # Check bias settings + has_q_bias = mixer.get("query_layer", {}).get("bias", {}).get("enabled", False) + has_k_bias = mixer.get("key_layer", {}).get("bias", {}).get("enabled", False) + has_v_bias = mixer.get("value_layer", {}).get("bias", {}).get("enabled", False) + + weights = {} + weights["model.embed_tokens.weight"] = torch.randn(vocab, hidden) + + for i in range(num_blocks): + p = f"model.decoder.blocks.{i}" + + # Attention + weights[f"{p}.mixer.q_proj.weight"] = torch.randn(heads * head_size, hidden) + weights[f"{p}.mixer.k_proj.weight"] = torch.randn(head_groups * head_size, hidden) + weights[f"{p}.mixer.v_proj.weight"] = torch.randn(head_groups * head_size, hidden) + weights[f"{p}.mixer.o_proj.weight"] = torch.randn(hidden, heads * head_size) + + if has_q_bias: + weights[f"{p}.mixer.q_proj.bias"] = torch.randn(heads * head_size) + if has_k_bias: + weights[f"{p}.mixer.k_proj.bias"] = torch.randn(head_groups * head_size) + if has_v_bias: + weights[f"{p}.mixer.v_proj.bias"] = torch.randn(head_groups * head_size) + + # MLP + weights[f"{p}.mlp.up_proj.weight"] = torch.randn(inter, hidden) + weights[f"{p}.mlp.gate_proj.weight"] = torch.randn(inter, hidden) + weights[f"{p}.mlp.down_proj.weight"] = torch.randn(hidden, inter) + + # Norms + weights[f"{p}.input_layernorm.weight"] = torch.randn(hidden) + weights[f"{p}.post_attention_layernorm.weight"] = torch.randn(hidden) + + weights["model.norm.weight"] = torch.randn(hidden) + weights["lm_head.weight"] = torch.randn(vocab, hidden) + + return {W(k): v for k, v in weights.items()} + + # ============================================================================= # Cache Test Fixtures - Tensor Dimensions # ============================================================================= diff --git a/fast_llm_external_models/tests/test_apriel2/test_compose_configs.py b/fast_llm_external_models/tests/test_apriel2/test_compose_configs.py index 4380b1fbd..b1ee15d54 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_compose_configs.py +++ b/fast_llm_external_models/tests/test_apriel2/test_compose_configs.py @@ -75,14 +75,10 @@ def source_config(self): }, } - def test_identity_empty_surgery(self, source_config): - """Law 1: compose_configs(config, {}) == config""" - result = compose_configs(source_config, {}) - assert result == source_config - - def test_identity_none_surgery(self, source_config): - """Law 1: compose_configs(config, None) == config""" - result = compose_configs(source_config, None) + @pytest.mark.parametrize("empty_surgery", [{}, None]) + def test_identity(self, source_config, empty_surgery): + """Law 1: compose_configs(config, empty) == config for empty in [{}, None]""" + result = compose_configs(source_config, empty_surgery) assert result == source_config def test_override_explicit_values(self, source_config): @@ -114,7 +110,7 @@ def test_same_type_inheritance(self, source_config): assert mixer["head_size"] == 32 # Inherited assert mixer["rope_theta"] == 10000.0 # Inherited assert mixer["window_size"] == 512 # Added - assert "init" not in mixer # Stripped by apply_surgery + # init is preserved for plan_surgery to see (stripped only at final output) def test_cross_type_attention_to_gdn(self, source_config): """Law 5: attention→gdn derives GDN dims from attention geometry.""" @@ -239,8 +235,14 @@ def test_null_deletion(self, source_config): assert "vision_encoder" not in result - def test_init_stripped_from_result(self, source_config): - """Verify `init` keys are stripped from final result.""" + def test_init_preserved_for_plan_surgery(self, source_config): + """Verify `init` keys are preserved so plan_surgery can see them. + + The `init` field controls weight initialization (transfer vs random). + It's preserved through composition and only stripped at final output. + """ + from fast_llm_external_models.apriel2.conversion.config import strip_init_fields + surgery = { "decoder": { "block": { @@ -252,20 +254,20 @@ def test_init_stripped_from_result(self, source_config): "gdn": {"type": "gdn", "init": "random", "convolution_layer": {"kernel_size": 4}}, }, }, - "mlp": {"init": "transfer"}, - "normalization": {"init": "transfer"}, }, }, } result = compose_configs(source_config, surgery) - def check_no_init(d, path=""): - assert "init" not in d, f"Found 'init' key at {path}" - for k, v in d.items(): - if isinstance(v, dict): - check_no_init(v, f"{path}.{k}") + # init is preserved in composed config + mixers = result["decoder"]["block"]["mixer"]["mixers"] + assert mixers["attention"].get("init") == "transfer" + assert mixers["gdn"].get("init") == "random" - check_no_init(result) + # strip_init_fields removes them for final output + stripped = strip_init_fields(result) + assert "init" not in stripped["decoder"]["block"]["mixer"]["mixers"]["attention"] + assert "init" not in stripped["decoder"]["block"]["mixer"]["mixers"]["gdn"] def test_init_random_still_inherits_config(self, source_config): """init: random is for weights only - config params still inherited.""" @@ -287,6 +289,49 @@ def test_init_random_still_inherits_config(self, source_config): assert mixer["head_groups"] == 4 assert mixer["window_size"] == 512 + # ========================================================================= + # Monoid Laws: compose_configs forms a monoid action on configs + # ========================================================================= + + def test_surgery_monoid_associativity(self): + """MONOID: merge(merge(A, B), C) == merge(A, merge(B, C)) for partial configs.""" + surgery_a = {"decoder": {"block": {"mixer": {"type": "stochastic", "main_mixer_name": "attention"}}}} + surgery_b = {"decoder": {"block": {"mixer": {"mixers": {"sliding_window": {"window_size": 512}}}}}} + surgery_c = {"decoder": {"block": {"mixer": {"mixers": {"gdn": {"type": "gdn"}}}}}} + + # Left-associated: (A ∘ B) ∘ C + ab_c = compose_configs(compose_configs(surgery_a, surgery_b), surgery_c) + # Right-associated: A ∘ (B ∘ C) + a_bc = compose_configs(surgery_a, compose_configs(surgery_b, surgery_c)) + + assert ab_c == a_bc, "Surgery monoid should be associative" + + @pytest.mark.parametrize("num_surgeries", [2, 3]) + def test_monoid_action_compatibility(self, source_config, num_surgeries): + """MONOID ACTION: apply(apply(c, A), B) == apply(c, merge(A, B)) + + This is the key law: applying surgeries sequentially equals merging first. + Parameterized to test with 2 and 3 surgeries. + """ + surgeries = [ + {"decoder": {"block": {"mixer": {"type": "stochastic", "main_mixer_name": "attention", "mixers": {"attention": {}}}}}}, + {"decoder": {"block": {"mixer": {"mixers": {"sliding_window": {"window_size": 512}}}}}}, + {"decoder": {"block": {"mixer": {"mixers": {"gdn": {"type": "gdn"}}}}}}, + ][:num_surgeries] + + # Sequential: ((c ⊳ A) ⊳ B) ⊳ ... + result_sequential = source_config + for s in surgeries: + result_sequential = compose_configs(result_sequential, s) + + # Merged: c ⊳ (A ∘ B ∘ ...) + merged = surgeries[0] + for s in surgeries[1:]: + merged = compose_configs(merged, s) + result_merged = compose_configs(source_config, merged) + + assert result_sequential == result_merged, f"Monoid action compatibility failed for {num_surgeries} surgeries" + class TestBiasConfigInheritance: """Test per-layer bias inheritance through surgery composition. @@ -555,160 +600,12 @@ def test_build_plan_returns_complete_config(self, llava_pixtral_checkpoint): mixer = config.decoder["block"]["mixer"] assert mixer["type"] == "stochastic" - # Each sub-mixer should have complete config (no init keys) + # Each sub-mixer should have complete config + # (init is preserved for plan_surgery, stripped only at final output) for name, sub_mixer in mixer["mixers"].items(): - assert "init" not in sub_mixer, f"Sub-mixer {name} still has 'init' key" assert "type" in sub_mixer -class TestMonoidLaws: - """Test the algebraic laws of compose_configs. - - Surgery specs form a MONOID under deep-merge: - - Identity: {} - - Operation: deep merge (overlay wins) - - Associativity: merge(merge(A, B), C) == merge(A, merge(B, C)) - - compose_configs is a MONOID ACTION on configs: - - Identity action: apply(config, {}) == config - - Compatibility: apply(apply(c, A), B) == apply(c, merge(A, B)) - """ - - @pytest.fixture - def complete_config(self): - """A complete Apriel2 config.""" - return { - "model_type": "apriel2", - "architectures": ["Apriel2ForConditionalGeneration"], - "hidden_size": 256, - "vocab_size": 1000, - "bos_token_id": 1, - "eos_token_id": 2, - "tie_word_embeddings": False, - "image_token_index": 100, - "decoder": { - "type": "fixed", - "num_blocks": 4, - "block": { - "mixer": { - "type": "attention", - "heads": 8, - "head_groups": 4, - "head_size": 32, - "rope_theta": 10000.0, - }, - "mlp": {"type": "mlp", "intermediate_size": 512}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - }, - } - - @pytest.fixture - def surgery_a(self): - """First surgery: wrap in stochastic with attention.""" - return { - "decoder": { - "block": { - "mixer": { - "type": "stochastic", - "main_mixer_name": "attention", - "mixers": { - "attention": {"init": "transfer"}, - }, - }, - }, - }, - } - - @pytest.fixture - def surgery_b(self): - """Second surgery: add sliding window mixer.""" - return { - "decoder": { - "block": { - "mixer": { - "mixers": { - "sliding_window": {"init": "transfer", "window_size": 512}, - }, - }, - }, - }, - } - - def test_identity_action(self, complete_config): - """apply(config, {}) == config""" - result = compose_configs(complete_config, {}) - assert result == complete_config - - def test_surgery_monoid_associativity(self, surgery_a, surgery_b): - """merge(merge(A, B), C) == merge(A, merge(B, C)) for partial configs.""" - surgery_c = { - "decoder": { - "block": { - "mixer": { - "mixers": { - "gdn": {"type": "gdn", "init": "transfer", "convolution_layer": {"kernel_size": 4}}, - }, - }, - }, - }, - } - - # Left-associated: (A ∘ B) ∘ C - ab = compose_configs(surgery_a, surgery_b) - ab_c = compose_configs(ab, surgery_c) - - # Right-associated: A ∘ (B ∘ C) - bc = compose_configs(surgery_b, surgery_c) - a_bc = compose_configs(surgery_a, bc) - - assert ab_c == a_bc, "Surgery monoid should be associative" - - def test_monoid_action_compatibility(self, complete_config, surgery_a, surgery_b): - """apply(apply(c, A), B) == apply(c, merge(A, B)) - - This is the key law: applying surgeries sequentially should equal - merging the surgeries first, then applying once. - """ - # Sequential application: (c ⊳ A) ⊳ B - result_sequential = compose_configs(compose_configs(complete_config, surgery_a), surgery_b) - - # Merged application: c ⊳ (A ∘ B) - merged_surgery = compose_configs(surgery_a, surgery_b) - result_merged = compose_configs(complete_config, merged_surgery) - - # These should be equivalent - assert result_sequential == result_merged, "Monoid action should satisfy compatibility law" - - def test_three_way_compatibility(self, complete_config, surgery_a, surgery_b): - """Test with three surgeries for stronger confidence.""" - surgery_c = { - "decoder": { - "block": { - "mixer": { - "mixers": { - "gdn": {"type": "gdn", "init": "transfer", "convolution_layer": {"kernel_size": 4}}, - }, - }, - }, - }, - } - - # Sequential: ((c ⊳ A) ⊳ B) ⊳ C - seq = compose_configs( - compose_configs(compose_configs(complete_config, surgery_a), surgery_b), - surgery_c - ) - - # Merged: c ⊳ ((A ∘ B) ∘ C) - merged = compose_configs( - complete_config, - compose_configs(compose_configs(surgery_a, surgery_b), surgery_c) - ) - - assert seq == merged, "Three-way monoid action should satisfy compatibility" - - class TestCompositionTortureTest: """Comprehensive stress test for config composition. @@ -807,19 +704,29 @@ def test_final_config_structure(self, complete_config, additive_surgery_chain): assert mixer["mixers"]["sliding_window"]["window_size"] == 512 assert mixer["mixers"]["gdn"]["value_heads"] == 16 - def test_no_init_keys_in_result(self, complete_config, additive_surgery_chain): - """Verify no 'init' keys leak through.""" + def test_init_keys_preserved_for_planning(self, complete_config, additive_surgery_chain): + """Verify 'init' keys are preserved for plan_surgery to see. - def check_no_init(d, path=""): - if isinstance(d, dict): - assert "init" not in d, f"Found 'init' key at {path}" - for k, v in d.items(): - check_no_init(v, f"{path}.{k}") + The `init` field is metadata for weight initialization. It's preserved + through composition and only stripped when saving final output. + """ + from fast_llm_external_models.apriel2.conversion.config import strip_init_fields result = complete_config for i, surgery in enumerate(additive_surgery_chain): result = compose_configs(result, surgery) - check_no_init(result, f"step_{i+1}") + + # init should be in the composed config + mixer = result["decoder"]["block"]["mixer"] + if "mixers" in mixer: + has_init = any("init" in m for m in mixer["mixers"].values()) + assert has_init, "init should be preserved in composed config" + + # strip_init_fields removes them + stripped = strip_init_fields(result) + mixer = stripped["decoder"]["block"]["mixer"] + if "mixers" in mixer: + assert all("init" not in m for m in mixer["mixers"].values()) def test_full_torture_chain(self, complete_config, torture_surgery_chain): """Test the full 10-step torture chain produces valid configs.""" diff --git a/fast_llm_external_models/tests/test_apriel2/test_plan_composition_torture.py b/fast_llm_external_models/tests/test_apriel2/test_conversion_e2e.py similarity index 84% rename from fast_llm_external_models/tests/test_apriel2/test_plan_composition_torture.py rename to fast_llm_external_models/tests/test_apriel2/test_conversion_e2e.py index 76a77ccb6..09fb9fa13 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_plan_composition_torture.py +++ b/fast_llm_external_models/tests/test_apriel2/test_conversion_e2e.py @@ -1,6 +1,6 @@ -"""End-to-end torture test for plan composition. +"""test_conversion_e2e.py - End-to-end conversion integration tests. -This tests the FULL pipeline at every step of a surgery chain: +Tests the FULL pipeline at every step of a surgery chain: 1. Config composition produces valid configs 2. Plan building works for each surgery 3. Plan execution produces valid weights @@ -1083,66 +1083,6 @@ def mamba_config(self): }, } - def test_config_composition_identical_regardless_of_init_mode(self, base_config): - """Config composition produces same structure with init: transfer vs init: random.""" - # Surgery with init: transfer - surgery_transfer = { - "decoder": { - "block": { - "mixer": { - "type": "stochastic", - "main_mixer_name": "attention", - "mixers": { - "attention": {"type": "attention", "init": "transfer"}, - "swa": { - "type": "attention", - "init": "transfer", - "sliding_window": 512, - }, - }, - }, - }, - }, - } - - # Surgery with init: random - surgery_random = { - "decoder": { - "block": { - "mixer": { - "type": "stochastic", - "main_mixer_name": "attention", - "mixers": { - "attention": {"type": "attention", "init": "random"}, - "swa": { - "type": "attention", - "init": "random", - "sliding_window": 512, - }, - }, - }, - }, - }, - } - - # Compose configs - result_transfer = compose_configs(base_config, surgery_transfer) - result_random = compose_configs(base_config, surgery_random) - - # Both should produce identical structure (init is stripped) - assert result_transfer == result_random, ( - "Config composition should produce identical structure regardless of init mode" - ) - - # Verify the structure is correct - mixer = result_transfer["decoder"]["block"]["mixer"] - assert mixer["type"] == "stochastic" - assert "attention" in mixer["mixers"] - assert "swa" in mixer["mixers"] - # init should be stripped - assert "init" not in mixer["mixers"]["attention"] - assert "init" not in mixer["mixers"]["swa"] - def test_plan_surgery_random_succeeds_for_any_type_pair(self, mamba_config): """plan_surgery with init: random should succeed even for mamba -> attention.""" # This surgery changes mamba to attention with random init @@ -1313,8 +1253,8 @@ class TestMarkovianProperty: """ @pytest.fixture - def attention_config(self): - """Base config with attention.""" + def attention_config_dict(self): + """Base config dict with attention mixer for compose_configs tests.""" return { "model_type": "apriel2", "hidden_size": 256, @@ -1335,43 +1275,7 @@ def attention_config(self): }, } - @pytest.fixture - def stochastic_config(self): - """Config with stochastic mixer.""" - return { - "model_type": "apriel2", - "hidden_size": 256, - "vocab_size": 1000, - "decoder": { - "type": "fixed", - "num_blocks": 2, - "block": { - "mixer": { - "type": "stochastic", - "main_mixer_name": "attention", - "mixers": { - "attention": { - "type": "attention", - "heads": 8, - "head_groups": 4, - "head_size": 32, - }, - "swa": { - "type": "sliding_window", - "heads": 8, - "head_groups": 4, - "head_size": 32, - "window_size": 512, - }, - }, - }, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - }, - } - - def test_different_paths_same_config_same_plan(self, attention_config): + def test_different_paths_same_config_same_plan(self, attention_config_dict): """Two different paths to the same config produce identical plans. Path A: attention -> stochastic{att, swa} @@ -1398,7 +1302,7 @@ def test_different_paths_same_config_same_plan(self, attention_config): }, }, } - config_a = compose_configs(attention_config, surgery_a) + config_a = compose_configs(attention_config_dict, surgery_a) # Path B: First add attention only, then add swa surgery_b1 = { @@ -1414,7 +1318,7 @@ def test_different_paths_same_config_same_plan(self, attention_config): }, }, } - intermediate_config = compose_configs(attention_config, surgery_b1) + intermediate_config = compose_configs(attention_config_dict, surgery_b1) surgery_b2 = { "decoder": { @@ -1469,7 +1373,7 @@ def test_different_paths_same_config_same_plan(self, attention_config): keys_b = set(str(k) for k in plan_from_b.mappings.keys()) assert keys_a == keys_b, "Plans from same config via different paths should be identical" - def test_init_in_source_config_does_not_affect_plan(self, attention_config): + def test_init_in_source_config_does_not_affect_plan(self, attention_config_dict): """Manually injecting init into source config doesn't change the plan. This tests that plan_surgery reads init from surgery, not source. @@ -1479,8 +1383,8 @@ def test_init_in_source_config_does_not_affect_plan(self, attention_config): import copy # Create two copies of the config - config_with_init = copy.deepcopy(attention_config) - config_without_init = copy.deepcopy(attention_config) + config_with_init = copy.deepcopy(attention_config_dict) + config_without_init = copy.deepcopy(attention_config_dict) # Manually inject init into one (bypassing compose_configs) config_with_init["decoder"]["block"]["mixer"]["init"] = "random" @@ -1510,232 +1414,6 @@ def test_init_in_source_config_does_not_affect_plan(self, attention_config): # Plans should be identical - source's init field is ignored assert keys_with == keys_without, "Plan should not depend on init in source config" - def test_associativity_of_surgery_composition(self, attention_config): - """Verify associativity: (A ∘ B) ∘ C == A ∘ (B ∘ C) for surgery specs. - - This tests that composing surgeries is associative, which is - equivalent to Markovianity for plan creation. - """ - surgery_a = { - "decoder": { - "block": { - "mixer": { - "type": "stochastic", - "main_mixer_name": "attention", - "mixers": { - "attention": {"type": "attention", "init": "transfer"}, - }, - }, - }, - }, - } - - surgery_b = { - "decoder": { - "block": { - "mixer": { - "mixers": { - "swa": { - "type": "sliding_window", - "init": "transfer", - "window_size": 512, - }, - }, - }, - }, - }, - } - - surgery_c = { - "decoder": { - "block": { - "mixer": { - "mixers": { - "gdn": { - "type": "gdn", - "init": "random", - "value_heads": 8, - "key_heads": 4, - "key_head_dim": 32, - "value_head_dim": 32, - "convolution_layer": {"kernel_size": 4}, - }, - }, - }, - }, - }, - } - - # Left association: ((attention_config ∘ A) ∘ B) ∘ C - left_1 = compose_configs(attention_config, surgery_a) - left_2 = compose_configs(left_1, surgery_b) - left_result = compose_configs(left_2, surgery_c) - - # Right association: (attention_config ∘ A) ∘ (B ∘ C) - # Note: B ∘ C is partial ∘ partial = deep merge of surgery specs - bc_merged = compose_configs(surgery_b, surgery_c) - right_1 = compose_configs(attention_config, surgery_a) - right_result = compose_configs(right_1, bc_merged) - - assert left_result == right_result, "Surgery composition should be associative" - - def test_complete_configs_have_no_init_fields(self, attention_config): - """Verify that compose_configs strips init from complete configs. - - This is the key invariant that enables Markovianity: - - Complete configs (states) have no init fields - - Surgery specs (transitions) have init fields - - Plans read init from surgery, not state - """ - surgery_with_init = { - "decoder": { - "block": { - "mixer": { - "type": "stochastic", - "main_mixer_name": "attention", - "mixers": { - "attention": {"type": "attention", "init": "transfer"}, - "swa": {"type": "sliding_window", "init": "random", "window_size": 512}, - }, - }, - }, - }, - } - - result = compose_configs(attention_config, surgery_with_init) - - # Recursively check for init fields - def has_init(obj): - if isinstance(obj, dict): - if "init" in obj: - return True - return any(has_init(v) for v in obj.values()) - if isinstance(obj, list): - return any(has_init(v) for v in obj) - return False - - assert not has_init(result), "Complete configs should have no init fields" - - def test_monoid_action_law_additive_surgeries(self): - """Monoid action law HOLDS for additive surgeries. - - Additive surgeries (no type: declaration) support: - apply(apply(s, t1), t2) == apply(s, t1 ∘ t2) - - This is because additive operations commute nicely: - "add {a}" then "add {b}" == "add {a, b}" - """ - # Start with stochastic (additive surgery target) - s = { - "model_type": "apriel2", - "hidden_size": 256, - "vocab_size": 1000, - "decoder": { - "type": "fixed", - "num_blocks": 2, - "block": { - "mixer": { - "type": "stochastic", - "main_mixer_name": "attention", - "mixers": { - "attention": {"type": "attention", "heads": 8, "head_groups": 4, "head_size": 32}, - }, - }, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - }, - } - - # Additive surgeries (no type: declaration) - t1 = {"decoder": {"block": {"mixer": {"mixers": {"swa": {"type": "sliding_window", "window_size": 512}}}}}} - t2 = {"decoder": {"block": {"mixer": {"mixers": {"mamba": {"type": "mamba", "d_inner": 512}}}}}} - - # Path A: Sequential - s_prime = compose_configs(s, t1) - s_double_prime_A = compose_configs(s_prime, t2) - - # Path B: Composed - t1_t2 = compose_configs(t1, t2) - s_double_prime_B = compose_configs(s, t1_t2) - - assert s_double_prime_A == s_double_prime_B, "Monoid action law should hold for additive surgeries" - - def test_monoid_action_law_replacement_surgeries_fails(self): - """Monoid action law FAILS for replacement surgeries (by design). - - Replacement surgeries (type: stochastic declared) have: - apply(apply(s, t1), t2) != apply(s, t1 ∘ t2) - - This is FUNDAMENTAL, not a bug: - - Sequential: "set to {a}" then "set to {b}" → {b} (second wins) - - Composed: merge({a}, {b}) = {a,b}, then apply → {a,b} - - These are genuinely different semantics. The failure documents - the distinction between declarative composition (merge) and - operational composition (function application). - """ - s = { - "model_type": "apriel2", - "hidden_size": 256, - "vocab_size": 1000, - "decoder": { - "type": "fixed", - "num_blocks": 2, - "block": { - "mixer": {"type": "attention", "heads": 8, "head_groups": 4, "head_size": 32}, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - }, - } - - # Replacement surgeries (both declare type: stochastic) - t1 = { - "decoder": { - "block": { - "mixer": { - "type": "stochastic", - "main_mixer_name": "attention", - "mixers": {"attention": {"type": "attention"}}, - } - } - } - } - t2 = { - "decoder": { - "block": { - "mixer": { - "type": "stochastic", - "main_mixer_name": "swa", - "mixers": {"swa": {"type": "sliding_window", "window_size": 512}}, - } - } - } - } - - # Path A: Sequential (second replacement wins) - s_prime = compose_configs(s, t1) - s_double_prime_A = compose_configs(s_prime, t2) - - # Path B: Composed (declarations merged) - t1_t2 = compose_configs(t1, t2) - s_double_prime_B = compose_configs(s, t1_t2) - - # They should be DIFFERENT (law fails) - assert s_double_prime_A != s_double_prime_B, ( - "Monoid action law should FAIL for replacement surgeries" - ) - - # Verify the specific difference: - # Sequential: only swa (second replacement wins) - # Composed: both attention and swa (merged declarations) - mixers_A = set(s_double_prime_A["decoder"]["block"]["mixer"]["mixers"].keys()) - mixers_B = set(s_double_prime_B["decoder"]["block"]["mixer"]["mixers"].keys()) - - assert mixers_A == {"swa"}, "Sequential: second replacement wins" - assert mixers_B == {"attention", "swa"}, "Composed: declarations merged" - class TestCyclingSurgeryGeneration: """Tests for the cycling surgery generation functions. diff --git a/fast_llm_external_models/tests/test_apriel2/test_plan_execution.py b/fast_llm_external_models/tests/test_apriel2/test_plan_execution.py new file mode 100644 index 000000000..9a98ec13b --- /dev/null +++ b/fast_llm_external_models/tests/test_apriel2/test_plan_execution.py @@ -0,0 +1,597 @@ +"""test_plan_execution.py - Plan execution and algebraic composition laws. + +This module provides rigorous, parameterized tests for the mathematical properties +that the conversion system must satisfy. Each test class corresponds to one +algebraic structure, and each test method verifies one specific law. + +Conceptual Types +================ + +The conversion system operates on three conceptual types (all ``dict`` at runtime): + +- **S (State)**: Complete config without ``init`` fields +- **P (Partial Surgery)**: Incomplete config, may have ``init`` fields +- **T (Transition Spec)**: Complete config WITH ``init`` fields + +Algebraic Structures +==================== + +1. **Partial Surgeries (P)** form a **Monoid** under deep merge:: + + compose_configs : P × P → P + Identity: {} + Associativity: (p1 ∘ p2) ∘ p3 = p1 ∘ (p2 ∘ p3) + +2. **Surgeries act on States** to produce Transition Specs:: + + compose_configs : S × P → T + compose_configs : T × P → T + + Action law (additive surgeries): (s · p1) · p2 = s · (p1 ∘ p2) + +3. **Plans** form a **Category** with composition:: + + compose : Plan(A→B) × Plan(B→C) → Plan(A→C) + Associativity: (P1 ∘ P2) ∘ P3 = P1 ∘ (P2 ∘ P3) + +4. **plan_surgery is a Functor** from config pairs to plans:: + + plan_surgery : S × T → Plan + Functoriality: compose(plan(S,T1), plan(T1,T2)) ≡ plan(S,T2) + + This is semantic equivalence: both produce identical weights when executed. + +Important Behaviors Tested +========================== + +- **init stripping**: Between surgery iterations, T → S conversion via + ``strip_init_fields()`` ensures ``init: random`` applies only to the surgery + that introduces a component. + +- **Bias inheritance**: Per-layer bias settings propagate through surgery chains. + +- **Plan composition**: Composed plans produce identical weights to direct plans. + +Design Principles +================= + +- Each law gets ONE parameterized test, not multiple similar tests +- Fixtures provide diverse configs (with/without biases) +- Corner cases are covered via parameterization, not test proliferation +- Tests document the laws they verify in their docstrings +""" + +import pytest +import torch +from functools import reduce + +from fast_llm_external_models.apriel2.conversion import ( + compose, + compose_configs, + execute, + plan_surgery, + ExprPlan, + W, + Ref, + Concat, + Slice, + Init, +) + +# Import shared helper from conftest +from fast_llm_external_models.tests.test_apriel2.conftest import make_weights_for_config + + +# ============================================================================= +# Fixtures: Use shared fixtures from conftest.py where possible +# ============================================================================= +# - base_config_dict: Complete config without biases (Llama-style) +# - base_config_with_bias_dict: Complete config with QKV biases +# - additive_surgery_chain: [wrap_stochastic, add_sliding_window, add_gdn] +# ============================================================================= + + +# ============================================================================= +# Test: Plan Composition Associativity +# ============================================================================= + + +class TestPlanCompositionAssociativity: + """ + LAW: Plan composition is associative. + + (P₁ ∘ P₂) ∘ P₃ = P₁ ∘ (P₂ ∘ P₃) + + where ∘ denotes compose(P1, P2). + + This must hold for the AST structure, not just semantic equivalence. + """ + + @pytest.mark.parametrize("expr_type", ["ref_chain", "with_concat", "with_slice", "with_init"]) + def test_associativity(self, expr_type): + """Plan composition is associative for various expression types.""" + # Build three plans that can be composed + if expr_type == "ref_chain": + p1 = ExprPlan(mappings={W("b"): Ref(key=W("a"))}) + p2 = ExprPlan(mappings={W("c"): Ref(key=W("b"))}) + p3 = ExprPlan(mappings={W("d"): Ref(key=W("c"))}) + elif expr_type == "with_concat": + p1 = ExprPlan(mappings={W("x"): Ref(key=W("a")), W("y"): Ref(key=W("b"))}) + p2 = ExprPlan(mappings={W("xy"): Concat(exprs=(Ref(key=W("x")), Ref(key=W("y"))), dim=0)}) + p3 = ExprPlan(mappings={W("final"): Ref(key=W("xy"))}) + elif expr_type == "with_slice": + p1 = ExprPlan(mappings={W("full"): Ref(key=W("src"))}) + p2 = ExprPlan(mappings={W("part"): Slice(expr=Ref(key=W("full")), slices=((0, 5, None),))}) + p3 = ExprPlan(mappings={W("out"): Ref(key=W("part"))}) + elif expr_type == "with_init": + p1 = ExprPlan(mappings={W("x"): Ref(key=W("a"))}) + p2 = ExprPlan(mappings={W("y"): Concat(exprs=(Ref(key=W("x")), Init(shape=(5,), init_type="zeros")), dim=0)}) + p3 = ExprPlan(mappings={W("z"): Ref(key=W("y"))}) + + left = compose(compose(p1, p2), p3) + right = compose(p1, compose(p2, p3)) + + assert left.mappings == right.mappings, f"Associativity failed for {expr_type}" + + +# ============================================================================= +# Test: Functoriality of plan_surgery (THE CRITICAL PROPERTY) +# ============================================================================= + + +class TestPlanSurgeryFunctoriality: + """ + LAW: plan_surgery is functorial with respect to config composition. + + For a surgery chain P₁, P₂, ..., Pₙ applied to base state S₀:: + + T₁ = compose_configs(S₀, P₁) # S × P → T + T₂ = compose_configs(T₁, P₂) # T × P → T (no stripping!) + ... + Tₙ = compose_configs(Tₙ₋₁, Pₙ) + + Plan functoriality says:: + + compose(plan(S₀,T₁), plan(T₁,T₂), ...) ≡ plan(S₀, Tₙ) + + where ≡ denotes semantic equivalence (identical weights when executed). + + NOTE: This tests T × P composition WITHOUT stripping between steps. + This differs from build_plan which strips (T → S) between iterations. + Both patterns are valid: + + - Without stripping: init fields accumulate, testing plan composition purity + - With stripping: init consumed per-step, testing real usage (see + test_build_plan_strips_init_between_iterations) + + The functoriality law holds in both cases because plan composition + correctly substitutes Ref expressions with their definitions. + """ + + @pytest.mark.parametrize("chain_length", [1, 2, 3]) + @pytest.mark.parametrize("use_bias", [True, False]) + def test_functoriality( + self, + chain_length, + use_bias, + base_config_dict, + base_config_with_bias_dict, + additive_surgery_chain, + ): + """ + Composed incremental plans produce same weights as direct plan. + + Parameterized over: + - chain_length: Number of surgeries (1, 2, or 3) + - use_bias: Whether base config has biases + """ + base_config = base_config_with_bias_dict if use_bias else base_config_dict + surgeries = additive_surgery_chain[:chain_length] + + # Build config chain: C₀ → C₁ → ... → Cₙ + configs = [base_config] + for s in surgeries: + configs.append(compose_configs(configs[-1], s)) + + # Build incremental plans: Pₖ = plan_surgery(Cₖ₋₁, Cₖ) + plans = [plan_surgery(configs[i], configs[i+1]) for i in range(len(surgeries))] + + # Compose all incremental plans + composed_plan = reduce(compose, plans) + + # Build direct plan: plan_surgery(C₀, Cₙ) + direct_plan = plan_surgery(configs[0], configs[-1]) + + # Execute both on same weights + weights = make_weights_for_config(base_config) + composed_weights = execute(composed_plan, weights, seed=42) + direct_weights = execute(direct_plan, weights, seed=42) + + # Verify semantic equivalence + assert set(composed_weights.keys()) == set(direct_weights.keys()), \ + f"Key sets differ for chain_length={chain_length}, use_bias={use_bias}" + + for key in composed_weights: + assert torch.allclose(composed_weights[key], direct_weights[key], atol=1e-6), \ + f"Weight mismatch for {key} with chain_length={chain_length}, use_bias={use_bias}" + + @pytest.mark.parametrize("split_point", [1, 2]) + def test_arbitrary_grouping( + self, + split_point, + base_config_with_bias_dict, + additive_surgery_chain, + ): + """ + Any grouping of surgery chain produces same result. + + For surgeries [S₁, S₂, S₃], tests that: + - compose(P₁, compose(P₂, P₃)) + - compose(compose(P₁, P₂), P₃) + - plan_surgery(C₀, C₃) + + all produce identical weights. + """ + surgeries = additive_surgery_chain + + # Build config chain + configs = [base_config_with_bias_dict] + for s in surgeries: + configs.append(compose_configs(configs[-1], s)) + + # Build incremental plans + plans = [plan_surgery(configs[i], configs[i+1]) for i in range(3)] + + # Different groupings + left_grouped = compose(compose(plans[0], plans[1]), plans[2]) + right_grouped = compose(plans[0], compose(plans[1], plans[2])) + direct = plan_surgery(configs[0], configs[-1]) + + # Execute all + weights = make_weights_for_config(base_config_with_bias_dict) + results = { + "left": execute(left_grouped, weights, seed=42), + "right": execute(right_grouped, weights, seed=42), + "direct": execute(direct, weights, seed=42), + } + + # All must match + keys = set(results["left"].keys()) + assert keys == set(results["right"].keys()) == set(results["direct"].keys()) + + for key in keys: + assert torch.allclose(results["left"][key], results["right"][key], atol=1e-6) + assert torch.allclose(results["left"][key], results["direct"][key], atol=1e-6) + + +# ============================================================================= +# Test: Bias Inheritance Preservation (Regression for the specific bug) +# ============================================================================= + + +class TestBiasInheritancePreservation: + """ + PROPERTY: Per-layer bias settings must be preserved through surgery chains. + + When a surgery spec does not mention bias settings, they must be inherited + from the source config. This is the specific failure mode of the build_plan + bug: passing partial surgery specs to plan_surgery lost inherited fields. + + This test verifies the SYMPTOM (missing biases) rather than the LAW + (functoriality). It's kept as a focused regression test. + """ + + @pytest.mark.parametrize("num_surgeries", [1, 2, 3]) + def test_qkv_biases_preserved_through_chain( + self, + num_surgeries, + base_config_with_bias_dict, + additive_surgery_chain, + ): + """QKV biases (enabled in source) appear in plan after N surgeries.""" + surgeries = additive_surgery_chain[:num_surgeries] + + # Build config and plan chain + configs = [base_config_with_bias_dict] + for s in surgeries: + configs.append(compose_configs(configs[-1], s)) + + plans = [plan_surgery(configs[i], configs[i+1]) for i in range(num_surgeries)] + final_plan = reduce(compose, plans) if len(plans) > 1 else plans[0] + + # Check bias keys present + target_keys = {str(k) for k in final_plan.target_keys()} + + assert any("q_proj.bias" in k for k in target_keys), \ + f"q_proj.bias missing after {num_surgeries} surgeries" + assert any("k_proj.bias" in k for k in target_keys), \ + f"k_proj.bias missing after {num_surgeries} surgeries" + assert any("v_proj.bias" in k for k in target_keys), \ + f"v_proj.bias missing after {num_surgeries} surgeries" + # O bias should NOT be present (disabled in source) + assert not any("o_proj.bias" in k for k in target_keys), \ + f"o_proj.bias should not be present (disabled in source)" + + def test_bias_values_preserved( + self, + base_config_with_bias_dict, + additive_surgery_chain, + ): + """Bias tensor values are correctly transferred, not just keys.""" + surgery = additive_surgery_chain[0] # wrap_stochastic + c1 = compose_configs(base_config_with_bias_dict, surgery) + plan = plan_surgery(base_config_with_bias_dict, c1) + + weights = make_weights_for_config(base_config_with_bias_dict) + result = execute(plan, weights, seed=42) + + # Verify values match (not just that keys exist) + for i in range(base_config_with_bias_dict["decoder"]["num_blocks"]): + src_key = W(f"model.decoder.blocks.{i}.mixer.q_proj.bias") + dst_key = W(f"model.decoder.blocks.{i}.mixer.mixers.attention.q_proj.bias") + + assert dst_key in result, f"Missing {dst_key}" + assert torch.allclose(weights[src_key], result[dst_key]), \ + f"Bias values differ for block {i}" + + +# ============================================================================= +# Test: build_plan Integration (Regression test for convert.py) +# ============================================================================= + + +class TestBuildPlanIntegration: + """ + REGRESSION: build_plan must compose configs before calling plan_surgery. + + The bug was: + plan_surgery(current_config, surgery_config) # WRONG: partial + + Should be: + target = compose_configs(current_config, surgery_config) + plan_surgery(current_config, target) # CORRECT: complete + + This test verifies the fix in convert.py's build_plan function. + """ + + @pytest.mark.parametrize("num_surgeries", [1, 2]) + def test_build_plan_preserves_inherited_fields( + self, + num_surgeries, + base_config_with_bias_dict, + additive_surgery_chain, + ): + """build_plan produces plans with inherited bias mappings.""" + from fast_llm_external_models.apriel2.convert import build_plan + + surgeries = additive_surgery_chain[:num_surgeries] + + plan, final_config = build_plan( + base_config_with_bias_dict, + surgeries, + source_format="apriel2", + ) + + # Verify inherited biases in config + if num_surgeries >= 1: + attn = final_config["decoder"]["block"]["mixer"]["mixers"]["attention"] + assert attn.get("query_layer", {}).get("bias", {}).get("enabled") is True + + # Verify bias mappings in plan + target_keys = {str(k) for k in plan.target_keys()} + assert any("q_proj.bias" in k for k in target_keys), \ + f"build_plan with {num_surgeries} surgeries missing q_proj.bias" + + +# ============================================================================= +# Test: init Field Preservation (Critical for random initialization) +# ============================================================================= + + +class TestInitFieldPreservation: + """ + PROPERTY: The `init` field must be visible to plan_surgery. + + The `init` field controls weight initialization mode: + - `init: transfer` → use weight transfer/conversion + - `init: random` → use random initialization + + compose_configs must preserve `init` so plan_surgery can see it. + Stripping happens only at final output (when saving to disk). + """ + + def test_init_random_produces_init_expression(self, base_config_with_bias_dict): + """Surgery with init: random produces Init expressions in plan.""" + surgery = { + "decoder": { + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"init": "transfer"}, + "gdn": {"type": "gdn", "init": "random", "convolution_layer": {"kernel_size": 4}}, + }, + }, + }, + }, + } + + target = compose_configs(base_config_with_bias_dict, surgery) + plan = plan_surgery(base_config_with_bias_dict, target) + + # Check that GDN weights use Init expressions (random init) + target_keys = {str(k) for k in plan.target_keys()} + gdn_keys = [k for k in target_keys if "gdn" in k.lower()] + + assert len(gdn_keys) > 0, "No GDN keys in plan" + + # Verify at least one GDN weight uses Init (random initialization) + has_init_expr = False + for key in plan.target_keys(): + if "gdn" in str(key).lower(): + expr = plan.mappings[key] + if isinstance(expr, Init): + has_init_expr = True + break + # Also check inside Concat/other composite expressions + if hasattr(expr, 'exprs'): + for sub in expr.exprs: + if isinstance(sub, Init): + has_init_expr = True + break + + assert has_init_expr, "init: random should produce Init expressions for GDN weights" + + def test_init_transfer_produces_ref_expression(self, base_config_with_bias_dict): + """Surgery with init: transfer produces Ref expressions (weight transfer).""" + surgery = { + "decoder": { + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"init": "transfer"}, + }, + }, + }, + }, + } + + target = compose_configs(base_config_with_bias_dict, surgery) + plan = plan_surgery(base_config_with_bias_dict, target) + + # Check that attention weights use Ref expressions (transfer) + has_ref = False + for key in plan.target_keys(): + if "attention" in str(key) and "q_proj.weight" in str(key): + expr = plan.mappings[key] + if isinstance(expr, Ref): + has_ref = True + break + + assert has_ref, "init: transfer should produce Ref expressions for attention weights" + + def test_build_plan_respects_init_random(self, base_config_with_bias_dict): + """build_plan correctly uses init: random for weight initialization.""" + from fast_llm_external_models.apriel2.convert import build_plan + + # Mamba requires many config fields for random init + surgery = { + "decoder": { + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"init": "transfer"}, + "mamba": { + "type": "mamba", + "init": "random", + "d_inner": 512, + "d_state": 16, + "dt_rank": 16, + "d_xb": 64, + "d_conv": 4, + "repeat_kv_before_conv": False, + "conv_bias": True, + "dt_proj_bias": True, + "dt_min": 0.001, + "dt_max": 0.1, + "dt_init_floor": 1e-4, + }, + }, + }, + }, + }, + } + + plan, final_config = build_plan( + base_config_with_bias_dict, + [surgery], + source_format="apriel2", + ) + + # Verify mamba weights use Init (random init) + has_mamba_init = False + for key in plan.target_keys(): + key_str = str(key) + if "mamba" in key_str: + expr = plan.mappings[key] + if isinstance(expr, Init): + has_mamba_init = True + break + + assert has_mamba_init, "build_plan should use Init for init: random components" + + def test_build_plan_strips_init_between_iterations(self, base_config_with_bias_dict): + """build_plan strips init between iterations (T → S conversion). + + This tests that the intermediate state between surgeries has no init fields. + The composed plan will show Init expressions because plan composition + substitutes Ref → Init, but the semantics are correct: GDN is initialized + once (in surgery 1), not re-randomized in surgery 2. + """ + from fast_llm_external_models.apriel2.conversion import ( + compose_configs, strip_init_fields, plan_surgery, compose + ) + + # Surgery 1: Add GDN with random init + surgery1 = { + "decoder": { + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"init": "transfer"}, + "gdn": { + "type": "gdn", + "init": "random", + "convolution_layer": {"kernel_size": 4}, + }, + }, + }, + }, + }, + } + + # Surgery 2: Add sliding window (doesn't mention GDN) + surgery2 = { + "decoder": { + "block": { + "mixer": { + "mixers": { + "sliding_window": {"init": "transfer", "window_size": 512}, + }, + }, + }, + }, + } + + # Simulate build_plan's iteration loop + s0 = base_config_with_bias_dict + + # Iteration 1 + t1 = compose_configs(s0, surgery1) + assert t1["decoder"]["block"]["mixer"]["mixers"]["gdn"].get("init") == "random" + s1 = strip_init_fields(t1) + assert s1["decoder"]["block"]["mixer"]["mixers"]["gdn"].get("init") is None + + # Iteration 2: s1 has no init for GDN + t2 = compose_configs(s1, surgery2) + assert t2["decoder"]["block"]["mixer"]["mixers"]["gdn"].get("init") is None, \ + "GDN should have no init in T2 (wasn't in surgery2, stripped from s1)" + + # plan_surgery(s1, t2) should use Ref for GDN (transfer, not random) + plan2 = plan_surgery(s1, t2) + gdn_uses_ref = False + for key in plan2.target_keys(): + if "gdn" in str(key): + expr = plan2.mappings[key] + if isinstance(expr, Ref): + gdn_uses_ref = True + break + + assert gdn_uses_ref, "plan_surgery(s1, t2) should use Ref for GDN (transfer from s1)" From b6dd6dc563db08de16ba76c89f335cdb8a014818 Mon Sep 17 00:00:00 2001 From: Torsten Scholak Date: Tue, 16 Dec 2025 21:47:34 +0000 Subject: [PATCH 08/25] =?UTF-8?q?Fix=20O(n=C2=B2)=20tokenization=20and=20a?= =?UTF-8?q?dd=20Qwen2=20training=20examples?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Replace apply_chat_template_with_spans with tokenize_chat (O(n) token-level) - Add _mask_to_spans helper to convert boolean mask to loss masking spans - Fix chat template docs: entire assistant turn must be in {% generation %} - Add parameterized tests with exact expected tokens and trainable indices - Add prepare_tulu3.yaml and train_supernet_qwen2.yaml examples - Document performance tuning (~8k tokens/s, ~61GB memory, ~25h for 1B tokens) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- .../data/preparator/gpt_memmap/prepare.py | 44 ++-- fast_llm/data/preprocessing/tokenizer.py | 77 +++---- .../apriel2/examples/prepare_tulu3.yaml | 103 ++++++++++ .../examples/train_supernet_qwen2.yaml | 193 ++++++++++++++++++ tests/data/test_tokenizer.py | 89 ++++++-- 5 files changed, 427 insertions(+), 79 deletions(-) create mode 100644 fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml create mode 100644 fast_llm_external_models/apriel2/examples/train_supernet_qwen2.yaml diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index f349b1979..a9beca42f 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -220,21 +220,25 @@ def _preprocessing_config(self) -> LanguageModelPreprocessingConfig: ) def _prepare_sample(self, sample: dict[str, typing.Any]) -> LanguageModelSample: - all_spans = [] - if self._source_schema.has_conversation: - # Conversation format: apply chat template and compute loss masking spans - messages = sample[self._source_schema.messages] - text, loss_masking_spans = self._tokenizer.apply_chat_template_with_spans( - messages, - add_generation_prompt=self._source_schema.add_generation_prompt, + tokens, train_mask = self._tokenizer.tokenize_chat( + sample[self._source_schema.messages], + self._source_schema.add_generation_prompt, + data_type=self._data_type, + ) + return LanguageModelSample( + TokenSample(tokens, [len(tokens)]), + RangeSample(_mask_to_spans(train_mask), len(tokens)), + None, + None, + None, ) - all_spans.extend([(SpanType.loss_masking, span) for span in loss_masking_spans]) - else: - # Plain text format - text = sample[self._source_schema.text] - if self._source_schema.has_loss_masking_span and not self._source_schema.has_conversation: + # Text format: use the text-spans pipeline + text = sample[self._source_schema.text] + all_spans = [] + + if self._source_schema.has_loss_masking_span: # Spans are typically stored in the (begin, last) format. We convert to (begin, end) range format. loss_masking_spans = _sort_spans( (SpanType.loss_masking, (begin, last + 1)) @@ -495,3 +499,19 @@ def _get_nearest_split(cumsum: np.ndarray, value: float) -> int: if left == len(cumsum): return left.item() return left.item() + 1 if (value - cumsum[left]) / (cumsum[left + 1] - cumsum[left]) > 0.5 else left.item() + + +def _mask_to_spans(mask: list[bool]) -> list[tuple[int, int]]: + """Convert a boolean train mask to loss masking spans (where mask[i] == False).""" + spans = [] + start = None + for i, value in enumerate(mask): + if not value: + if start is None: + start = i + elif start is not None: + spans.append((start, i)) + start = None + if start is not None: + spans.append((start, len(mask))) + return spans diff --git a/fast_llm/data/preprocessing/tokenizer.py b/fast_llm/data/preprocessing/tokenizer.py index 372d8cd90..f3b5a51a8 100644 --- a/fast_llm/data/preprocessing/tokenizer.py +++ b/fast_llm/data/preprocessing/tokenizer.py @@ -242,32 +242,17 @@ def validate_chat_template(self) -> None: "Please use a tokenizer with generation markers in its chat template." ) - def apply_chat_template_with_spans( + def tokenize_chat( self, messages: list[dict[str, str]], - *, add_generation_prompt: bool = False, - ) -> tuple[str, list[tuple[int, int]]]: - """ - Apply the tokenizer's chat template to messages and compute loss masking spans. - - This method converts a list of messages (OpenAI/Tulu format) into formatted - text and computes character-level spans that should be MASKED (not trained on). - - Note: Call validate_chat_template() once before using this method to ensure - the tokenizer has a valid chat template with generation markers. - - Args: - messages: List of message dicts with 'role' and 'content' keys. - add_generation_prompt: Whether to add a generation prompt at the end. + begin: bool = True, + end: bool = True, + data_type: DataType = DataType.int64, + ) -> tuple["torch.Tensor", list[bool]]: + """Apply chat template and return (tokens, train_mask) where train_mask[i]=True means train on token i.""" + import torch - Returns: - Tuple of (formatted_text, loss_masking_spans) where loss_masking_spans - is a list of (start, end) character positions to MASK (not train on). - """ - if not messages: - return "", [] - # Get tokens and assistant mask result = self.tokenizer.apply_chat_template( messages, tokenize=True, @@ -275,40 +260,24 @@ def apply_chat_template_with_spans( return_dict=True, add_generation_prompt=add_generation_prompt, ) - tokens = result["input_ids"] train_mask = result["assistant_masks"] - # Get text for output - full_text = self.tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=add_generation_prompt, - ) + # Prepend BOS / append EOS if needed (avoid O(n) insert) + prepend_bos = begin and (not tokens or tokens[0] != self.bod_id) + append_eos = end and (not tokens or tokens[-1] != self.eod_id) + tokens = [self.bod_id] * prepend_bos + list(tokens) + [self.eod_id] * append_eos + train_mask = [False] * prepend_bos + [bool(m) for m in train_mask] + [False] * append_eos - # Convert token mask to character spans using detokenization - # We need spans for tokens where train_mask=0 (should be masked/not trained on) - loss_masking_spans = [] - current_span_start = None - - # Track character positions by decoding incrementally - char_positions = [0] - for i in range(len(tokens)): - decoded = self.tokenizer.decode(tokens[: i + 1]) - char_positions.append(len(decoded)) - - for i, is_train in enumerate(train_mask): - if not is_train: # This token should be masked - if current_span_start is None: - current_span_start = char_positions[i] - else: # This token should be trained on - if current_span_start is not None: - loss_masking_spans.append((current_span_start, char_positions[i])) - current_span_start = None - - # Close any open span - if current_span_start is not None: - loss_masking_spans.append((current_span_start, char_positions[-1])) - - return full_text, loss_masking_spans + if self._config.max_vocab_size is not None: + tokens = ( + torch.tensor( + tokens, + dtype=torch.int64 if len(self.tokenizer) > torch.iinfo(data_type.torch).max else data_type.torch, + ) + % self._config.max_vocab_size + ).to(data_type.torch) + else: + tokens = torch.tensor(tokens, dtype=data_type.torch) + return tokens, train_mask diff --git a/fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml b/fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml new file mode 100644 index 000000000..ba85c1aed --- /dev/null +++ b/fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml @@ -0,0 +1,103 @@ +# Dataset preparation config for Tulu 3 SFT mixture with Qwen2 tokenizer +# +# This config converts the Tulu 3 SFT dataset (conversation format) to +# Fast-LLM's memmap format, with automatic loss masking span computation +# to train only on assistant responses. +# +# ============================================================================= +# TOKENIZER SETUP (one-time) +# ============================================================================= +# +# The tokenizer must have a chat template with {% generation %} markers. +# Qwen2's default template doesn't have these, so we need to patch it. +# +# IMPORTANT: The entire assistant turn (opening tag + content + closing tag) +# must be inside the {% generation %} block. This ensures the model learns to +# produce the full assistant response including special tokens like <|im_end|>. +# Reference: https://huggingface.co/HuggingFaceTB/SmolLM3-3B/blob/main/chat_template.jinja +# +# Run this Python script to create a patched tokenizer: +# +# from transformers import AutoTokenizer +# +# tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") +# +# # Patch chat template: wrap ENTIRE assistant turn in generation markers +# tokenizer.chat_template = '''{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system +# You are a helpful assistant.<|im_end|> +# ' }}{% endif %}{% if message['role'] == 'assistant' %}{% generation %}{{ '<|im_start|>assistant +# ' + message['content'] + '<|im_end|> +# ' }}{% endgeneration %}{% else %}{{ '<|im_start|>' + message['role'] + ' +# ' + message['content'] + '<|im_end|> +# ' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant +# ' }}{% endif %}''' +# +# tokenizer.save_pretrained("/path/to/qwen2-instruct-with-markers") +# +# ============================================================================= +# DATA PREPARATION +# ============================================================================= +# +# Small dataset (for testing): +# +# fast-llm prepare gpt_memmap \ +# -c fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml \ +# dataset.split=train[:1000] \ +# output_path=/path/to/tulu3-prepared-small +# +# Full dataset (~939K samples, ~6 minutes): +# +# fast-llm prepare gpt_memmap \ +# -c fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml +# +# ============================================================================= +# VERIFICATION +# ============================================================================= +# +# To verify the prepared dataset has loss masking spans: +# +# import pathlib +# from fast_llm.data.dataset.memmap import MemmapDataset +# from fast_llm.data.sample.language_model import LanguageModelSample +# from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig +# +# dataset = MemmapDataset[LanguageModelSample]( +# 'tulu3', +# pathlib.Path('/path/to/tulu3-prepared/shard_0_0.fast_llm_dataset'), +# LanguageModelPreprocessingConfig(use_loss_masking_spans=True) +# ) +# +# doc = dataset.get_document(0) +# print(f'Tokens: {len(doc.tokens.tokens)}') +# print(f'Loss masking spans: {doc.loss_masking_spans.ranges}') +# +# ============================================================================= + +# Dataset configuration +dataset: + # Tulu 3 SFT mixture from AllenAI + path: allenai/tulu-3-sft-mixture + split: train + + # Source schema for conversation format + source_schema: + # Use conversation type (vs default "text" type) + type: conversation + + # Column containing the messages list + messages: messages + +# Tokenizer configuration +# IMPORTANT: Must use a tokenizer with {% generation %} markers in its chat template. +# See instructions above to create a patched Qwen2 tokenizer. +tokenizer: + path: /path/to/qwen2-instruct-with-markers + # Qwen2 doesn't have a BOS token by default, use <|endoftext|> as BOS + bos_token: "<|endoftext|>" + +# Output configuration +output_path: /path/to/tulu3-prepared + +# Processing configuration +num_workers: 8 +documents_per_shard: 100000 diff --git a/fast_llm_external_models/apriel2/examples/train_supernet_qwen2.yaml b/fast_llm_external_models/apriel2/examples/train_supernet_qwen2.yaml new file mode 100644 index 000000000..5b190955f --- /dev/null +++ b/fast_llm_external_models/apriel2/examples/train_supernet_qwen2.yaml @@ -0,0 +1,193 @@ +# Training config for Qwen2-based Apriel2 stochastic supernet on Tulu 3 SFT data +# +# This config trains a stochastic supernet where each layer can sample from +# multiple mixer types (attention, sliding window, gated delta net, KDA). +# Only the mixer weights are trained; all other weights are frozen. +# Activation-level distillation from a teacher model guides the training. +# +# ============================================================================= +# PREREQUISITES +# ============================================================================= +# +# 1. TOKENIZER SETUP +# +# Qwen2's default chat template doesn't have generation markers needed for +# loss masking. Create a patched tokenizer following the SmolLM3 pattern: +# https://huggingface.co/HuggingFaceTB/SmolLM3-3B/blob/main/chat_template.jinja +# +# IMPORTANT: The ENTIRE assistant turn (opening tag + content + closing tag) +# must be inside {% generation %}...{% endgeneration %} markers. +# +# from transformers import AutoTokenizer +# tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") +# # Wrap entire assistant turn in generation markers (NOT just content!) +# tokenizer.chat_template = '''{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system +# You are a helpful assistant.<|im_end|> +# ' }}{% endif %}{% if message['role'] == 'assistant' %}{% generation %}{{ '<|im_start|>assistant +# ' + message['content'] + '<|im_end|> +# ' }}{% endgeneration %}{% else %}{{ '<|im_start|>' + message['role'] + ' +# ' + message['content'] + '<|im_end|> +# ' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant +# ' }}{% endif %}''' +# tokenizer.save_pretrained("/path/to/qwen2-instruct-with-markers") +# +# 2. PREPARE TULU 3 DATASET +# +# Small dataset (for testing): +# +# fast-llm prepare gpt_memmap \ +# -c fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml \ +# tokenizer.path=/path/to/qwen2-instruct-with-markers \ +# dataset.split=train[:1000] \ +# output_path=/path/to/tulu3-prepared-small +# +# Full dataset (~939K samples, ~6 minutes): +# +# fast-llm prepare gpt_memmap \ +# -c fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml \ +# tokenizer.path=/path/to/qwen2-instruct-with-markers \ +# output_path=/path/to/tulu3-prepared +# +# 3. CONVERT QWEN2 TO APRIEL2 SUPERNET (student model) +# +# This creates a stochastic supernet with multiple mixer types per layer: +# +# python fast_llm_external_models/apriel2/convert.py \ +# Qwen/Qwen2.5-0.5B-Instruct \ +# /path/to/qwen2-supernet \ +# --surgery fast_llm_external_models/apriel2/examples/stochastic_supernet.yaml +# +# 4. CONVERT QWEN2 TO APRIEL2 (teacher model) +# +# The teacher is the original model without surgery, used for distillation: +# +# python fast_llm_external_models/apriel2/convert.py \ +# Qwen/Qwen2.5-0.5B-Instruct \ +# /path/to/qwen2-teacher +# +# 5. RUN TRAINING +# +# Update paths below and run: +# +# fast-llm train gpt \ +# -c fast_llm_external_models/apriel2/examples/train_supernet_qwen2.yaml +# +# For long runs, use nohup: +# +# nohup fast-llm train gpt \ +# -c fast_llm_external_models/apriel2/examples/train_supernet_qwen2.yaml \ +# > training.log 2>&1 & +# tail -f training.log +# +# ============================================================================= +# PERFORMANCE TUNING +# ============================================================================= +# +# Default config uses seq=4096, micro_batch=2, batch=16 which gives: +# - ~8k tokens/s/gpu throughput +# - ~61GB GPU memory usage +# - ~25 hours for 1B tokens on single GPU +# +# Adjust batch settings based on your GPU memory: +# - Reduce micro_batch_size if OOM +# - Increase micro_batch_size/batch_size if memory available +# +# ============================================================================= +# OUTPUT +# ============================================================================= +# +# Checkpoints: /path/to/qwen2-supernet-trained/checkpoints/{iteration}/ +# Exports: /path/to/qwen2-supernet-trained/export/apriel2_text/{iteration}/ +# +# ============================================================================= + +# Load pretrained model (Qwen2 converted to Apriel2 supernet) +pretrained: + path: /path/to/qwen2-supernet + format: apriel2_text + model_weights: true + load_config: model + +# Model config +model: + base_model: + # Freeze all components except the mixer + decoder: + block: + mlp: + lr_scale: 0.0 # Freeze MLP + normalization: + lr_scale: 0.0 # Freeze layer norms + # Activation-level distillation from teacher + distillation_model: teacher + activation_distillation_factor: 0.8 + embeddings: + lr_scale: 0.0 # Freeze word embeddings + head: + lr_scale: 0.0 # Freeze output head + cross_entropy_implementation: torch + multi_stage: + zero_stage: 2 + distributed: + compute_dtype: bf16 + seed: 42 + +# Teacher model for activation-level distillation +reference_models: + teacher: + model: + type: gpt + pretrained: + path: /path/to/qwen2-teacher + format: apriel2_text + model_weights: true + load_config: model + +# Batch configuration (tuned for ~61GB GPU memory, ~8k tokens/s) +batch: + sequence_length: 4096 + micro_batch_size: 2 + batch_size: 16 + +# Data configuration (prepared Tulu 3 dataset) +data: + datasets: + training: + type: file + path: /path/to/tulu3-prepared/fast_llm_config.yaml + +# Optimizer configuration +optimizer: + learning_rate: + base: 1.0e-05 + decay_style: cosine + warmup_iterations: 100 + decay_iterations: 10000 + minimum: 1.0e-06 + weight_decay: 0.1 + beta_1: 0.9 + beta_2: 0.95 + +# Training configuration +# At seq=4096, batch=16: ~65k tokens/iter, ~280 iters/hour +# 10000 iters ≈ 650M tokens ≈ 35 hours +training: + train_iters: 10000 + num_workers: 4 + logs: + interval: 10 + checkpoint: + interval: 280 # ~hourly + export: + interval: 280 # ~hourly (useful for development/testing during training) + format: apriel2_text + test_iters: 0 + evaluators: {} + # Weights & Biases configuration (optional, uncomment to enable) + # wandb: + # entity_name: your-entity + # project_name: your-project + +# Experiment directory +run: + experiment_dir: /path/to/qwen2-supernet-trained diff --git a/tests/data/test_tokenizer.py b/tests/data/test_tokenizer.py index 4b8f45d8d..97f16c6d6 100644 --- a/tests/data/test_tokenizer.py +++ b/tests/data/test_tokenizer.py @@ -61,10 +61,13 @@ def test_validate_chat_template_with_markers(common_tokenizer): common_tokenizer.validate_chat_template() +# Realistic chat template following HF conventions (e.g., SmolLM3): +# The generation block includes the full assistant turn: opening tag, content, and closing tag. +# This ensures the model learns to emit the closing tag. CHAT_TEMPLATE = ( "{% for message in messages %}" "{% if message.role == 'assistant' %}" - "{% generation %}{{ message.content }}{% endgeneration %}" + "{% generation %}{{ message.content }}{% endgeneration %}" "{% else %}" "<{{ message.role }}>{{ message.content }}" "{% endif %}" @@ -73,24 +76,84 @@ def test_validate_chat_template_with_markers(common_tokenizer): @pytest.mark.parametrize( - ("messages", "expected_text", "expected_spans"), + ("messages", "expected_tokens", "expected_trainable_indices"), ( - ([], "", []), + # Single turn: full assistant turn (Hello) is trainable ( [{"role": "user", "content": "Hi"}, {"role": "assistant", "content": "Hello"}], - "HiHello", - [(0, 26), (31, 43)], + [49152, 27, 789, 29, 16946, 750, 789, 2293, 17822, 29, 7371, 750, 17822, 29, 49152], + [7, 8, 9, 10, 11, 12, 13], ), + # Multi-turn: both assistant turns are fully trainable ( - [{"role": "user", "content": "A"}, {"role": "assistant", "content": "B"}, {"role": "user", "content": "C"}, {"role": "assistant", "content": "D"}], - "ABCD", - [(0, 25), (26, 63), (64, 76)], + [ + {"role": "user", "content": "A"}, + {"role": "assistant", "content": "B"}, + {"role": "user", "content": "C"}, + {"role": "assistant", "content": "D"}, + ], + [49152, 27, 789, 29, 32, 750, 789, 2293, 17822, 29, 33, 750, 17822, 2293, 789, 29, 34, 750, 789, 2293, 17822, 29, 35, 750, 17822, 29, 49152], + [7, 8, 9, 10, 11, 12, 13, 19, 20, 21, 22, 23, 24, 25], + ), + # System + user + assistant: full assistant turn trainable + ( + [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello"}, + ], + [49152, 27, 3144, 29, 5815, 1139, 44569, 6928, 3144, 2293, 789, 29, 16946, 750, 789, 2293, 17822, 29, 7371, 750, 17822, 29, 49152], + [15, 16, 17, 18, 19, 20, 21], + ), + # User only: no trainable tokens + ( + [{"role": "user", "content": "Hi"}], + [49152, 27, 789, 29, 16946, 750, 789, 29, 49152], + [], + ), + # Long multi-turn (85 tokens, 3 assistant responses with tags, tests span machinery) + ( + [ + {"role": "system", "content": "You are a helpful assistant that answers questions."}, + {"role": "user", "content": "What is the capital of France?"}, + {"role": "assistant", "content": "The capital of France is Paris."}, + {"role": "user", "content": "What about Germany?"}, + {"role": "assistant", "content": "The capital of Germany is Berlin."}, + {"role": "user", "content": "And Italy?"}, + {"role": "assistant", "content": "The capital of Italy is Rome."}, + ], + [49152, 27, 3144, 29, 5815, 1139, 373, 44569, 2424, 11886, 954, 15737, 14516, 6928, 3144, 2293, 789, 29, 13938, 438, 331, 25016, 457, 12409, 562, 35838, 789, 2293, 17822, 29, 2111, 25016, 457, 12409, 562, 438, 4235, 280, 6928, 17822, 2293, 789, 29, 13938, 5028, 759, 42226, 35838, 789, 2293, 17822, 29, 2111, 25016, 457, 759, 42226, 438, 29784, 3556, 6928, 17822, 2293, 789, 29, 1996, 4413, 3326, 35838, 789, 2293, 17822, 29, 2111, 25016, 457, 4413, 3326, 438, 613, 1361, 6928, 17822, 29, 49152], + list(range(27, 41)) + list(range(49, 63)) + list(range(70, 84)), ), ), ) -def test_apply_chat_template_with_spans(common_tokenizer, messages, expected_text, expected_spans): - """Chat template produces correct text and masking spans.""" +def test_tokenize_chat(common_tokenizer, messages, expected_tokens, expected_trainable_indices): common_tokenizer.tokenizer.chat_template = CHAT_TEMPLATE - text, spans = common_tokenizer.apply_chat_template_with_spans(messages) - Assert.eq(text, expected_text) - Assert.eq(spans, expected_spans) + tokens, train_mask = common_tokenizer.tokenize_chat(messages) + Assert.eq(tokens.tolist(), expected_tokens) + Assert.eq([i for i, m in enumerate(train_mask) if m], expected_trainable_indices) + + +@pytest.mark.parametrize( + ("train_mask", "expected_loss_spans"), + ( + # All masked (no trainable tokens) + ([False, False, False], [(0, 3)]), + # All trainable (no spans) + ([True, True, True], []), + # Single trainable at start + ([True, False, False], [(1, 3)]), + # Single trainable at end + ([False, False, True], [(0, 2)]), + # Single trainable in middle + ([False, True, False], [(0, 1), (2, 3)]), + # Multiple trainable regions (simulates multi-turn conversation) + ([False, False, True, True, False, False, True, True, True, False], [(0, 2), (4, 6), (9, 10)]), + # Alternating + ([False, True, False, True, False], [(0, 1), (2, 3), (4, 5)]), + ), +) +def test_mask_to_spans(train_mask, expected_loss_spans): + from fast_llm.data.preparator.gpt_memmap.prepare import _mask_to_spans + + Assert.eq(_mask_to_spans(train_mask), expected_loss_spans) From f61a6d1088d1124afce4e0cd05ef4396134d7b77 Mon Sep 17 00:00:00 2001 From: Torsten Scholak Date: Tue, 16 Dec 2025 21:53:16 +0000 Subject: [PATCH 09/25] Improve Apriel2 conversion config composition and documentation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Refactor config.py with clearer algebraic structure documentation - Document State (S), Partial Surgery (P), and Transition Spec (T) types - Clarify monoid structure and action laws for config composition - Update activation_distillation_factor from 0.1 to 0.8 in small example 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- .../apriel2/conversion/__init__.py | 159 +++++++----- .../apriel2/conversion/config.py | 233 ++++++++++++------ .../apriel2/conversion/converters.py | 73 ++++-- fast_llm_external_models/apriel2/convert.py | 17 +- .../examples/train_supernet_small.yaml | 2 +- 5 files changed, 323 insertions(+), 161 deletions(-) diff --git a/fast_llm_external_models/apriel2/conversion/__init__.py b/fast_llm_external_models/apriel2/conversion/__init__.py index 60fc0ef0a..c6bad6626 100644 --- a/fast_llm_external_models/apriel2/conversion/__init__.py +++ b/fast_llm_external_models/apriel2/conversion/__init__.py @@ -1,86 +1,122 @@ """Weight conversion system for Apriel2 models. -Architecture Overview -===================== +Overview +======== -This package implements a declarative weight transformation system with two -orthogonal concerns: +This package implements a declarative weight transformation system. The core +abstraction separates config composition (structural) from plan execution (weights). -1. **Config Composition** - Structural transformations of model configs -2. **Plan Building & Execution** - Weight transformations between configs +Conceptual Types +================ -These concerns are intentionally separated: -- Config composition determines WHAT the target architecture looks like -- Plan building determines HOW weights are transformed to match -- The `init` field bridges them: it's config metadata consumed by the plan builder +All configs are ``dict``, but we distinguish three conceptual types: -Key Design Decisions -==================== +**State (S)** - A complete model config without ``init`` fields. + What you load from disk or save after conversion. -**Declarative Plans** - Plans are DATA (JSON-serializable expressions), not functions. This enables: - - Inspection and debugging of transformations - - Serialization for distributed execution - - Composition via substitution rather than function composition - -**Separation of Config and Weights** - The `init` field in surgery specs controls weight handling (transfer vs random) - but does NOT affect config composition. Config composition is purely structural. - After composition, `init` fields are stripped from complete configs. - -**Composition Semantics** - Surgery specs use declarative (merge) composition, not operational (function) - composition. For "additive" surgeries (modifying existing structure), the - monoid action law holds. For "replacement" surgeries (defining complete new - structure), sequential application differs from composed application by design. - -**Cross-Type Derivation** - When converting between mixer types (e.g., attention → mamba), geometric - parameters are derived where possible: - - attention.heads → mamba dimensions (MIL conversion) - - attention.heads → gdn heads (DIL conversion) +**Partial Surgery (P)** - An incomplete config specifying changes. + May contain ``init`` fields (``transfer`` or ``random``). -Module Structure -================ +**Transition Spec (T)** - A complete config WITH ``init`` fields. + The result of applying surgery to a state. Describes both target + structure and weight initialization mode. + +Algebraic Structure +=================== + +**Monoid**: Partial surgeries compose via deep merge:: + + compose_configs : P × P → P + +**Action**: Surgeries act on states to produce transition specs:: + + compose_configs : S × P → T + compose_configs : T × P → T + +**Extraction**: Strip init to get a state:: + + strip_init_fields : T → S + +**Planning**: Build weight transformation from source state + transition spec:: + + plan_surgery : S × T → Plan + +The ``init`` Field +================== + +The ``init`` field in surgeries specifies weight initialization: -- `config.py` - Config composition (compose_configs, apply_surgery) -- `converters.py` - Plan builders (plan_surgery, plan_mil_attention_to_mamba, etc.) -- `expr.py` - Expression types and plan class (Ref, Slice, Concat, Init, ExprPlan) -- `executor.py` - Plan execution (StreamingExecutor, execute) -- `io.py` - Streaming I/O (SafetensorLoader, ShardedSafetensorWriter) -- `llava/` - Source-specific converter for Llava → Apriel2 +- ``init: transfer`` → transfer/convert weights from source +- ``init: random`` → randomly initialize weights -Example Usage +This field is preserved through ``compose_configs`` so ``plan_surgery`` can read it. +Use ``strip_init_fields`` before saving configs to disk. + +Typical Usage ============= +:: + from fast_llm_external_models.apriel2.conversion import ( compose_configs, plan_surgery, + strip_init_fields, execute, ) - # 1. Compose configs to get target architecture - target_config = compose_configs(source_config, surgery_spec) + # Load source state + source_state = load_config(...) # S - # 2. Build plan for weight transformation - plan = plan_surgery(source_config, target_config) + # Apply surgery + surgery = {"decoder": {"block": {"mixer": {"type": "gdn", "init": "random"}}}} # P + transition = compose_configs(source_state, surgery) # T - # 3. Execute plan to transform weights - target_weights = execute(plan, source_weights, seed=42) + # Build and execute plan + plan = plan_surgery(source_state, transition) + weights = execute(plan, source_weights, seed=42) -For streaming I/O with large models: + # Save (strip init first) + target_state = strip_init_fields(transition) # S + save_config(target_state) - from fast_llm_external_models.apriel2.conversion import ( - StreamingExecutor, - SafetensorLoader, - ShardedSafetensorWriter, - ) +For chained surgeries:: + + current_state = source_state # S + current_plan = identity_plan + + for surgery in surgery_chain: # each P + transition = compose_configs(current_state, surgery) # T + plan = plan_surgery(current_state, transition) + current_plan = compose(current_plan, plan) + current_state = strip_init_fields(transition) # S <- IMPORTANT! + +**Note**: The ``strip_init_fields`` call is critical. It ensures that ``init: random`` +applies only to the surgery that introduces a component. Without stripping, subsequent +surgeries would re-randomize existing components. See ``config.py`` docstring for details. + +Key Design Decisions +==================== + +**Declarative Plans** + Plans are data (expressions), not functions. Enables inspection, + serialization, and composition via substitution. + +**Inheritance Semantics** + When S × P → T, unspecified fields inherit from source. + Cross-type derivation maps geometry (attention.heads → gdn.value_heads). + +**Additive vs Replacement Surgeries** + Additive surgeries (no ``type:`` declaration) satisfy the action law. + Replacement surgeries (explicit ``type:``) use last-write-wins. + +Module Structure +================ - with SafetensorLoader(source_files) as loader: - executor = StreamingExecutor(plan, loader) - with ShardedSafetensorWriter(output_dir) as writer: - for key, tensor in executor.execute(seed=42): - writer.add(key, tensor) +- ``config.py`` - Config composition (compose_configs, strip_init_fields) +- ``converters.py`` - Plan builders (plan_surgery, plan_mil_attention_to_mamba) +- ``expr.py`` - Expression types (Ref, Slice, Concat, Init, ExprPlan) +- ``executor.py`` - Plan execution (StreamingExecutor, execute) +- ``io.py`` - Streaming I/O (SafetensorLoader, ShardedSafetensorWriter) """ # Core types and plan operations @@ -127,7 +163,7 @@ ) # Config composition -from fast_llm_external_models.apriel2.conversion.config import compose_configs +from fast_llm_external_models.apriel2.conversion.config import compose_configs, strip_init_fields # Source-specific converters from fast_llm_external_models.apriel2.conversion.llava import ( @@ -175,6 +211,7 @@ "plan_kil_attention_to_kda", # Config composition "compose_configs", + "strip_init_fields", # Source-specific converters "convert_llava_config", "plan_llava_to_apriel2", diff --git a/fast_llm_external_models/apriel2/conversion/config.py b/fast_llm_external_models/apriel2/conversion/config.py index f5b19e208..3752688c1 100644 --- a/fast_llm_external_models/apriel2/conversion/config.py +++ b/fast_llm_external_models/apriel2/conversion/config.py @@ -1,59 +1,136 @@ """Config composition for Apriel2 architecture transformations. -This module handles STRUCTURAL composition of configs, independent of weight handling. -The `init` field in surgery specs is metadata for plan_surgery(), not for composition. +Conceptual Types +================ + +The system operates on three conceptual types, all represented as ``dict``: + +**State (S)** + A complete structural description of a model. Has ``hidden_size`` and ``decoder``. + Does NOT contain ``init`` fields. Represents WHAT a model looks like. + + Example: A saved config.json, or a model you're about to transform. + +**Partial Surgery (P)** + An incomplete config specifying fields to change. Missing ``hidden_size`` or + ``decoder``. May contain ``init`` fields specifying weight initialization mode. + + Example: ``{"decoder": {"block": {"mixer": {"type": "gdn", "init": "random"}}}}`` + +**Transition Spec (T)** + A complete config WITH ``init`` fields. Describes both the target structure + AND how to initialize weights. This is the output of applying a surgery to + a state - it's a complete specification of the transformation. + + Example: The result of ``compose_configs(state, surgery)`` before stripping. + +The distinction between S and T is semantic (presence of ``init``), not structural. +Both are "complete" in the sense of having ``hidden_size`` and ``decoder``. Algebraic Structure =================== -The system has a precise algebraic structure with two interacting components: +**Partial Surgeries form a Monoid (P, ∘, {})**:: + + compose_configs : P × P → P (deep merge, overlay wins) + + Identity: compose_configs(p, {}) = compose_configs({}, p) = p + Associativity: compose_configs(compose_configs(a, b), c) + = compose_configs(a, compose_configs(b, c)) + +**Surgeries act on States to produce Transition Specs**:: + + compose_configs : S × P → T (apply surgery with inheritance) + compose_configs : T × P → T (extend transition with more surgery) -**Surgery Specs (Monoid)** - Partial config dicts form a monoid under deep merge: - - Identity: {} (empty dict) - - Operation: compose_configs(partial1, partial2) = deep_merge(partial1, partial2) - - Associativity: (a ∘ b) ∘ c = a ∘ (b ∘ c) +**Action Law (for additive surgeries)**:: -**Complete Configs (Monoid Action)** - Surgery specs ACT on complete configs: - - Action: compose_configs(complete, partial) → complete - - For additive surgeries: (s · t₁) · t₂ = s · (t₁ ∘ t₂) - - For replacement surgeries: action law intentionally fails (last-write-wins) + compose_configs(compose_configs(s, p₁), p₂) = compose_configs(s, compose_configs(p₁, p₂)) -This separation is fundamental: surgery specs compose declaratively (what fields to -merge), while the action on configs interprets those fields with inheritance semantics. +This law holds when surgeries are "additive" (modifying existing structure without +declaring new types). For "replacement" surgeries (explicitly declaring ``type:``), +the action law intentionally fails - this is last-write-wins semantics. -Composition Cases -================= +**State Extraction**:: -compose_configs(base, overlay) dispatches based on completeness: + strip_init_fields : T → S (remove init metadata for saving) -1. **Complete + Partial** → Monoid action (inheritance, cross-type derivation) -2. **Partial + Partial** → Monoid operation (deep merge) -3. **Partial + Complete** → Overlay wins (complete replaces partial) -4. **Complete + Complete** → Deep merge, strip `init` fields +Operations Summary +================== -A config is "complete" if it has `hidden_size` and `decoder`. +``compose_configs(base, overlay)`` dispatches based on completeness: + +1. **S × P → T** : Apply surgery to state (inheritance, cross-type derivation) +2. **T × P → T** : Extend transition spec with more surgery +3. **P × P → P** : Merge partial surgeries (monoid operation) +4. **S × S → S** : Merge states (deep merge, rare) +5. **P × S → S** : Overlay wins (complete replaces partial) + +``strip_init_fields(config)`` removes all ``init`` fields, converting T → S. Inheritance Semantics ===================== -When the monoid action applies a surgery to a complete config: +When applying a surgery (S × P → T): -- Unspecified fields inherit from source -- New blocks inherit from the "default" block +- Unspecified fields inherit from source state +- New decoder blocks inherit from the "default" block - Cross-type derivation maps geometry (attention.heads → gdn.value_heads, etc.) -- Stochastic mixers: additive (no type decl) preserves source, replacement replaces +- Stochastic mixers: additive surgery preserves source mixers, replacement replaces -The `init` Field -================ +The ``init`` Field +================== + +The ``init`` field specifies weight initialization mode for ``plan_surgery()``: + +- ``init: transfer`` → transfer weights from source (possibly with conversion) +- ``init: random`` → randomly initialize weights + +**Key invariant**: ``init`` is preserved through composition so ``plan_surgery()`` +can read it. Use ``strip_init_fields()`` to obtain a pure state for: + +- Saving to disk (config.json should not contain ``init``) +- Starting the next surgery iteration (current_state should be S, not T) + +Typical Usage Pattern +===================== + +:: + + current_state: S = load_config(...) + + for surgery: P in surgery_chain: + transition: T = compose_configs(current_state, surgery) # S × P → T + plan = plan_surgery(current_state, transition) # plan reads init from T + current_state: S = strip_init_fields(transition) # T → S for next iteration -The `init` field is metadata for plan_surgery(), NOT for config composition: -- `init: transfer` → plan uses weight transfer/conversion -- `init: random` → plan uses random initialization + save_config(current_state) # S has no init fields -After composition produces a complete config, ALL `init` fields are stripped. -This ensures configs are purely structural and plan creation is Markovian. +Sequential vs Merged Surgery Application +======================================== + +**IMPORTANT**: Applying surgeries sequentially (with stripping) differs from merging +surgeries first then applying once. This affects ``init`` semantics: + +**Sequential** (recommended):: + + t1 = compose_configs(s, p1) # GDN gets init: random + s1 = strip_init_fields(t1) # GDN loses init + t2 = compose_configs(s1, p2) # GDN has init: None → transfer mode + +**Merged**:: + + merged = compose_configs(p1, p2) # GDN keeps init: random from p1 + t = compose_configs(s, merged) # GDN has init: random → random mode + +The sequential approach means ``init: random`` applies **only to the surgery that +introduces a component**. Subsequent surgeries transfer existing weights by default. + +This is the intended behavior: if surgery 1 adds GDN with random init, and surgery 2 +adds sliding window (not mentioning GDN), GDN keeps its weights from surgery 1. + +The merged approach would re-randomize GDN in every execution, which is rarely desired. +Always use the sequential pattern shown in "Typical Usage Pattern" above. """ from __future__ import annotations @@ -68,49 +145,42 @@ def is_complete(config: dict) -> bool: def compose_configs(base: dict, overlay: dict | None) -> dict: - """Compose two configs using monoid or monoid action semantics. + """Compose configs. Dispatches based on completeness of arguments. - This function implements two algebraic operations depending on argument types: + Type Signatures (see module docstring for S, P, T definitions):: - 1. **Monoid Action** (complete + partial): Apply surgery to a complete config. - Unspecified fields inherit from base; `init` fields are stripped from result. + S × P → T Apply surgery to state, get transition spec + T × P → T Extend transition spec with more surgery + P × P → P Merge partial surgeries (monoid operation) + S × S → S Merge states (deep merge) + P × S → S Overlay wins - 2. **Monoid Operation** (partial + partial): Merge two surgery specs. - Deep merge with overlay winning on conflicts; `init` fields preserved. + The ``init`` field is preserved in all cases. Use ``strip_init_fields()`` + to convert T → S for saving or iteration. Args: - base: Base config (complete) or surgery spec (partial). - overlay: Surgery spec to apply (partial) or config to merge. + base: State (S), transition spec (T), or partial surgery (P). + overlay: Partial surgery (P) or state (S). Returns: - - If base is complete: Complete config with surgery applied, `init` stripped. - - If both partial: Merged surgery spec with `init` preserved. + Composed config. Type depends on inputs (see signatures above). Algebraic Properties: - Surgery specs form a monoid: (a ∘ b) ∘ c = a ∘ (b ∘ c), identity = {} - - For additive surgeries, the action law holds: - compose(compose(s, t1), t2) == compose(s, compose(t1, t2)) - - For replacement surgeries (declaring type:), action law intentionally fails. + Monoid: ``compose(compose(p1, p2), p3) == compose(p1, compose(p2, p3))`` - Example: - # Apply surgery to complete config (monoid action) - source = {"hidden_size": 256, "decoder": {...}} # complete - surgery = {"decoder": {"block": {"mixer": {"type": "mamba"}}}} # partial + Action law (additive surgeries): + ``compose(compose(s, p1), p2) == compose(s, compose(p1, p2))`` - target = compose_configs(source, surgery) - # target is complete with inherited fields, init stripped + Example:: - # Merge two surgery specs (monoid operation) - s1 = {"decoder": {"block": {"mixer": {"mixers": {"a": {...}}}}}} - s2 = {"decoder": {"block": {"mixer": {"mixers": {"b": {...}}}}}} + # S × P → T (apply surgery to state) + state = {"hidden_size": 256, "decoder": {...}} + surgery = {"decoder": {"block": {"mixer": {"init": "random"}}}} + transition = compose_configs(state, surgery) # T, has init - merged = compose_configs(s1, s2) - # merged has both mixers a and b, init preserved - - # Use composed config with plan_surgery - plan = plan_surgery(source, target) + # Build plan, then extract state + plan = plan_surgery(state, transition) + new_state = strip_init_fields(transition) # S, no init """ if not overlay: return copy.deepcopy(base) @@ -132,9 +202,8 @@ def compose_configs(base: dict, overlay: dict | None) -> dict: if not base_complete and overlay_complete: return copy.deepcopy(overlay) - # Case 4: Both complete -> deep merge + # Case 4: Both complete -> deep merge (init preserved for plan_surgery) result = _deep_merge(base, overlay) - _strip_keys(result, {"init"}) return result @@ -166,6 +235,29 @@ def _strip_keys(config: Any, keys_to_strip: set[str]) -> None: _strip_keys(item, keys_to_strip) +def strip_init_fields(config: dict) -> dict: + """Return a copy of config with all ``init`` fields stripped (T → S). + + Converts a transition spec (T) to a state (S) by removing ``init`` metadata. + Use this: + + 1. Before saving configs to disk (config.json should be purely structural) + 2. Between surgery iterations (so subsequent surgeries don't re-randomize) + + See module docstring section "Sequential vs Merged Surgery Application" for + why stripping between iterations is critical. + + Args: + config: Config dict (not modified). Typically a transition spec (T). + + Returns: + A deep copy with all ``init`` fields recursively removed (a state S). + """ + result = copy.deepcopy(config) + _strip_keys(result, {"init"}) + return result + + # ============================================================================= # Surgery application with full semantics # ============================================================================= @@ -182,14 +274,14 @@ def apply_surgery(source_config: dict, surgery_config: dict | None) -> dict: - Unspecified fields inherit from source - Cross-type derivation maps geometry (attention → gdn, etc.) - Stochastic sub-mixers inherit from source's main mixer - - All `init` fields stripped from result + - `init` fields are PRESERVED for plan_surgery() to see Args: source_config: Complete Apriel2 config (the state being acted on). surgery_config: Partial surgery spec (the monoid element acting). Returns: - Complete config with surgery applied, `init` fields stripped. + Complete config with surgery applied. `init` fields preserved. """ if not surgery_config: return copy.deepcopy(source_config) @@ -231,8 +323,9 @@ def apply_surgery(source_config: dict, surgery_config: dict | None) -> dict: surgery_config["vision_encoder"], ) - # Strip init keys from final result - _strip_keys(result, {"init"}) + # NOTE: We do NOT strip init keys here. The `init` field is preserved through + # composition so that plan_surgery() can see it and decide between transfer + # vs random initialization. The caller (convert.py) strips init before saving. return result diff --git a/fast_llm_external_models/apriel2/conversion/converters.py b/fast_llm_external_models/apriel2/conversion/converters.py index b54bb5a87..c8b83f657 100644 --- a/fast_llm_external_models/apriel2/conversion/converters.py +++ b/fast_llm_external_models/apriel2/conversion/converters.py @@ -831,44 +831,69 @@ def plan_surgery( source_config: dict, target_config: dict, ) -> ExprPlan: - """Build a weight conversion plan between two Apriel2 configurations. + """Build a weight conversion plan: S × T → Plan. - This function creates an ExprPlan that maps source weight keys to expressions - defining how to compute target weights. The plan handles same-type passthrough, - cross-type conversions (MIL, DIL, KIL), and stochastic mixer routing. + Creates an ExprPlan mapping target weight keys to expressions over source weights. + Handles same-type passthrough, cross-type conversions (MIL, DIL, KIL), and + stochastic mixer routing. + + Type Signature:: + + plan_surgery : S × T → Plan + + Where S is a state (source) and T is a transition spec (target with ``init`` fields). + + The ``init`` Field + ------------------ + + The ``init`` field in ``target_config`` controls weight initialization: + + - ``init: transfer`` (or absent) → create Ref expressions (transfer from source) + - ``init: random`` → create Init expressions (random initialization) + + This is why ``target_config`` should be a transition spec (T) from ``compose_configs``, + not a stripped state (S). If ``init`` fields are missing, all components default to + transfer mode. Args: - source_config: Complete Apriel2 config dict describing the source architecture. - Must have all structural fields (hidden_size, decoder, etc.) fully specified. - target_config: Complete Apriel2 config dict describing the target architecture. - Must be fully specified with all inherited fields resolved. Use - compose_configs(source_config, surgery_spec) to produce this from a - partial surgery specification. + source_config: State (S) - complete config describing source architecture. + Must have hidden_size, decoder, etc. No ``init`` fields expected. + target_config: Transition spec (T) - complete config with ``init`` fields. + Use ``compose_configs(source, surgery)`` to produce this. Returns: ExprPlan mapping target weight keys to expressions over source weights. - Example: + Example:: + # Apply a surgery that wraps attention in a stochastic mixer surgery_spec = { "decoder": {"block": {"mixer": { "type": "stochastic", - "mixers": {"attention": {"type": "attention", "init": "transfer"}} + "mixers": { + "attention": {"init": "transfer"}, + "gdn": {"type": "gdn", "init": "random"}, + } }}} } - # First compose to get complete target config with inherited fields - target_config = compose_configs(source_config, surgery_spec) + # S × P → T + transition = compose_configs(source_config, surgery_spec) + + # S × T → Plan + plan = plan_surgery(source_config, transition) + + # Execute + new_weights = execute(plan, source_weights, seed=42) - # Then build the plan from two complete configs - plan = plan_surgery(source_config, target_config) - new_weights = execute(plan, source_weights, seed=0) + # T → S for saving + target_state = strip_init_fields(transition) Note: - Both arguments must be complete configs. The target_config determines the - full target architecture including all inherited fields (bias settings, - rotary config, etc.). Passing a partial surgery spec directly will result - in missing inherited fields and incorrect plans. + Both arguments must be complete (have hidden_size and decoder). + The target_config should retain ``init`` fields from the surgery spec. + Passing a stripped state as target will cause all components to use + transfer mode, which may not be intended. """ hidden_size = target_config.get("hidden_size", source_config.get("hidden_size")) assert hidden_size is not None, "hidden_size must be specified in source or target config" @@ -922,8 +947,10 @@ def _plan_non_decoder_weights(config: dict) -> ExprPlan: embed = W("model", "embed_tokens", "weight") mappings[embed] = Ref(key=embed) - head = W("lm_head", "weight") - mappings[head] = Ref(key=head) + # lm_head only if not tied to embeddings + if not config.get("tie_word_embeddings", False): + head = W("lm_head", "weight") + mappings[head] = Ref(key=head) norm = W("model", "norm", "weight") mappings[norm] = Ref(key=norm) diff --git a/fast_llm_external_models/apriel2/convert.py b/fast_llm_external_models/apriel2/convert.py index 05c38c7ce..60786d22c 100644 --- a/fast_llm_external_models/apriel2/convert.py +++ b/fast_llm_external_models/apriel2/convert.py @@ -43,6 +43,7 @@ compose, compose_configs, plan_surgery, + strip_init_fields, ) # Import source-specific converters @@ -149,15 +150,19 @@ def build_plan( # Apply surgery chain if requested if surgery_configs: for i, surgery_config in enumerate(surgery_configs, 1): - surgery_plan = plan_surgery(current_config, surgery_config) + # S × P → T: compose state with surgery to get transition spec + target_config = compose_configs(current_config, surgery_config) + + # S × T → Plan: build plan from source state and transition spec + surgery_plan = plan_surgery(current_config, target_config) logger.info(f"Built surgery plan [{i}/{len(surgery_configs)}]: {surgery_plan.summary()['num_targets']} targets") - # Compose: current -> surgery + # Compose plans current_plan = compose(current_plan, surgery_plan) logger.info(f"Composed plan [{i}/{len(surgery_configs)}]: {current_plan.summary()['num_targets']} targets") - # Compose configs: merge surgery spec into current config - current_config = compose_configs(current_config, surgery_config) + # T → S: strip init for next iteration (init is consumed by plan_surgery) + current_config = strip_init_fields(target_config) return current_plan, current_config @@ -407,11 +412,11 @@ def main(): show_plan=args.show_plan or args.verbose, ) - # Save config + # Save config (build_plan returns S which has no init, but strip defensively) output_config_file = args.output_dir / "config.json" logger.info(f"Saving config to {output_config_file}") with open(output_config_file, "w") as f: - json.dump(apriel2_config, f, indent=2) + json.dump(strip_init_fields(apriel2_config), f, indent=2) # Copy tokenizer files copy_tokenizer_files(input_dir, args.output_dir) diff --git a/fast_llm_external_models/apriel2/examples/train_supernet_small.yaml b/fast_llm_external_models/apriel2/examples/train_supernet_small.yaml index 78c22e57f..be4d06e0a 100644 --- a/fast_llm_external_models/apriel2/examples/train_supernet_small.yaml +++ b/fast_llm_external_models/apriel2/examples/train_supernet_small.yaml @@ -107,7 +107,7 @@ model: lr_scale: 0.0 # Freeze layer norms (norm_1 and norm_2 in each block) # Activation-level distillation: teach mixers to mimic teacher's attention outputs distillation_model: teacher - activation_distillation_factor: 0.1 + activation_distillation_factor: 0.8 embeddings: lr_scale: 0.0 # Freeze word embeddings head: From 8933953d61c2efa499bdd94d55bfe43e2b613955 Mon Sep 17 00:00:00 2001 From: Torsten Scholak Date: Wed, 17 Dec 2025 17:58:41 +0000 Subject: [PATCH 10/25] Fix RangeSample.from_documents and loss mask distillation bugs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - range.py: Use append() instead of extend() for tuple pairs. The extend() call was flattening tuples into individual integers, causing "cannot unpack non-iterable numpy.int64" errors when iterating over ranges. - model.py: Fix attribute name from output_layer to head. The config uses 'head' for the language model head configuration. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- fast_llm/data/sample/range.py | 2 +- fast_llm/models/gpt/model.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/fast_llm/data/sample/range.py b/fast_llm/data/sample/range.py index 8dd351e1f..22d5e8992 100644 --- a/fast_llm/data/sample/range.py +++ b/fast_llm/data/sample/range.py @@ -38,7 +38,7 @@ def from_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self: sample_size = 0 for document in documents: for begin, end in document.ranges: - ranges.extend((begin + sample_size, end + sample_size)) + ranges.append((begin + sample_size, end + sample_size)) sample_size += document.sample_size return cls(ranges, sample_size) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index a0c381439..fd8d2af1b 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -247,7 +247,7 @@ def preprocess_batch( for sample_index, loss_masking_spans in enumerate(loss_masking_spans.ranges): for begin, end in loss_masking_spans: loss_mask[sample_index, begin:end] = False - if self._config.output_layer.distillation_model is not None: + if self._config.head.distillation_model is not None: kwargs[LanguageModelKwargs.loss_mask] = loss_mask labels = torch.where(loss_mask, labels, -100) From 1277894a690f94f9183b8d1fc1338328a95b5b23 Mon Sep 17 00:00:00 2001 From: Torsten Scholak Date: Fri, 19 Dec 2025 15:41:19 +0000 Subject: [PATCH 11/25] Skip roundtrip integration tests on CPU-only CI MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Integration tests should run on realistic hardware. Roundtrip tests (Apriel2 -> Fast-LLM -> Apriel2) now skip when CUDA is unavailable. Changes: - Add CUDA check to roundtrip_converted fixture - Lazy-load roundtrip fixture in converted_model to avoid eager evaluation - Apriel2 and supernet tests still run on CPU (16 tests) - Roundtrip tests skip on CPU-only CI (8 tests) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- .../tests/test_apriel2/test_integration.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/fast_llm_external_models/tests/test_apriel2/test_integration.py b/fast_llm_external_models/tests/test_apriel2/test_integration.py index c11302d22..b90f0774e 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_integration.py +++ b/fast_llm_external_models/tests/test_apriel2/test_integration.py @@ -136,6 +136,9 @@ def supernet_converted(qwen2_source, apriel2_converted): @pytest.fixture(scope="module") def roundtrip_converted(supernet_converted, qwen2_source): """Stage 3: Supernet -> Fast-LLM -> Supernet.""" + if not torch.cuda.is_available(): + pytest.skip("Roundtrip conversion requires CUDA (integration tests need realistic hardware)") + from fast_llm.engine.checkpoint.config import ( CheckpointLoadConfig, CheckpointSaveConfig, @@ -181,18 +184,22 @@ def roundtrip_converted(supernet_converted, qwen2_source): @pytest.fixture(params=["apriel2", "supernet", "roundtrip"]) -def converted_model(request, apriel2_converted, supernet_converted, roundtrip_converted): +def converted_model(request, apriel2_converted, supernet_converted): """Parameterized fixture providing each conversion stage for testing. This allows a single test to run against all stages automatically. """ if request.param == "roundtrip": pytest.importorskip("fast_llm") + if not torch.cuda.is_available(): + pytest.skip("Roundtrip tests require CUDA (integration tests need realistic hardware)") + # Lazy-load to avoid fixture evaluation when CUDA unavailable + roundtrip_converted = request.getfixturevalue("roundtrip_converted") + return roundtrip_converted return { "apriel2": apriel2_converted, "supernet": supernet_converted, - "roundtrip": roundtrip_converted, }[request.param] From 9273966076de95ec1dc57154120c355a0b5cb88c Mon Sep 17 00:00:00 2001 From: Torsten Scholak Date: Sat, 20 Dec 2025 22:25:40 +0000 Subject: [PATCH 12/25] Refactor conversation format handling and tokenize_chat MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Split LanguageModelSourceConfig into abstract base + DocumentSourceConfig - Remove has_conversation property, use isinstance checks instead - Move _mask_to_spans to tokenizer module as _train_mask_to_loss_spans - tokenize_chat now returns (tokens, loss_masking_spans) directly - Safer BOS/EOS handling: check anywhere in tokens, not just first/last - Remove unused add_generation_prompt parameter from tokenize_chat 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- fast_llm/data/preparator/gpt_memmap/config.py | 138 ++++++------ .../data/preparator/gpt_memmap/prepare.py | 207 +++++++++--------- fast_llm/data/preprocessing/tokenizer.py | 51 ++++- .../apriel2/examples/prepare_tulu3.yaml | 2 +- tests/data/test_tokenizer.py | 29 ++- 5 files changed, 226 insertions(+), 201 deletions(-) diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 2aa0fbf31..a1aadf40a 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -18,30 +18,78 @@ @config_class(registry=True) class LanguageModelSourceConfig(Config): """ - A schema holding the name of each relevant column in the dataset. - Setting optional entries will enable the associated feature. + Abstract base class for data source schemas. - This is the base class for source schemas. Use `type: text` (default) for - plain text datasets, or `type: conversation` for chat/conversation datasets. + Use `type: document` (default) for documents with text, optional span annotations, and optional images. + Use `type: conversation` for structured chat/conversation datasets. + """ + + @classmethod + def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: + if cls is LanguageModelSourceConfig and cls.get_subclass(default.get("type")) is None: + # Default to DocumentSourceConfig when type is not specified + return DocumentSourceConfig._from_dict(default, strict) + return super()._from_dict(default, strict=strict) + + @functools.cached_property + def columns(self) -> list[str]: + """Columns to read from the dataset.""" + raise NotImplementedError + + @functools.cached_property + def has_loss_masking_span(self) -> bool: + return False + + @functools.cached_property + def has_preference_spans(self) -> bool: + return False + + @functools.cached_property + def has_images(self) -> bool: + return False + + +@config_class(dynamic_type={LanguageModelSourceConfig: "document"}) +class DocumentSourceConfig(LanguageModelSourceConfig): + """ + Source schema for document datasets with text, optional span annotations, and optional images. + + The dataset should have a text column containing the document text. + Optionally, it can have additional columns for: + - Loss masking spans: character ranges to mask from loss computation + - Preference spans: chosen/rejected text for DPO training + - Images: image data with character positions for multimodal training """ text: str = Field( default="text", - desc="Field of the dataset to use.", + desc="Field containing the document text.", hint=FieldHint.optional, ) - loss_masking_spans: None | str = Field( - default=None, desc="Field containing character spans to mask for loss computation", hint=FieldHint.optional + loss_masking_spans: str | None = Field( + default=None, + desc="Field containing character spans to mask for loss computation.", + hint=FieldHint.optional, ) - chosen_span: None | str = Field( - default=None, desc="Field containing chosen text for preference optimization", hint=FieldHint.optional + chosen_span: str | None = Field( + default=None, + desc="Field containing chosen text for preference optimization.", + hint=FieldHint.optional, ) - rejected_span: None | str = Field( - default=None, desc="Field containing rejected text for preference optimization", hint=FieldHint.optional + rejected_span: str | None = Field( + default=None, + desc="Field containing rejected text for preference optimization.", + hint=FieldHint.optional, + ) + images: str | None = Field( + default=None, + desc="Field containing images.", + hint=FieldHint.optional, ) - images: None | str = Field(default=None, desc="Field containing images", hint=FieldHint.optional) - image_positions: None | str = Field( - default=None, desc="Field containing image positions in the text.", hint=FieldHint.optional + image_positions: str | None = Field( + default=None, + desc="Field containing image positions in the text.", + hint=FieldHint.optional, ) @functools.cached_property @@ -69,28 +117,10 @@ def has_images(self) -> bool: Assert.eq(self.images is None, self.image_positions is None) return self.images is not None - @functools.cached_property - def has_conversation(self) -> bool: - """Whether this is a conversation source schema.""" - return False - def _validate(self): super()._validate() if self.has_preference_spans and self.has_loss_masking_span: - raise ValueError(f"Can not enable both loss masking and preference spans.") - - -@config_class(dynamic_type={LanguageModelSourceConfig: "text"}) -class TextSourceConfig(LanguageModelSourceConfig): - """ - Source schema for plain text datasets (default). - - The dataset should have a text column containing the document text. - Optionally, it can have additional columns for loss masking spans, - preference spans (for DPO), or images. - """ - - pass + raise ValueError("Cannot enable both loss masking and preference spans.") @config_class(dynamic_type={LanguageModelSourceConfig: "conversation"}) @@ -120,59 +150,21 @@ class ConversationSourceConfig(LanguageModelSourceConfig): } """ - # Override text field - not used directly for conversation format - text: None | str = Field( - default=None, - desc="Not used for conversation format. Text is generated from messages.", - hint=FieldHint.optional, - ) - - # Conversation-specific fields messages: str = Field( default="messages", desc="Field containing the conversation messages list. Each message should have 'role' and 'content' keys.", hint=FieldHint.core, ) - add_generation_prompt: bool = Field( - default=False, - desc="Whether to add a generation prompt at the end of the conversation. " - "Typically False for training data.", - hint=FieldHint.optional, - ) - @functools.cached_property def columns(self) -> list[str]: - # For conversation format, we read the messages column, not text - columns = [self.messages] - # Images can still be used with conversation format - if self.has_images: - columns.extend([self.images, self.image_positions]) - return columns - - @functools.cached_property - def has_conversation(self) -> bool: - return True + return [self.messages] @functools.cached_property def has_loss_masking_span(self) -> bool: - # Conversation format always generates loss masking spans + # Conversation format always generates loss masking spans from chat template markers return True - def _validate(self): - # Skip parent validation that checks text field - Config._validate(self) - if self.has_preference_spans: - raise ValueError("Preference spans are not supported with conversation format.") - if self.has_images: - # Images with conversation format would require computing image positions in the - # chat-template-formatted text, which is complex and format-dependent. - # For VLM training with conversations, preprocess the data to plain text format first. - raise ValueError( - "Images are not yet supported with conversation format. " - "For multimodal conversation data, preprocess to plain text format with image positions." - ) - @config_class() class GPTHuggingfaceDatasetConfig(Config): diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index a9beca42f..eeb925591 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -28,7 +28,12 @@ ) from fast_llm.data.dataset.memmap import MemmapDataset from fast_llm.data.preparator.config import DatasetPreparator -from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig, LanguageModelSourceConfig +from fast_llm.data.preparator.gpt_memmap.config import ( + ConversationSourceConfig, + GPTMemmapDatasetPreparatorConfig, + LanguageModelSourceConfig, + DocumentSourceConfig, +) from fast_llm.data.preprocessing.abstract import NullPreprocessingConfig from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.data.preprocessing.tokenizer import Tokenizer @@ -133,7 +138,7 @@ def run(self) -> None: self._tokenizer = self._config.tokenizer.get_tokenizer() # Validate chat template for conversation format - if self._source_schema.has_conversation: + if isinstance(self._source_schema, ConversationSourceConfig): self._tokenizer.validate_chat_template() # Decide the datatype based on the tokenizer vocabulary size @@ -220,108 +225,110 @@ def _preprocessing_config(self) -> LanguageModelPreprocessingConfig: ) def _prepare_sample(self, sample: dict[str, typing.Any]) -> LanguageModelSample: - if self._source_schema.has_conversation: - tokens, train_mask = self._tokenizer.tokenize_chat( + token_spans_by_type = collections.defaultdict(list) + image_patches = image_token_maps = image_position_ids = patch_counts = None + + if isinstance(self._source_schema, ConversationSourceConfig): + # Conversation format: tokenize messages and get loss masking spans from chat template + tokens, loss_masking_spans = self._tokenizer.tokenize_chat( sample[self._source_schema.messages], - self._source_schema.add_generation_prompt, + True, + True, data_type=self._data_type, ) - return LanguageModelSample( - TokenSample(tokens, [len(tokens)]), - RangeSample(_mask_to_spans(train_mask), len(tokens)), - None, - None, - None, - ) + token_spans_by_type[SpanType.loss_masking] = loss_masking_spans + elif isinstance(self._source_schema, DocumentSourceConfig): + # Document format: use the text-spans pipeline + text = sample[self._source_schema.text] + all_spans = [] + + if self._source_schema.has_loss_masking_span: + # Spans are typically stored in the (begin, last) format. We convert to (begin, end) range format. + loss_masking_spans = _sort_spans( + (SpanType.loss_masking, (begin, last + 1)) + for begin, last in np.array(sample[self._source_schema.loss_masking_spans], dtype=np.int32) + .reshape(-1, 2) + .tolist() + ) + all_spans.extend(loss_masking_spans) - # Text format: use the text-spans pipeline - text = sample[self._source_schema.text] - all_spans = [] - - if self._source_schema.has_loss_masking_span: - # Spans are typically stored in the (begin, last) format. We convert to (begin, end) range format. - loss_masking_spans = _sort_spans( - (SpanType.loss_masking, (begin, last + 1)) - for begin, last in np.array(sample[self._source_schema.loss_masking_spans], dtype=np.int32) - .reshape(-1, 2) - .tolist() - ) - all_spans.extend(loss_masking_spans) - - if self._source_schema.has_preference_spans: - full_chosen_text = text + sample[self._source_schema.chosen_span] + self._tokenizer.tokenizer.eos_token - full_rejected_text = self._tokenizer.tokenizer.bos_token + text + sample[self._source_schema.rejected_span] - # compute chosen span - chosen_spans = [(SpanType.chosen, (len(text), len(full_chosen_text)))] - - # compute rejected span - rejected_span = [ - ( - SpanType.rejected, - ( - len(full_chosen_text) + len(self._tokenizer.tokenizer.bos_token) + len(text), - len(full_chosen_text) + len(full_rejected_text), - ), + if self._source_schema.has_preference_spans: + full_chosen_text = text + sample[self._source_schema.chosen_span] + self._tokenizer.tokenizer.eos_token + full_rejected_text = ( + self._tokenizer.tokenizer.bos_token + text + sample[self._source_schema.rejected_span] ) - ] - # pack texts - text = full_chosen_text + full_rejected_text - all_spans.extend(chosen_spans + rejected_span) - - if self._source_schema.has_images: - # Get the images and positions, sorted by position. - images, image_positions = ( - zip( - *sorted( - zip( - sample[self._source_schema.images], - sample[self._source_schema.image_positions], - strict=True, + # compute chosen span + chosen_spans = [(SpanType.chosen, (len(text), len(full_chosen_text)))] + + # compute rejected span + rejected_span = [ + ( + SpanType.rejected, + ( + len(full_chosen_text) + len(self._tokenizer.tokenizer.bos_token) + len(text), + len(full_chosen_text) + len(full_rejected_text), ), - key=lambda x: x[1], ) + ] + # pack texts + text = full_chosen_text + full_rejected_text + all_spans.extend(chosen_spans + rejected_span) + + if self._source_schema.has_images: + # Get the images and positions, sorted by position. + images, image_positions = ( + zip( + *sorted( + zip( + sample[self._source_schema.images], + sample[self._source_schema.image_positions], + strict=True, + ), + key=lambda x: x[1], + ) + ) + if len(sample[self._source_schema.images]) > 0 + else ([], []) ) - if len(sample[self._source_schema.images]) > 0 - else ([], []) - ) - # Get the image patches and associated data. - image_patches, image_position_ids, image_token_maps, image_token_ids, patch_counts = ( - self._config.image_patches.get_patches_from_images(images, self._data_type) + # Get the image patches and associated data. + image_patches, image_position_ids, image_token_maps, image_token_ids, patch_counts = ( + self._config.image_patches.get_patches_from_images(images, self._data_type) + ) + patch_count_cumsum = padded_cumsum(patch_counts).tolist() + # Add an empty "span" at each image position so we know where to insert them in the tokenized sequence. + all_spans.extend([(SpanType.image, (position, position)) for position in image_positions]) + + # Sort the spans by location (begin), keeping track of their type. + # Note: overlapping spans are not supported (explicit assertion in the tokenizer). + span_types, spans = zip(*_sort_spans(all_spans)) if all_spans else ([], []) + # Tokenize the text, and determine the span locations in the tokenized text. + tokens, token_spans = self._tokenizer.tokenize_with_spans( + text, True, True, text_spans=spans, data_type=self._data_type ) - patch_count_cumsum = padded_cumsum(patch_counts).tolist() - # Add an empty "span" at each image position so we know where to insert them in the tokenized sequence. - all_spans.extend([(SpanType.image, (position, position)) for position in image_positions]) - - # Sort the spans by location (begin), keeping track of their type. - # Note: overlapping spans are not supported (explicit assertion in the tokenizer). - span_types, spans = zip(*_sort_spans(all_spans)) if all_spans else ([], []) - # Tokenize the text, and determine the span locations in the tokenized text. - tokens, token_spans = self._tokenizer.tokenize_with_spans( - text, True, True, text_spans=spans, data_type=self._data_type - ) - # Gather token spans by type. - token_spans_by_type = collections.defaultdict(list) - if self._source_schema.has_images: - # Insert the image token ids in the token sequence and shift the spans accordingly. - tokens_shift = 0 - image_index = 0 - for span_type, (begin, end) in zip(span_types, token_spans, strict=True): - # Account for the tokens already inserted. - begin = begin + tokens_shift - end = end + tokens_shift - if span_type == SpanType.image: - # Shift the token map to the image location. - image_token_maps[patch_count_cumsum[image_index] : patch_count_cumsum[image_index + 1]] += begin - # Insert the placeholder and image break tokens. - tokens = torch.cat([tokens[:begin], image_token_ids[image_index], tokens[begin:]]) - tokens_shift += len(image_token_ids[image_index]) - image_index += 1 - else: - token_spans_by_type[span_type].append((begin, end)) + # Gather token spans by type. + if self._source_schema.has_images: + # Insert the image token ids in the token sequence and shift the spans accordingly. + tokens_shift = 0 + image_index = 0 + for span_type, (begin, end) in zip(span_types, token_spans, strict=True): + # Account for the tokens already inserted. + begin = begin + tokens_shift + end = end + tokens_shift + if span_type == SpanType.image: + # Shift the token map to the image location. + image_token_maps[patch_count_cumsum[image_index] : patch_count_cumsum[image_index + 1]] += begin + # Insert the placeholder and image break tokens. + tokens = torch.cat([tokens[:begin], image_token_ids[image_index], tokens[begin:]]) + tokens_shift += len(image_token_ids[image_index]) + image_index += 1 + else: + token_spans_by_type[span_type].append((begin, end)) + else: + for span_type, token_span in zip(span_types, token_spans, strict=True): + token_spans_by_type[span_type].append(token_span) else: - for span_type, token_span in zip(span_types, token_spans, strict=True): - token_spans_by_type[span_type].append(token_span) + raise NotImplementedError(f"Unsupported source schema type: {type(self._source_schema)}") sample_size = len(tokens) @@ -501,17 +508,3 @@ def _get_nearest_split(cumsum: np.ndarray, value: float) -> int: return left.item() + 1 if (value - cumsum[left]) / (cumsum[left + 1] - cumsum[left]) > 0.5 else left.item() -def _mask_to_spans(mask: list[bool]) -> list[tuple[int, int]]: - """Convert a boolean train mask to loss masking spans (where mask[i] == False).""" - spans = [] - start = None - for i, value in enumerate(mask): - if not value: - if start is None: - start = i - elif start is not None: - spans.append((start, i)) - start = None - if start is not None: - spans.append((start, len(mask))) - return spans diff --git a/fast_llm/data/preprocessing/tokenizer.py b/fast_llm/data/preprocessing/tokenizer.py index f3b5a51a8..2d27c3853 100644 --- a/fast_llm/data/preprocessing/tokenizer.py +++ b/fast_llm/data/preprocessing/tokenizer.py @@ -245,12 +245,17 @@ def validate_chat_template(self) -> None: def tokenize_chat( self, messages: list[dict[str, str]], - add_generation_prompt: bool = False, begin: bool = True, end: bool = True, data_type: DataType = DataType.int64, - ) -> tuple["torch.Tensor", list[bool]]: - """Apply chat template and return (tokens, train_mask) where train_mask[i]=True means train on token i.""" + ) -> tuple["torch.Tensor", list[tuple[int, int]]]: + """ + Apply chat template and return (tokens, loss_masking_spans). + + The loss_masking_spans mark token ranges to EXCLUDE from training (where the model + should not learn). These are derived from the chat template's generation markers - + tokens outside {% generation %}...{% endgeneration %} blocks are masked. + """ import torch result = self.tokenizer.apply_chat_template( @@ -258,17 +263,22 @@ def tokenize_chat( tokenize=True, return_assistant_tokens_mask=True, return_dict=True, - add_generation_prompt=add_generation_prompt, + add_generation_prompt=False, ) tokens = result["input_ids"] train_mask = result["assistant_masks"] - # Prepend BOS / append EOS if needed (avoid O(n) insert) - prepend_bos = begin and (not tokens or tokens[0] != self.bod_id) - append_eos = end and (not tokens or tokens[-1] != self.eod_id) + # Prepend BOS / append EOS if not already present anywhere in the sequence. + # We check anywhere (not just first/last) because some chat templates add trailing + # whitespace after the final EOS token, e.g. "<|im_end|>\n". + prepend_bos = begin and self.bod_id not in tokens + append_eos = end and self.eod_id not in tokens tokens = [self.bod_id] * prepend_bos + list(tokens) + [self.eod_id] * append_eos train_mask = [False] * prepend_bos + [bool(m) for m in train_mask] + [False] * append_eos + # Convert boolean train mask to loss masking spans (spans where train_mask[i] == False) + loss_masking_spans = _train_mask_to_loss_spans(train_mask) + if self._config.max_vocab_size is not None: tokens = ( torch.tensor( @@ -279,5 +289,30 @@ def tokenize_chat( ).to(data_type.torch) else: tokens = torch.tensor(tokens, dtype=data_type.torch) - return tokens, train_mask + return tokens, loss_masking_spans + + +def _train_mask_to_loss_spans(train_mask: list[bool]) -> list[tuple[int, int]]: + """ + Convert a boolean train mask to loss masking spans. + + Args: + train_mask: Boolean list where True = train on this token, False = don't train + + Returns: + List of (begin, end) spans marking token ranges to EXCLUDE from training + (i.e., where train_mask[i] == False). + """ + spans = [] + start = None + for i, should_train in enumerate(train_mask): + if not should_train: + if start is None: + start = i + elif start is not None: + spans.append((start, i)) + start = None + if start is not None: + spans.append((start, len(train_mask))) + return spans diff --git a/fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml b/fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml index ba85c1aed..34672916c 100644 --- a/fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml +++ b/fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml @@ -81,7 +81,7 @@ dataset: # Source schema for conversation format source_schema: - # Use conversation type (vs default "text" type) + # Use conversation type (vs default "document" type) type: conversation # Column containing the messages list diff --git a/tests/data/test_tokenizer.py b/tests/data/test_tokenizer.py index 97f16c6d6..f8f07ef0f 100644 --- a/tests/data/test_tokenizer.py +++ b/tests/data/test_tokenizer.py @@ -76,15 +76,17 @@ def test_validate_chat_template_with_markers(common_tokenizer): @pytest.mark.parametrize( - ("messages", "expected_tokens", "expected_trainable_indices"), + ("messages", "expected_tokens", "expected_loss_masking_spans"), ( # Single turn: full assistant turn (Hello) is trainable + # 15 tokens, trainable indices 7-13, loss mask spans cover 0-6 and 14 ( [{"role": "user", "content": "Hi"}, {"role": "assistant", "content": "Hello"}], [49152, 27, 789, 29, 16946, 750, 789, 2293, 17822, 29, 7371, 750, 17822, 29, 49152], - [7, 8, 9, 10, 11, 12, 13], + [(0, 7), (14, 15)], ), # Multi-turn: both assistant turns are fully trainable + # 27 tokens, trainable indices 7-13 and 19-25 ( [ {"role": "user", "content": "A"}, @@ -93,9 +95,10 @@ def test_validate_chat_template_with_markers(common_tokenizer): {"role": "assistant", "content": "D"}, ], [49152, 27, 789, 29, 32, 750, 789, 2293, 17822, 29, 33, 750, 17822, 2293, 789, 29, 34, 750, 789, 2293, 17822, 29, 35, 750, 17822, 29, 49152], - [7, 8, 9, 10, 11, 12, 13, 19, 20, 21, 22, 23, 24, 25], + [(0, 7), (14, 19), (26, 27)], ), # System + user + assistant: full assistant turn trainable + # 23 tokens, trainable indices 15-21 ( [ {"role": "system", "content": "You are helpful."}, @@ -103,15 +106,17 @@ def test_validate_chat_template_with_markers(common_tokenizer): {"role": "assistant", "content": "Hello"}, ], [49152, 27, 3144, 29, 5815, 1139, 44569, 6928, 3144, 2293, 789, 29, 16946, 750, 789, 2293, 17822, 29, 7371, 750, 17822, 29, 49152], - [15, 16, 17, 18, 19, 20, 21], + [(0, 15), (22, 23)], ), # User only: no trainable tokens + # 9 tokens, no trainable indices ( [{"role": "user", "content": "Hi"}], [49152, 27, 789, 29, 16946, 750, 789, 29, 49152], - [], + [(0, 9)], ), # Long multi-turn (85 tokens, 3 assistant responses with tags, tests span machinery) + # Trainable: indices 27-40, 49-62, 70-83 ( [ {"role": "system", "content": "You are a helpful assistant that answers questions."}, @@ -123,15 +128,15 @@ def test_validate_chat_template_with_markers(common_tokenizer): {"role": "assistant", "content": "The capital of Italy is Rome."}, ], [49152, 27, 3144, 29, 5815, 1139, 373, 44569, 2424, 11886, 954, 15737, 14516, 6928, 3144, 2293, 789, 29, 13938, 438, 331, 25016, 457, 12409, 562, 35838, 789, 2293, 17822, 29, 2111, 25016, 457, 12409, 562, 438, 4235, 280, 6928, 17822, 2293, 789, 29, 13938, 5028, 759, 42226, 35838, 789, 2293, 17822, 29, 2111, 25016, 457, 759, 42226, 438, 29784, 3556, 6928, 17822, 2293, 789, 29, 1996, 4413, 3326, 35838, 789, 2293, 17822, 29, 2111, 25016, 457, 4413, 3326, 438, 613, 1361, 6928, 17822, 29, 49152], - list(range(27, 41)) + list(range(49, 63)) + list(range(70, 84)), + [(0, 27), (41, 49), (63, 70), (84, 85)], ), ), ) -def test_tokenize_chat(common_tokenizer, messages, expected_tokens, expected_trainable_indices): +def test_tokenize_chat(common_tokenizer, messages, expected_tokens, expected_loss_masking_spans): common_tokenizer.tokenizer.chat_template = CHAT_TEMPLATE - tokens, train_mask = common_tokenizer.tokenize_chat(messages) + tokens, loss_masking_spans = common_tokenizer.tokenize_chat(messages) Assert.eq(tokens.tolist(), expected_tokens) - Assert.eq([i for i, m in enumerate(train_mask) if m], expected_trainable_indices) + Assert.eq(loss_masking_spans, expected_loss_masking_spans) @pytest.mark.parametrize( @@ -153,7 +158,7 @@ def test_tokenize_chat(common_tokenizer, messages, expected_tokens, expected_tra ([False, True, False, True, False], [(0, 1), (2, 3), (4, 5)]), ), ) -def test_mask_to_spans(train_mask, expected_loss_spans): - from fast_llm.data.preparator.gpt_memmap.prepare import _mask_to_spans +def test_train_mask_to_loss_spans(train_mask, expected_loss_spans): + from fast_llm.data.preprocessing.tokenizer import _train_mask_to_loss_spans - Assert.eq(_mask_to_spans(train_mask), expected_loss_spans) + Assert.eq(_train_mask_to_loss_spans(train_mask), expected_loss_spans) From e5172d59b2e9e9c05bf20b64ee814177907484e2 Mon Sep 17 00:00:00 2001 From: Torsten Scholak Date: Mon, 5 Jan 2026 16:55:11 +0000 Subject: [PATCH 13/25] Fix GDN mixer dtype mismatches in Apriel2 model MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix _recurrent_gated_delta_rule tensor shape: transpose from [batch, seq, heads, dim] to [batch, heads, seq, dim] for einsum ops - Fix dtype after g.exp() which returns float32 even with bfloat16 input - Ensure recurrent_state dtype matches hidden_states before/after FLA kernel - Ensure last_recurrent_state converted to initial_dtype when returned These fixes resolve dtype mismatch errors during inference with mixed precision (bfloat16) when using the GDN mixer. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- .../apriel2/modeling_apriel2.py | 29 +++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/fast_llm_external_models/apriel2/modeling_apriel2.py b/fast_llm_external_models/apriel2/modeling_apriel2.py index 878677653..930b158f3 100644 --- a/fast_llm_external_models/apriel2/modeling_apriel2.py +++ b/fast_llm_external_models/apriel2/modeling_apriel2.py @@ -1033,6 +1033,8 @@ def torch_chunk_gated_delta_rule( if not output_final_state: last_recurrent_state = None + elif last_recurrent_state is not None: + last_recurrent_state = last_recurrent_state.to(initial_dtype) core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1]) core_attn_out = core_attn_out[:, :, :sequence_length] core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) @@ -1286,8 +1288,14 @@ def forward(self, hidden_states: torch.Tensor, past_key_values=None, attention_m output_final_state=past_key_values is not None, use_qk_l2norm_in_kernel=True, ) + # Ensure state is in same dtype as hidden_states (fla kernel may return float32) + if last_recurrent_state is not None: + last_recurrent_state = last_recurrent_state.to(hidden_states.dtype) else: # Recurrent mode for single token decode + # Convert recurrent_state to match hidden_states dtype if needed + if recurrent_state is not None and recurrent_state.dtype != hidden_states.dtype: + recurrent_state = recurrent_state.to(hidden_states.dtype) output, last_recurrent_state = self._recurrent_gated_delta_rule( query, key, value, g, beta_gate, recurrent_state ) @@ -1310,7 +1318,16 @@ def forward(self, hidden_states: torch.Tensor, past_key_values=None, attention_m return (output,) def _recurrent_gated_delta_rule(self, query, key, value, g, beta, state): - """Single-step recurrent update for cached inference.""" + """Single-step recurrent update for cached inference. + + Input shapes: [batch, seq=1, heads, dim] + Need shapes: [batch, heads, dim] for einsum operations + """ + # Transpose from [batch, seq, heads, dim] to [batch, heads, seq, dim] + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + # L2 normalize query and key query = _l2norm(query, dim=-1, eps=1e-6) key = _l2norm(key, dim=-1, eps=1e-6) @@ -1323,7 +1340,9 @@ def _recurrent_gated_delta_rule(self, query, key, value, g, beta, state): beta = beta.squeeze(1) # Update state: S = exp(g) * S + beta * k^T @ v - decay = g.exp().unsqueeze(-1).unsqueeze(-1) # [batch, heads, 1, 1] + # Keep everything in the same dtype as input (exp() returns float32, need to convert back) + input_dtype = query.dtype + decay = g.exp().to(input_dtype).unsqueeze(-1).unsqueeze(-1) # [batch, heads, 1, 1] k_outer_v = torch.einsum("bhk,bhv->bhkv", key * beta.unsqueeze(-1), value) state = decay * state + k_outer_v @@ -1331,6 +1350,12 @@ def _recurrent_gated_delta_rule(self, query, key, value, g, beta, state): output = torch.einsum("bhk,bhkv->bhv", query, state) output = output.unsqueeze(2) # [batch, heads, 1, v_dim] + # Transpose back to [batch, seq=1, heads, v_dim] + output = output.transpose(1, 2) + + # Ensure state matches output dtype + state = state.to(output.dtype) + return output, state @classmethod From ef990a5fb73591d4b6bda8f047397f865694be06 Mon Sep 17 00:00:00 2001 From: Torsten Scholak Date: Mon, 5 Jan 2026 17:02:10 +0000 Subject: [PATCH 14/25] Run code formatters (black, isort, autoflake, pyupgrade) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Apply automatic formatting fixes: - black: code style formatting - isort: import sorting - autoflake: remove unused imports - pyupgrade: modernize Python syntax 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- .github/ISSUE_TEMPLATE/feature_request.md | 20 +- .../data/preparator/gpt_memmap/prepare.py | 8 +- fast_llm/data/preprocessing/tokenizer.py | 1 - fast_llm/models/gpt/conversion/apriel2.py | 68 ++-- .../models/multimodal/conversion/apriel2.py | 9 +- fast_llm_external_models/apriel2/cache.py | 1 + .../apriel2/conversion/__init__.py | 38 +- .../apriel2/conversion/converters.py | 380 +++++++++--------- .../apriel2/conversion/executor.py | 11 +- .../apriel2/conversion/expr.py | 22 +- .../apriel2/conversion/io.py | 7 +- .../apriel2/conversion/llava/plan.py | 7 +- .../apriel2/conversion/qwen2/plan.py | 15 +- .../apriel2/conversion/render.py | 28 +- fast_llm_external_models/apriel2/convert.py | 31 +- .../apriel2/modeling_apriel2.py | 8 +- .../tests/test_apriel2/conftest.py | 32 +- .../test_cache_apriel2_specific.py | 3 +- .../test_apriel2/test_cache_contracts.py | 25 +- .../tests/test_apriel2/test_causal_conv1d.py | 23 +- .../test_apriel2/test_compose_configs.py | 10 +- .../tests/test_apriel2/test_conversion_e2e.py | 131 ++---- .../test_apriel2/test_convert_from_llava.py | 18 +- .../tests/test_apriel2/test_equivalence.py | 15 +- .../tests/test_apriel2/test_expr_plan.py | 291 ++++++++------ .../tests/test_apriel2/test_integration.py | 28 +- .../test_apriel2/test_mixer_equivalence.py | 108 +++-- .../test_apriel2/test_model_structure.py | 69 ++-- .../tests/test_apriel2/test_modeling.py | 59 ++- .../tests/test_apriel2/test_plan_execution.py | 69 ++-- setup.py | 6 +- tests/data/test_tokenizer.py | 144 ++++++- 32 files changed, 898 insertions(+), 787 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md index 50c5a2c1c..a09f78c6c 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.md +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -8,26 +8,26 @@ assignees: '' --- # 🎯 **Goal (What & Why)** -> **Clearly state the purpose of this feature.** +> **Clearly state the purpose of this feature.** > _(Example: Add FP8 support using torchao to improve training throughput by 1.5x.)_ # 🚀 **Execution Plan** -> _(This section may start as an incomplete draft but must be defined before implementation begins.)_ +> _(This section may start as an incomplete draft but must be defined before implementation begins.)_ ### **Step 1: What is the smallest working version?** -> _(Describe the simplest way to implement this feature with minimal effort.)_ +> _(Describe the simplest way to implement this feature with minimal effort.)_ -### **Step 2: What additional optimizations are possible (but optional)?** -> _(List potential refinements that can be added in later PRs if needed.)_ +### **Step 2: What additional optimizations are possible (but optional)?** +> _(List potential refinements that can be added in later PRs if needed.)_ # 📌 **Acceptance Criteria** (Must-Haves for Completion) -* The feature must be **functional and tested**. -* The implementation must be **documented in practical terms**. -* The PR must include a **performance/impact summary**. -* **No refactors unless directly necessary** for feature completion. +* The feature must be **functional and tested**. +* The implementation must be **documented in practical terms**. +* The PR must include a **performance/impact summary**. +* **No refactors unless directly necessary** for feature completion. # 🛠️ **Project Management** - [ ] **Assign the project to the Fast-LLM project.** - [ ] **Set the `Estimate` field (in days) in the GitHub project.** - [ ] **Use the `Size` field to categorize the PR size (Small/Medium/Large).** -- [ ] **Assign an owner when opening the issue.** +- [ ] **Assign an owner when opening the issue.** diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 285e36d22..325d33c43 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -30,9 +30,9 @@ from fast_llm.data.preparator.config import DatasetPreparator from fast_llm.data.preparator.gpt_memmap.config import ( ConversationSourceConfig, + DocumentSourceConfig, GPTMemmapDatasetPreparatorConfig, LanguageModelSourceConfig, - DocumentSourceConfig, ) from fast_llm.data.preprocessing.abstract import NullPreprocessingConfig from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig @@ -317,7 +317,9 @@ def _prepare_sample(self, sample: dict[str, typing.Any]) -> LanguageModelSample: end = end + tokens_shift if span_type == SpanType.image: # Shift the token map to the image location. - image_token_maps[patch_count_cumsum[image_index] : patch_count_cumsum[image_index + 1]] += begin + image_token_maps[ + patch_count_cumsum[image_index] : patch_count_cumsum[image_index + 1] + ] += begin # Insert the placeholder and image break tokens. tokens = torch.cat([tokens[:begin], image_token_ids[image_index], tokens[begin:]]) tokens_shift += len(image_token_ids[image_index]) @@ -509,5 +511,3 @@ def _get_nearest_split(cumsum: np.ndarray, value: float) -> int: if left == len(cumsum): return left.item() return left.item() + 1 if (value - cumsum[left]) / (cumsum[left + 1] - cumsum[left]) > 0.5 else left.item() - - diff --git a/fast_llm/data/preprocessing/tokenizer.py b/fast_llm/data/preprocessing/tokenizer.py index 2d27c3853..157744f51 100644 --- a/fast_llm/data/preprocessing/tokenizer.py +++ b/fast_llm/data/preprocessing/tokenizer.py @@ -315,4 +315,3 @@ def _train_mask_to_loss_spans(train_mask: list[bool]) -> list[tuple[int, int]]: if start is not None: spans.append((start, len(train_mask))) return spans - diff --git a/fast_llm/models/gpt/conversion/apriel2.py b/fast_llm/models/gpt/conversion/apriel2.py index dc2d4b4ad..91e3be508 100644 --- a/fast_llm/models/gpt/conversion/apriel2.py +++ b/fast_llm/models/gpt/conversion/apriel2.py @@ -686,41 +686,45 @@ def get_mlp_layer_bias(layer_config, default: bool) -> bool: if config.mlp.gated: # Gated MLP: gate_proj + up_proj -> layer_1 (split), down_proj -> layer_2 - converters.extend([ - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_1", - (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), - layer_1_bias, - SplitWeightConverter, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_2", - f"{hf_prefix}.mlp.down_proj", - layer_2_bias, - MLPLayer2Converter, - drop_on_export=drop_on_export, - ), - ]) + converters.extend( + [ + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_1", + (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), + layer_1_bias, + SplitWeightConverter, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_2", + f"{hf_prefix}.mlp.down_proj", + layer_2_bias, + MLPLayer2Converter, + drop_on_export=drop_on_export, + ), + ] + ) else: # Non-gated MLP: up_proj -> layer_1, down_proj -> layer_2 # Note: layer_2 still needs MLPLayer2Converter for the transpose - converters.extend([ - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_1", - f"{hf_prefix}.mlp.up_proj", - layer_1_bias, - WeightConverter, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_2", - f"{hf_prefix}.mlp.down_proj", - layer_2_bias, - MLPLayer2Converter, - drop_on_export=drop_on_export, - ), - ]) + converters.extend( + [ + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_1", + f"{hf_prefix}.mlp.up_proj", + layer_1_bias, + WeightConverter, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_2", + f"{hf_prefix}.mlp.down_proj", + layer_2_bias, + MLPLayer2Converter, + drop_on_export=drop_on_export, + ), + ] + ) converters.extend( [ diff --git a/fast_llm/models/multimodal/conversion/apriel2.py b/fast_llm/models/multimodal/conversion/apriel2.py index b4147a8bf..307a67c63 100644 --- a/fast_llm/models/multimodal/conversion/apriel2.py +++ b/fast_llm/models/multimodal/conversion/apriel2.py @@ -326,9 +326,7 @@ class Apriel2MultimodalBaseModelConverter: @classmethod def import_config(cls, config: dict) -> dict: text_config = Apriel2BaseModelConverter.import_config(config) - vision_config = ( - cls.vision_model_converter_class.import_config(config) if "vision_encoder" in config else None - ) + vision_config = cls.vision_model_converter_class.import_config(config) if "vision_encoder" in config else None result = safe_merge_dicts( text_config, @@ -388,10 +386,7 @@ def get_transformers_configuration_class(cls): @classmethod def get_model_files(cls) -> tuple[str, str, str | None]: - from fast_llm_external_models.apriel2 import ( - configuration_apriel2, - modeling_apriel2, - ) + from fast_llm_external_models.apriel2 import configuration_apriel2, modeling_apriel2 return configuration_apriel2.__file__, modeling_apriel2.__file__, None diff --git a/fast_llm_external_models/apriel2/cache.py b/fast_llm_external_models/apriel2/cache.py index 32db547b9..f83ae87d6 100644 --- a/fast_llm_external_models/apriel2/cache.py +++ b/fast_llm_external_models/apriel2/cache.py @@ -1,4 +1,5 @@ from __future__ import annotations + import torch from transformers.cache_utils import Cache diff --git a/fast_llm_external_models/apriel2/conversion/__init__.py b/fast_llm_external_models/apriel2/conversion/__init__.py index c6bad6626..2c28d1e87 100644 --- a/fast_llm_external_models/apriel2/conversion/__init__.py +++ b/fast_llm_external_models/apriel2/conversion/__init__.py @@ -119,6 +119,20 @@ - ``io.py`` - Streaming I/O (SafetensorLoader, ShardedSafetensorWriter) """ +# Config composition +from fast_llm_external_models.apriel2.conversion.config import compose_configs, strip_init_fields + +# Plan builders (generic) +from fast_llm_external_models.apriel2.conversion.converters import ( + plan_dil_attention_to_gdn, + plan_kil_attention_to_kda, + plan_mil_attention_to_mamba, + plan_surgery, +) + +# Execution +from fast_llm_external_models.apriel2.conversion.executor import MAX_SEED, StreamingExecutor, execute + # Core types and plan operations from fast_llm_external_models.apriel2.conversion.expr import ( Concat, @@ -140,13 +154,6 @@ substitute, ) -# Execution -from fast_llm_external_models.apriel2.conversion.executor import ( - MAX_SEED, - StreamingExecutor, - execute, -) - # I/O utilities from fast_llm_external_models.apriel2.conversion.io import ( DEFAULT_MAX_SHARD_SIZE, @@ -154,22 +161,9 @@ ShardedSafetensorWriter, ) -# Plan builders (generic) -from fast_llm_external_models.apriel2.conversion.converters import ( - plan_mil_attention_to_mamba, - plan_dil_attention_to_gdn, - plan_kil_attention_to_kda, - plan_surgery, -) - -# Config composition -from fast_llm_external_models.apriel2.conversion.config import compose_configs, strip_init_fields - # Source-specific converters -from fast_llm_external_models.apriel2.conversion.llava import ( - convert_config as convert_llava_config, - plan_llava_to_apriel2, -) +from fast_llm_external_models.apriel2.conversion.llava import convert_config as convert_llava_config +from fast_llm_external_models.apriel2.conversion.llava import plan_llava_to_apriel2 # Rendering (optional, imported lazily by ExprPlan.render_tree) # from fast_llm_external_models.apriel2.conversion.render import render_tree diff --git a/fast_llm_external_models/apriel2/conversion/converters.py b/fast_llm_external_models/apriel2/conversion/converters.py index c8b83f657..9c9238bb0 100644 --- a/fast_llm_external_models/apriel2/conversion/converters.py +++ b/fast_llm_external_models/apriel2/conversion/converters.py @@ -61,16 +61,7 @@ from __future__ import annotations -from fast_llm_external_models.apriel2.conversion.expr import ( - Concat, - Expr, - ExprPlan, - Init, - Ref, - Slice, - W, -) - +from fast_llm_external_models.apriel2.conversion.expr import Concat, Expr, ExprPlan, Init, Ref, Slice, W # ============================================================================= # SECTION 1: Per-Mixer Plan Functions @@ -195,20 +186,22 @@ def _plan_mamba_mixer( """ if source_prefix is not None: # Passthrough - include all possible weights - return ExprPlan(mappings={ - prefix / name: Ref(key=source_prefix / name) - for name in [ - "in_proj.weight", - "out_proj.weight", - "dt_in_proj.weight", - "dt_proj.weight", - "dt_proj.bias", - "conv1d.weight", - "conv1d.bias", - "A_log", - "D", - ] - }) + return ExprPlan( + mappings={ + prefix / name: Ref(key=source_prefix / name) + for name in [ + "in_proj.weight", + "out_proj.weight", + "dt_in_proj.weight", + "dt_proj.weight", + "dt_proj.bias", + "conv1d.weight", + "conv1d.bias", + "A_log", + "D", + ] + } + ) # Random init d_inner = config["d_inner"] @@ -226,9 +219,7 @@ def _plan_mamba_mixer( conv_channels = d_inner if repeat_kv_before_conv else d_xb mappings: dict[W, Expr] = { - prefix / "in_proj" / "weight": Init( - shape=(2 * d_inner + 2 * d_xb, hidden_size), init_type="kaiming" - ), + prefix / "in_proj" / "weight": Init(shape=(2 * d_inner + 2 * d_xb, hidden_size), init_type="kaiming"), prefix / "out_proj" / "weight": Init(shape=(hidden_size, d_inner), init_type="kaiming"), prefix / "dt_in_proj" / "weight": Init(shape=(dt_rank, hidden_size), init_type="kaiming"), prefix / "dt_proj" / "weight": Init(shape=(d_inner, dt_rank), init_type="kaiming"), @@ -275,18 +266,20 @@ def _plan_gdn_mixer( """ if source_prefix is not None: # Passthrough - return ExprPlan(mappings={ - prefix / name: Ref(key=source_prefix / name) - for name in [ - "in_proj_qkvz.weight", - "in_proj_ba.weight", - "out_proj.weight", - "convolution.weight", - "A_log", - "dt_bias", - "norm.weight", - ] - }) + return ExprPlan( + mappings={ + prefix / name: Ref(key=source_prefix / name) + for name in [ + "in_proj_qkvz.weight", + "in_proj_ba.weight", + "out_proj.weight", + "convolution.weight", + "A_log", + "dt_bias", + "norm.weight", + ] + } + ) # Random init num_v_heads = config["value_heads"] @@ -300,17 +293,19 @@ def _plan_gdn_mixer( conv_dim = key_dim * 2 + value_dim qkvz_size = key_dim * 2 + value_dim * 2 # Q, K both key_dim; V, Z both value_dim - return ExprPlan(mappings={ - prefix / "in_proj_qkvz" / "weight": Init(shape=(qkvz_size, hidden_size), init_type="kaiming"), - prefix / "in_proj_ba" / "weight": Init(shape=(num_v_heads * 2, hidden_size), init_type="zeros"), - prefix / "out_proj" / "weight": Init(shape=(hidden_size, value_dim), init_type="kaiming"), - prefix / "convolution" / "weight": Init( - shape=(conv_dim, 1, conv_kernel_size), init_type="scaled_identity_conv" - ), - prefix / "A_log": Init(shape=(num_v_heads,), init_type="slow_decay"), - prefix / "dt_bias": Init(shape=(num_v_heads,), init_type="zeros"), - prefix / "norm" / "weight": Init(shape=(head_v_dim,), init_type="ones"), - }) + return ExprPlan( + mappings={ + prefix / "in_proj_qkvz" / "weight": Init(shape=(qkvz_size, hidden_size), init_type="kaiming"), + prefix / "in_proj_ba" / "weight": Init(shape=(num_v_heads * 2, hidden_size), init_type="zeros"), + prefix / "out_proj" / "weight": Init(shape=(hidden_size, value_dim), init_type="kaiming"), + prefix + / "convolution" + / "weight": Init(shape=(conv_dim, 1, conv_kernel_size), init_type="scaled_identity_conv"), + prefix / "A_log": Init(shape=(num_v_heads,), init_type="slow_decay"), + prefix / "dt_bias": Init(shape=(num_v_heads,), init_type="zeros"), + prefix / "norm" / "weight": Init(shape=(head_v_dim,), init_type="ones"), + } + ) def _plan_kda_mixer( @@ -343,26 +338,28 @@ def _plan_kda_mixer( """ if source_prefix is not None: # Passthrough - return ExprPlan(mappings={ - prefix / name: Ref(key=source_prefix / name) - for name in [ - "q_proj.weight", - "k_proj.weight", - "v_proj.weight", - "o_proj.weight", - "q_conv.weight", - "k_conv.weight", - "v_conv.weight", - "f_a_proj.weight", - "f_b_proj.weight", - "g_a_proj.weight", - "g_b_proj.weight", - "beta_proj.weight", - "A_log", - "dt_bias", - "norm.weight", - ] - }) + return ExprPlan( + mappings={ + prefix / name: Ref(key=source_prefix / name) + for name in [ + "q_proj.weight", + "k_proj.weight", + "v_proj.weight", + "o_proj.weight", + "q_conv.weight", + "k_conv.weight", + "v_conv.weight", + "f_a_proj.weight", + "f_b_proj.weight", + "g_a_proj.weight", + "g_b_proj.weight", + "beta_proj.weight", + "A_log", + "dt_bias", + "norm.weight", + ] + } + ) # Random init num_heads = config["heads"] @@ -370,36 +367,38 @@ def _plan_kda_mixer( projection_size = num_heads * head_dim conv_kernel_size = config.get("convolution_layer", {}).get("kernel_size", 4) - return ExprPlan(mappings={ - # Main projections - prefix / "q_proj" / "weight": Init(shape=(projection_size, hidden_size), init_type="kaiming"), - prefix / "k_proj" / "weight": Init(shape=(projection_size, hidden_size), init_type="kaiming"), - prefix / "v_proj" / "weight": Init(shape=(projection_size, hidden_size), init_type="kaiming"), - prefix / "o_proj" / "weight": Init(shape=(hidden_size, projection_size), init_type="kaiming"), - # Convolutions - prefix / "q_conv" / "weight": Init( - shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv" - ), - prefix / "k_conv" / "weight": Init( - shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv" - ), - prefix / "v_conv" / "weight": Init( - shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv" - ), - # Gate kernels (low-rank factorization) - prefix / "f_a_proj" / "weight": Init(shape=(head_dim, hidden_size), init_type="kaiming"), - prefix / "f_b_proj" / "weight": Init(shape=(projection_size, head_dim), init_type="kaiming"), - # Output gate (low-rank factorization) - prefix / "g_a_proj" / "weight": Init(shape=(head_dim, hidden_size), init_type="kaiming"), - prefix / "g_b_proj" / "weight": Init(shape=(projection_size, head_dim), init_type="kaiming"), - # Beta projection - prefix / "beta_proj" / "weight": Init(shape=(num_heads, hidden_size), init_type="kaiming"), - # Learnable parameters - prefix / "A_log": Init(shape=(num_heads,), init_type="slow_decay"), - prefix / "dt_bias": Init(shape=(projection_size,), init_type="zeros"), - # Normalization - prefix / "norm" / "weight": Init(shape=(head_dim,), init_type="ones"), - }) + return ExprPlan( + mappings={ + # Main projections + prefix / "q_proj" / "weight": Init(shape=(projection_size, hidden_size), init_type="kaiming"), + prefix / "k_proj" / "weight": Init(shape=(projection_size, hidden_size), init_type="kaiming"), + prefix / "v_proj" / "weight": Init(shape=(projection_size, hidden_size), init_type="kaiming"), + prefix / "o_proj" / "weight": Init(shape=(hidden_size, projection_size), init_type="kaiming"), + # Convolutions + prefix + / "q_conv" + / "weight": Init(shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv"), + prefix + / "k_conv" + / "weight": Init(shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv"), + prefix + / "v_conv" + / "weight": Init(shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv"), + # Gate kernels (low-rank factorization) + prefix / "f_a_proj" / "weight": Init(shape=(head_dim, hidden_size), init_type="kaiming"), + prefix / "f_b_proj" / "weight": Init(shape=(projection_size, head_dim), init_type="kaiming"), + # Output gate (low-rank factorization) + prefix / "g_a_proj" / "weight": Init(shape=(head_dim, hidden_size), init_type="kaiming"), + prefix / "g_b_proj" / "weight": Init(shape=(projection_size, head_dim), init_type="kaiming"), + # Beta projection + prefix / "beta_proj" / "weight": Init(shape=(num_heads, hidden_size), init_type="kaiming"), + # Learnable parameters + prefix / "A_log": Init(shape=(num_heads,), init_type="slow_decay"), + prefix / "dt_bias": Init(shape=(projection_size,), init_type="zeros"), + # Normalization + prefix / "norm" / "weight": Init(shape=(head_dim,), init_type="ones"), + } + ) # Dispatcher for per-mixer plan functions @@ -454,16 +453,13 @@ def plan_mil_attention_to_mamba( exprs=( Init(shape=(d_inner, hidden_size), init_type="kaiming"), # z: random Slice( - expr=Ref(key=source_prefix / "v_proj" / "weight"), - slices=((0, d_xb, None), (None, None, None)) + expr=Ref(key=source_prefix / "v_proj" / "weight"), slices=((0, d_xb, None), (None, None, None)) ), # x <- V Slice( - expr=Ref(key=source_prefix / "k_proj" / "weight"), - slices=((0, d_xb, None), (None, None, None)) + expr=Ref(key=source_prefix / "k_proj" / "weight"), slices=((0, d_xb, None), (None, None, None)) ), # B <- K Slice( - expr=Ref(key=source_prefix / "q_proj" / "weight"), - slices=((0, d_inner, None), (None, None, None)) + expr=Ref(key=source_prefix / "q_proj" / "weight"), slices=((0, d_inner, None), (None, None, None)) ), # C <- Q ), dim=0, @@ -577,19 +573,21 @@ def plan_dil_attention_to_gdn( dim=0, ) - return ExprPlan(mappings={ - target_prefix / "in_proj_qkvz" / "weight": in_proj_qkvz_expr, - target_prefix / "in_proj_ba" / "weight": Init( - shape=(2 * num_v_heads, hidden_size), init_type="zeros" - ), # b=a=0 → β=0.5 - target_prefix / "out_proj" / "weight": Ref(key=source_prefix / "o_proj" / "weight"), - target_prefix / "convolution" / "weight": Init( - shape=(conv_dim, 1, conv_kernel_size), init_type="scaled_identity_conv" - ), - target_prefix / "A_log": Init(shape=(num_v_heads,), init_type="slow_decay"), - target_prefix / "dt_bias": Init(shape=(num_v_heads,), init_type="zeros"), - target_prefix / "norm" / "weight": Init(shape=(head_v_dim,), init_type="ones"), - }) + return ExprPlan( + mappings={ + target_prefix / "in_proj_qkvz" / "weight": in_proj_qkvz_expr, + target_prefix + / "in_proj_ba" + / "weight": Init(shape=(2 * num_v_heads, hidden_size), init_type="zeros"), # b=a=0 → β=0.5 + target_prefix / "out_proj" / "weight": Ref(key=source_prefix / "o_proj" / "weight"), + target_prefix + / "convolution" + / "weight": Init(shape=(conv_dim, 1, conv_kernel_size), init_type="scaled_identity_conv"), + target_prefix / "A_log": Init(shape=(num_v_heads,), init_type="slow_decay"), + target_prefix / "dt_bias": Init(shape=(num_v_heads,), init_type="zeros"), + target_prefix / "norm" / "weight": Init(shape=(head_v_dim,), init_type="ones"), + } + ) def plan_kil_attention_to_kda( @@ -640,9 +638,7 @@ def plan_kil_attention_to_kda( for h in range(num_heads): src_h = h % source_num_q_heads row_start = src_h * source_head_dim - q_slices.append( - Slice(expr=q_ref, slices=((row_start, row_start + head_dim, None), (None, None, None))) - ) + q_slices.append(Slice(expr=q_ref, slices=((row_start, row_start + head_dim, None), (None, None, None)))) q_expr = Concat(exprs=tuple(q_slices), dim=0) # K: tile source KV heads to fill target projection_size @@ -653,9 +649,7 @@ def plan_kil_attention_to_kda( for h in range(num_heads): src_h = h % source_num_kv_heads row_start = src_h * source_head_dim - k_slices.append( - Slice(expr=k_ref, slices=((row_start, row_start + head_dim, None), (None, None, None))) - ) + k_slices.append(Slice(expr=k_ref, slices=((row_start, row_start + head_dim, None), (None, None, None)))) k_expr = Concat(exprs=tuple(k_slices), dim=0) # V: tile source KV heads to fill target projection_size @@ -666,41 +660,41 @@ def plan_kil_attention_to_kda( for h in range(num_heads): src_h = h % source_num_kv_heads row_start = src_h * source_head_dim - v_slices.append( - Slice(expr=v_ref, slices=((row_start, row_start + head_dim, None), (None, None, None))) - ) + v_slices.append(Slice(expr=v_ref, slices=((row_start, row_start + head_dim, None), (None, None, None)))) v_expr = Concat(exprs=tuple(v_slices), dim=0) - return ExprPlan(mappings={ - # Transfer main projections - target_prefix / "q_proj" / "weight": q_expr, - target_prefix / "k_proj" / "weight": k_expr, - target_prefix / "v_proj" / "weight": v_expr, - target_prefix / "o_proj" / "weight": Ref(key=source_prefix / "o_proj" / "weight"), - # Random init: convolutions (scaled identity for near-passthrough initially) - target_prefix / "q_conv" / "weight": Init( - shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv" - ), - target_prefix / "k_conv" / "weight": Init( - shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv" - ), - target_prefix / "v_conv" / "weight": Init( - shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv" - ), - # Random init: gate kernels (low-rank factorization) - target_prefix / "f_a_proj" / "weight": Init(shape=(head_dim, hidden_size), init_type="kaiming"), - target_prefix / "f_b_proj" / "weight": Init(shape=(projection_size, head_dim), init_type="kaiming"), - # Random init: output gate (low-rank factorization) - target_prefix / "g_a_proj" / "weight": Init(shape=(head_dim, hidden_size), init_type="kaiming"), - target_prefix / "g_b_proj" / "weight": Init(shape=(projection_size, head_dim), init_type="kaiming"), - # Random init: beta projection - target_prefix / "beta_proj" / "weight": Init(shape=(num_heads, hidden_size), init_type="kaiming"), - # Random init: learnable parameters - target_prefix / "A_log": Init(shape=(num_heads,), init_type="slow_decay"), - target_prefix / "dt_bias": Init(shape=(projection_size,), init_type="zeros"), - # Random init: normalization - target_prefix / "norm" / "weight": Init(shape=(head_dim,), init_type="ones"), - }) + return ExprPlan( + mappings={ + # Transfer main projections + target_prefix / "q_proj" / "weight": q_expr, + target_prefix / "k_proj" / "weight": k_expr, + target_prefix / "v_proj" / "weight": v_expr, + target_prefix / "o_proj" / "weight": Ref(key=source_prefix / "o_proj" / "weight"), + # Random init: convolutions (scaled identity for near-passthrough initially) + target_prefix + / "q_conv" + / "weight": Init(shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv"), + target_prefix + / "k_conv" + / "weight": Init(shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv"), + target_prefix + / "v_conv" + / "weight": Init(shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv"), + # Random init: gate kernels (low-rank factorization) + target_prefix / "f_a_proj" / "weight": Init(shape=(head_dim, hidden_size), init_type="kaiming"), + target_prefix / "f_b_proj" / "weight": Init(shape=(projection_size, head_dim), init_type="kaiming"), + # Random init: output gate (low-rank factorization) + target_prefix / "g_a_proj" / "weight": Init(shape=(head_dim, hidden_size), init_type="kaiming"), + target_prefix / "g_b_proj" / "weight": Init(shape=(projection_size, head_dim), init_type="kaiming"), + # Random init: beta projection + target_prefix / "beta_proj" / "weight": Init(shape=(num_heads, hidden_size), init_type="kaiming"), + # Random init: learnable parameters + target_prefix / "A_log": Init(shape=(num_heads,), init_type="slow_decay"), + target_prefix / "dt_bias": Init(shape=(projection_size,), init_type="zeros"), + # Random init: normalization + target_prefix / "norm" / "weight": Init(shape=(head_dim,), init_type="ones"), + } + ) # ============================================================================= @@ -912,18 +906,24 @@ def plan_surgery( target_block = _get_block_config(target_decoder, target_layer_idx) plan += _plan_mixer( - target_layer_idx, source_layer_idx, - source_block.get("mixer", {}), target_block.get("mixer", {}), + target_layer_idx, + source_layer_idx, + source_block.get("mixer", {}), + target_block.get("mixer", {}), hidden_size, ) plan += _plan_mlp( - target_layer_idx, source_layer_idx, - source_block.get("mlp", {}), target_block.get("mlp", {}), + target_layer_idx, + source_layer_idx, + source_block.get("mlp", {}), + target_block.get("mlp", {}), hidden_size, ) plan += _plan_norms( - target_layer_idx, source_layer_idx, - source_block, target_block, + target_layer_idx, + source_layer_idx, + source_block, + target_block, hidden_size, ) @@ -1060,9 +1060,13 @@ def _plan_mixer( source_prefix = source_mixer_base plan += _plan_mixer_transfer( - matched_source_type, sub_type, - matched_source, sub_config, - source_prefix, target_prefix, hidden_size, + matched_source_type, + sub_type, + matched_source, + sub_config, + source_prefix, + target_prefix, + hidden_size, ) # Passthrough source sub-mixers not in target spec @@ -1073,8 +1077,13 @@ def _plan_mixer( source_prefix = source_layer / "mixer" / "mixers" / sub_name target_prefix = target_layer / "mixer" / "mixers" / sub_name plan += _plan_mixer_transfer( - sub_type, sub_type, sub_config, sub_config, - source_prefix, target_prefix, hidden_size, + sub_type, + sub_type, + sub_config, + sub_config, + source_prefix, + target_prefix, + hidden_size, ) return plan @@ -1090,9 +1099,13 @@ def _plan_mixer( source_prefix = source_layer / "mixer" return _plan_mixer_transfer( - main_source_type, target_type, - main_source, target_mixer, - source_prefix, target_prefix, hidden_size, + main_source_type, + target_type, + main_source, + target_mixer, + source_prefix, + target_prefix, + hidden_size, ) @@ -1163,8 +1176,7 @@ def _plan_mlp_transfer( weight_projs = ["up_proj", "down_proj"] mappings: dict[W, Expr] = { - target_mlp_path / proj / "weight": Ref(key=source_mlp_path / proj / "weight") - for proj in weight_projs + target_mlp_path / proj / "weight": Ref(key=source_mlp_path / proj / "weight") for proj in weight_projs } # Passthrough biases if enabled @@ -1259,10 +1271,12 @@ def _plan_norms_transfer( f"Use 'init: random' to initialize randomly." ) - return ExprPlan(mappings={ - target_layer / norm_name / "weight": Ref(key=source_layer / norm_name / "weight") - for norm_name in ["input_layernorm", "post_attention_layernorm"] - }) + return ExprPlan( + mappings={ + target_layer / norm_name / "weight": Ref(key=source_layer / norm_name / "weight") + for norm_name in ["input_layernorm", "post_attention_layernorm"] + } + ) def _plan_random_norms( @@ -1271,7 +1285,9 @@ def _plan_random_norms( ) -> ExprPlan: """Random initialization for normalization layers.""" target_layer = W("model", "decoder", "blocks", target_layer_idx) - return ExprPlan(mappings={ - target_layer / norm_name / "weight": Init(shape=(hidden_size,), init_type="ones") - for norm_name in ["input_layernorm", "post_attention_layernorm"] - }) + return ExprPlan( + mappings={ + target_layer / norm_name / "weight": Init(shape=(hidden_size,), init_type="ones") + for norm_name in ["input_layernorm", "post_attention_layernorm"] + } + ) diff --git a/fast_llm_external_models/apriel2/conversion/executor.py b/fast_llm_external_models/apriel2/conversion/executor.py index a6c5672f0..b0779c97f 100644 --- a/fast_llm_external_models/apriel2/conversion/executor.py +++ b/fast_llm_external_models/apriel2/conversion/executor.py @@ -29,7 +29,8 @@ from __future__ import annotations import hashlib -from typing import Callable, Iterator +from collections.abc import Iterator +from typing import Callable import torch from torch import Tensor @@ -81,8 +82,7 @@ def execute( break else: raise ValueError( - "Cannot infer device/dtype: plan has no source references. " - "Provide device and dtype explicitly." + "Cannot infer device/dtype: plan has no source references. " "Provide device and dtype explicitly." ) generator = torch.Generator(device=device) @@ -94,10 +94,7 @@ def execute( # Verify device/dtype consistency for key, tensor in sources.items(): if tensor.device != device or tensor.dtype != dtype: - raise ValueError( - f"Source {key} has {tensor.device}/{tensor.dtype}, " - f"expected {device}/{dtype}" - ) + raise ValueError(f"Source {key} has {tensor.device}/{tensor.dtype}, " f"expected {device}/{dtype}") # Deterministic per-target seed key_offset = int(hashlib.md5(str(target_key).encode()).hexdigest()[:8], 16) diff --git a/fast_llm_external_models/apriel2/conversion/expr.py b/fast_llm_external_models/apriel2/conversion/expr.py index 4867a27ae..34ea106fc 100644 --- a/fast_llm_external_models/apriel2/conversion/expr.py +++ b/fast_llm_external_models/apriel2/conversion/expr.py @@ -52,7 +52,8 @@ import math from collections import defaultdict -from typing import Annotated, Any, Callable, Iterator, Literal, TypedDict, Union, Unpack +from collections.abc import Iterator +from typing import Annotated, Any, Callable, Literal, TypedDict, Union, Unpack import torch from pydantic import BaseModel, ConfigDict, Field, GetCoreSchemaHandler, TypeAdapter @@ -60,7 +61,6 @@ from pydantic_core import CoreSchema, core_schema from torch import Tensor - # ============================================================================= # Weight Path Builder # ============================================================================= @@ -78,7 +78,7 @@ class W(str): mappings[q] = Ref(key=source_q) """ - def __new__(cls, *parts) -> "W": + def __new__(cls, *parts) -> W: # Join parts, stripping any leading/trailing dots from each cleaned = [] for p in parts: @@ -89,12 +89,12 @@ def __new__(cls, *parts) -> "W": cleaned.append(s) return super().__new__(cls, ".".join(cleaned)) - def __truediv__(self, other) -> "W": + def __truediv__(self, other) -> W: if isinstance(other, (list, tuple)): return W(self, *other) return W(self, other) - def __rtruediv__(self, other) -> "W": + def __rtruediv__(self, other) -> W: return W(other, self) @classmethod @@ -156,7 +156,7 @@ class Slice(BaseModel): model_config = ConfigDict(frozen=True) type: Literal["slice"] = "slice" - expr: "Expr" + expr: Expr slices: tuple[tuple[int | None, int | None, int | None], ...] def find_refs(self) -> set[W]: @@ -184,7 +184,7 @@ class Concat(BaseModel): model_config = ConfigDict(frozen=True) type: Literal["concat"] = "concat" - exprs: tuple["Expr", ...] + exprs: tuple[Expr, ...] dim: int = 0 def find_refs(self) -> set[W]: @@ -303,7 +303,7 @@ class Reshape(BaseModel): model_config = ConfigDict(frozen=True) type: Literal["reshape"] = "reshape" - expr: "Expr" + expr: Expr shape: tuple[int, ...] def find_refs(self) -> set[W]: @@ -442,10 +442,10 @@ def __getitem__(self, key: W) -> Expr: def __contains__(self, key: W) -> bool: return key in self.mappings - def __or__(self, other: "ExprPlan") -> "ExprPlan": + def __or__(self, other: ExprPlan) -> ExprPlan: return compose(self, other) - def __add__(self, other: "ExprPlan") -> "ExprPlan": + def __add__(self, other: ExprPlan) -> ExprPlan: return merge(self, other) def source_keys(self) -> set[str]: @@ -471,7 +471,7 @@ def summary(self) -> dict[str, Any]: "metadata": self.metadata, } - def fuse(self) -> "ExprPlan": + def fuse(self) -> ExprPlan: return ExprPlan( mappings={k: fuse(v) for k, v in self.mappings.items()}, source_format=self.source_format, diff --git a/fast_llm_external_models/apriel2/conversion/io.py b/fast_llm_external_models/apriel2/conversion/io.py index e1a261d7e..1f64df0b9 100644 --- a/fast_llm_external_models/apriel2/conversion/io.py +++ b/fast_llm_external_models/apriel2/conversion/io.py @@ -62,7 +62,7 @@ def __init__(self, files: list[Path], device: str = "cpu"): self._handles: dict[Path, Any] = {} self._key_index: dict[str, Path] = {} - def __enter__(self) -> "SafetensorLoader": + def __enter__(self) -> SafetensorLoader: # Pre-build index: key -> file (one-time O(n×m), then O(1) lookups) for f in self.files: handle = safe_open(f, framework="pt", device=self.device) @@ -128,7 +128,7 @@ def __init__( self._finalized: bool = False self._result_path: Path | None = None - def __enter__(self) -> "ShardedSafetensorWriter": + def __enter__(self) -> ShardedSafetensorWriter: return self def __exit__(self, exc_type, exc_val, exc_tb) -> None: @@ -180,8 +180,7 @@ def _flush(self) -> None: shard_file = self.output_dir / f"{self.base_name}-{self._shard_index:05d}.safetensors.tmp" logger.debug( - f"Writing shard {self._shard_index}: {len(self._buffer)} tensors, " - f"{self._buffer_bytes / 1e9:.2f} GB" + f"Writing shard {self._shard_index}: {len(self._buffer)} tensors, " f"{self._buffer_bytes / 1e9:.2f} GB" ) save_file(self._buffer, shard_file) self._shard_files.append(shard_file) diff --git a/fast_llm_external_models/apriel2/conversion/llava/plan.py b/fast_llm_external_models/apriel2/conversion/llava/plan.py index df485efbd..a97e46c1a 100644 --- a/fast_llm_external_models/apriel2/conversion/llava/plan.py +++ b/fast_llm_external_models/apriel2/conversion/llava/plan.py @@ -1,11 +1,6 @@ """Llava to Apriel2 weight conversion plan.""" -from fast_llm_external_models.apriel2.conversion.expr import ( - Expr, - ExprPlan, - Ref, - W, -) +from fast_llm_external_models.apriel2.conversion.expr import Expr, ExprPlan, Ref, W def plan_llava_to_apriel2(llava_config: dict) -> ExprPlan: diff --git a/fast_llm_external_models/apriel2/conversion/qwen2/plan.py b/fast_llm_external_models/apriel2/conversion/qwen2/plan.py index 7752d37c9..c1ec4af8b 100644 --- a/fast_llm_external_models/apriel2/conversion/qwen2/plan.py +++ b/fast_llm_external_models/apriel2/conversion/qwen2/plan.py @@ -1,11 +1,6 @@ """Qwen2/Qwen2.5 to Apriel2 weight conversion plan.""" -from fast_llm_external_models.apriel2.conversion.expr import ( - Expr, - ExprPlan, - Ref, - W, -) +from fast_llm_external_models.apriel2.conversion.expr import Expr, ExprPlan, Ref, W def plan_qwen2_to_apriel2(qwen2_config: dict) -> ExprPlan: @@ -55,9 +50,7 @@ def plan_qwen2_to_apriel2(qwen2_config: dict) -> ExprPlan: # lm_head - only if not tied if not qwen2_config.get("tie_word_embeddings", False): - static_mappings.append( - (W("lm_head", "weight"), W("lm_head", "weight")) - ) + static_mappings.append((W("lm_head", "weight"), W("lm_head", "weight"))) for src, tgt in static_mappings: mappings[tgt] = Ref(key=src) @@ -89,9 +82,7 @@ def plan_qwen2_to_apriel2(qwen2_config: dict) -> ExprPlan: mappings[tgt] = Ref(key=src) # Layer norms - mappings[apriel_layer / "input_layernorm" / "weight"] = Ref( - key=qwen_layer / "input_layernorm" / "weight" - ) + mappings[apriel_layer / "input_layernorm" / "weight"] = Ref(key=qwen_layer / "input_layernorm" / "weight") mappings[apriel_layer / "post_attention_layernorm" / "weight"] = Ref( key=qwen_layer / "post_attention_layernorm" / "weight" ) diff --git a/fast_llm_external_models/apriel2/conversion/render.py b/fast_llm_external_models/apriel2/conversion/render.py index d71fa03e1..f9a0c8ac1 100644 --- a/fast_llm_external_models/apriel2/conversion/render.py +++ b/fast_llm_external_models/apriel2/conversion/render.py @@ -8,17 +8,11 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING +from fast_llm_external_models.apriel2.conversion.expr import Concat, Init, Ref, Reshape, Slice + if TYPE_CHECKING: from fast_llm_external_models.apriel2.conversion.expr import Expr, ExprPlan -from fast_llm_external_models.apriel2.conversion.expr import ( - Concat, - Init, - Ref, - Reshape, - Slice, -) - @dataclass class PlanTreeNode: @@ -28,10 +22,10 @@ class PlanTreeNode: After merging, leaf nodes contain aggregated values from multiple siblings. """ - children: dict[str, "PlanTreeNode"] = field(default_factory=dict) + children: dict[str, PlanTreeNode] = field(default_factory=dict) # For leaf nodes: list of (sibling_key, expr) pairs # Before merge: single item, after merge: multiple items from merged siblings - values: list[tuple[str, "Expr"]] = field(default_factory=list) + values: list[tuple[str, Expr]] = field(default_factory=list) def is_leaf(self) -> bool: return len(self.children) == 0 @@ -61,7 +55,7 @@ def _build_plan_tree(plan: ExprPlan) -> PlanTreeNode: return root -def _expr_signature(expr: "Expr") -> tuple: +def _expr_signature(expr: Expr) -> tuple: """Get a signature for an expression that determines merge compatibility. Expressions with different signatures should not be merged together. @@ -453,7 +447,7 @@ def _render_plan_tree( ) -def _format_aggregated_leaf(values: list[tuple[str, "Expr"]]) -> str: +def _format_aggregated_leaf(values: list[tuple[str, Expr]]) -> str: """Format a leaf with aggregated values using pattern discovery. Args: @@ -494,7 +488,7 @@ def _format_aggregated_leaf(values: list[tuple[str, "Expr"]]) -> str: return _format_single_expr(first_expr) -def _format_single_expr(expr: "Expr") -> str: +def _format_single_expr(expr: Expr) -> str: """Format a single expression using ML notation.""" match expr: case Ref(key=key): @@ -531,7 +525,7 @@ def _format_single_expr(expr: "Expr") -> str: return f"= {type(expr).__name__}" -def _format_concat_part(expr: "Expr") -> str: +def _format_concat_part(expr: Expr) -> str: """Format a single part of a concat (for short display).""" match expr: case Ref(key=key): @@ -570,7 +564,7 @@ def _format_slice_notation(slices: tuple) -> str: return f"[{', '.join(slice_strs)}]" -def _format_aggregated_concat(values: list[tuple[str, "Expr"]]) -> str: +def _format_aggregated_concat(values: list[tuple[str, Expr]]) -> str: """Format aggregated Concat expressions with pattern discovery.""" # Get the first concat to understand structure first_concat = values[0][1] @@ -590,7 +584,7 @@ def _format_aggregated_concat(values: list[tuple[str, "Expr"]]) -> str: return f"= [{sep.join(formatted_parts)}]" -def _format_aggregated_concat_part(values: list[tuple[str, "Expr"]]) -> str: +def _format_aggregated_concat_part(values: list[tuple[str, Expr]]) -> str: """Format a single part of an aggregated concat.""" if len(values) == 1: return _format_concat_part(values[0][1]) @@ -619,7 +613,7 @@ def _format_aggregated_concat_part(values: list[tuple[str, "Expr"]]) -> str: return _format_concat_part(first_expr) -def _format_aggregated_slice(values: list[tuple[str, "Expr"]]) -> str: +def _format_aggregated_slice(values: list[tuple[str, Expr]]) -> str: """Format aggregated Slice expressions with pattern discovery.""" first_slice = values[0][1] if not isinstance(first_slice, Slice): diff --git a/fast_llm_external_models/apriel2/convert.py b/fast_llm_external_models/apriel2/convert.py index 60786d22c..66c419dfd 100644 --- a/fast_llm_external_models/apriel2/convert.py +++ b/fast_llm_external_models/apriel2/convert.py @@ -30,10 +30,7 @@ import yaml from tqdm import tqdm -# Allow running as script or module -if __name__ == "__main__": - sys.path.insert(0, str(Path(__file__).parent.parent.parent)) - +# Import source-specific converters from fast_llm_external_models.apriel2.conversion import ( DEFAULT_MAX_SHARD_SIZE, ExprPlan, @@ -42,13 +39,16 @@ StreamingExecutor, compose, compose_configs, - plan_surgery, - strip_init_fields, ) - -# Import source-specific converters from fast_llm_external_models.apriel2.conversion import llava as llava_converter +from fast_llm_external_models.apriel2.conversion import plan_surgery from fast_llm_external_models.apriel2.conversion import qwen2 as qwen2_converter +from fast_llm_external_models.apriel2.conversion import strip_init_fields + +# Allow running as script or module +if __name__ == "__main__": + sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + logger = logging.getLogger(__name__) @@ -155,7 +155,9 @@ def build_plan( # S × T → Plan: build plan from source state and transition spec surgery_plan = plan_surgery(current_config, target_config) - logger.info(f"Built surgery plan [{i}/{len(surgery_configs)}]: {surgery_plan.summary()['num_targets']} targets") + logger.info( + f"Built surgery plan [{i}/{len(surgery_configs)}]: {surgery_plan.summary()['num_targets']} targets" + ) # Compose plans current_plan = compose(current_plan, surgery_plan) @@ -223,9 +225,7 @@ def convert( executor = StreamingExecutor(full_plan, loader) with ShardedSafetensorWriter(output_dir, max_shard_size=max_shard_size) as writer: - for target_key, tensor in tqdm( - executor.execute(seed), desc="Converting", total=len(full_plan) - ): + for target_key, tensor in tqdm(executor.execute(seed), desc="Converting", total=len(full_plan)): writer.add(target_key, tensor) return final_config @@ -294,9 +294,7 @@ def resolve_input(input_path: str) -> Path: def main(): - parser = argparse.ArgumentParser( - description="Convert HuggingFace checkpoint to Apriel2 HF format" - ) + parser = argparse.ArgumentParser(description="Convert HuggingFace checkpoint to Apriel2 HF format") parser.add_argument( "input", type=str, @@ -396,8 +394,7 @@ def main(): safetensor_files = sorted(input_dir.glob("*.safetensors")) if not safetensor_files: raise ValueError( - f"No safetensor files found in {input_dir}. " - "Plan-based conversion requires safetensor files." + f"No safetensor files found in {input_dir}. " "Plan-based conversion requires safetensor files." ) # Convert using plan-based approach with streaming sharded output diff --git a/fast_llm_external_models/apriel2/modeling_apriel2.py b/fast_llm_external_models/apriel2/modeling_apriel2.py index 930b158f3..240240cd6 100644 --- a/fast_llm_external_models/apriel2/modeling_apriel2.py +++ b/fast_llm_external_models/apriel2/modeling_apriel2.py @@ -1243,7 +1243,9 @@ def forward(self, hidden_states: torch.Tensor, past_key_values=None, attention_m mixed_qkv = self.convolution.update( mixed_qkv.squeeze(2), # [batch, conv_dim, 1] -> [batch, conv_dim] conv_state, - ).unsqueeze(2) # [batch, conv_dim] -> [batch, conv_dim, 1] + ).unsqueeze( + 2 + ) # [batch, conv_dim] -> [batch, conv_dim, 1] else: # Prefill mode use_cache = past_key_values is not None @@ -1488,9 +1490,7 @@ def __init__( # Normalization - use GatedRMSNormalization (same wrapper as GDN, with sigmoid activation) self.norm = GatedRMSNormalization(self.head_dim, eps=self.norm_eps, activation=self.norm_activation) - def _apply_conv( - self, x: torch.Tensor, conv: CausalConv1d, conv_state: torch.Tensor | None, use_cache: bool - ): + def _apply_conv(self, x: torch.Tensor, conv: CausalConv1d, conv_state: torch.Tensor | None, use_cache: bool): """ Apply causal convolution with cache support. diff --git a/fast_llm_external_models/tests/test_apriel2/conftest.py b/fast_llm_external_models/tests/test_apriel2/conftest.py index 320813747..21b90b097 100644 --- a/fast_llm_external_models/tests/test_apriel2/conftest.py +++ b/fast_llm_external_models/tests/test_apriel2/conftest.py @@ -1,7 +1,7 @@ """Test fixtures for Apriel2 model tests.""" +from collections.abc import Generator from pathlib import Path -from typing import Generator import pytest import torch @@ -18,7 +18,6 @@ def pytest_configure(config): def _can_import_fast_llm(): """Check if Fast-LLM is available.""" try: - from fast_llm.engine.checkpoint.convert import ConvertConfig return True except ImportError: return False @@ -26,15 +25,11 @@ def _can_import_fast_llm(): # Skip marker for tests that require CUDA for Mamba forward pass requires_cuda = pytest.mark.skipif( - not torch.cuda.is_available(), - reason="SSM mixers (Mamba) require CUDA for forward pass" + not torch.cuda.is_available(), reason="SSM mixers (Mamba) require CUDA for forward pass" ) # Skip marker for tests that require Fast-LLM -requires_fastllm = pytest.mark.skipif( - not _can_import_fast_llm(), - reason="Fast-LLM not available" -) +requires_fastllm = pytest.mark.skipif(not _can_import_fast_llm(), reason="Fast-LLM not available") @pytest.fixture(scope="module", autouse=True) @@ -164,14 +159,11 @@ def model_pair(request, small_pixtral_model, tmp_path): tuple: (source_model, target_model, expected_atol, variant_name) """ import json + from safetensors import safe_open from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config - from fast_llm_external_models.apriel2.conversion import ( - convert_llava_config, - execute, - plan_llava_to_apriel2, - ) + from fast_llm_external_models.apriel2.conversion import convert_llava_config, execute, plan_llava_to_apriel2 from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForConditionalGeneration source = small_pixtral_model @@ -667,12 +659,12 @@ def apriel2_config_comprehensive(): "type": "pattern", "num_blocks": 6, "pattern": [ - "attn", # 0: pure full attention - "swa", # 1: pure sliding window attention - "mamba", # 2: pure mamba - "gdn", # 3: pure gated delta net - "stoch_attn_mamba", # 4: stochastic attention + mamba - "stoch_swa_gdn", # 5: stochastic swa + gated delta net + "attn", # 0: pure full attention + "swa", # 1: pure sliding window attention + "mamba", # 2: pure mamba + "gdn", # 3: pure gated delta net + "stoch_attn_mamba", # 4: stochastic attention + mamba + "stoch_swa_gdn", # 5: stochastic swa + gated delta net ], "blocks": { "attn": { @@ -1031,7 +1023,7 @@ def comprehensive_torture_chain(): # MIL requires: d_inner <= Q rows (256), d_xb <= K/V rows (128) mamba_params = { "d_inner": 256, # Must be <= heads*head_size = 256 - "d_xb": 64, # Must be <= head_groups*head_size = 128 + "d_xb": 64, # Must be <= head_groups*head_size = 128 "dt_rank": 16, "d_state": 16, "d_conv": 4, diff --git a/fast_llm_external_models/tests/test_apriel2/test_cache_apriel2_specific.py b/fast_llm_external_models/tests/test_apriel2/test_cache_apriel2_specific.py index e0e4db2d3..b45779454 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_cache_apriel2_specific.py +++ b/fast_llm_external_models/tests/test_apriel2/test_cache_apriel2_specific.py @@ -18,8 +18,7 @@ import pytest import torch -from fast_llm_external_models.apriel2.cache import Apriel2Cache, _AttentionCache, _SSMCache - +from fast_llm_external_models.apriel2.cache import Apriel2Cache # ============================================================================= # STOCHASTIC MIXER ROUTING diff --git a/fast_llm_external_models/tests/test_apriel2/test_cache_contracts.py b/fast_llm_external_models/tests/test_apriel2/test_cache_contracts.py index 7c38f75b7..8ceabfb91 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_cache_contracts.py +++ b/fast_llm_external_models/tests/test_apriel2/test_cache_contracts.py @@ -27,8 +27,7 @@ import pytest import torch -from fast_llm_external_models.apriel2.cache import _AttentionCache, _SSMCache, Apriel2Cache - +from fast_llm_external_models.apriel2.cache import Apriel2Cache, _AttentionCache # ============================================================================= # SECTION 1: FULL ATTENTION - _AttentionCache vs DynamicLayer @@ -78,9 +77,9 @@ def test_get_seq_length_during_decode( hf_dynamic_layer.update(key.clone(), value.clone()) apriel_attention_cache.update(key.clone(), value.clone()) - assert apriel_attention_cache.cumulative_length == hf_dynamic_layer.get_seq_length(), ( - f"Mismatch at decode step {step}" - ) + assert ( + apriel_attention_cache.cumulative_length == hf_dynamic_layer.get_seq_length() + ), f"Mismatch at decode step {step}" # ------------------------------------------------------------------------- # get_mask_sizes: Verify HF behavior for documentation @@ -343,9 +342,9 @@ def test_cumulative_length_tracks_all_tokens( hf_sliding_layer.update(key.clone(), value.clone()) apriel_sliding_cache.update(key.clone(), value.clone()) - assert apriel_sliding_cache.cumulative_length == hf_sliding_layer.get_seq_length(), ( - f"cumulative_length mismatch at step {i}" - ) + assert ( + apriel_sliding_cache.cumulative_length == hf_sliding_layer.get_seq_length() + ), f"cumulative_length mismatch at step {i}" # ============================================================================= @@ -496,7 +495,7 @@ class TestMaskCorrectness: def test_full_attention_decode_can_attend_to_all(self): """During decode, query can attend to all cached positions.""" - from transformers.masking_utils import sdpa_mask, causal_mask_function + from transformers.masking_utils import causal_mask_function, sdpa_mask cache = _AttentionCache(window=None) @@ -559,13 +558,13 @@ def test_sliding_window_decode_respects_window(self, window_size): causal = abs_pos <= query_pos expected = in_window and causal - assert query_mask[kv_idx].item() == expected, ( - f"Position {abs_pos}: expected {expected}, got {query_mask[kv_idx].item()}" - ) + assert ( + query_mask[kv_idx].item() == expected + ), f"Position {abs_pos}: expected {expected}, got {query_mask[kv_idx].item()}" def test_prefill_has_causal_pattern(self): """During prefill, mask has proper causal (lower triangular) pattern.""" - from transformers.masking_utils import sdpa_mask, causal_mask_function + from transformers.masking_utils import causal_mask_function, sdpa_mask cache = _AttentionCache(window=None) cache.update(torch.randn(1, 1, 5, 16), torch.randn(1, 1, 5, 16)) diff --git a/fast_llm_external_models/tests/test_apriel2/test_causal_conv1d.py b/fast_llm_external_models/tests/test_apriel2/test_causal_conv1d.py index ec6abc1d2..0567cd76e 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_causal_conv1d.py +++ b/fast_llm_external_models/tests/test_apriel2/test_causal_conv1d.py @@ -24,7 +24,6 @@ from fast_llm_external_models.apriel2.modeling_apriel2 import CausalConv1d, _causal_conv1d_fn - # ============================================================================= # Fixtures # ============================================================================= @@ -63,6 +62,7 @@ def kernel_size(): def to_device(conv: CausalConv1d, device: str) -> CausalConv1d: """Create a copy of conv on the specified device.""" import copy + return copy.deepcopy(conv).to(device) @@ -71,7 +71,9 @@ def prefill(conv: CausalConv1d, x: torch.Tensor, state: torch.Tensor = None) -> return conv(x, conv_state=state, return_final_state=True) -def decode_sequence(conv: CausalConv1d, tokens: torch.Tensor, state: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: +def decode_sequence( + conv: CausalConv1d, tokens: torch.Tensor, state: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: """Decode multiple tokens one-by-one, return (stacked_outputs, final_state). Args: @@ -223,7 +225,7 @@ def test_chunked_prefill_cpu(self, conv, dim, total_len, chunk_size): outputs = [] state = None for start in range(0, total_len, chunk_size): - chunk = x[:, :, start:start + chunk_size] + chunk = x[:, :, start : start + chunk_size] out, state = prefill(conv, chunk, state) outputs.append(out) @@ -248,7 +250,7 @@ def test_chunked_prefill_cuda(self, conv, dim, total_len, chunk_size): outputs = [] state = None for start in range(0, total_len, chunk_size): - chunk = x[:, :, start:start + chunk_size].cuda() + chunk = x[:, :, start : start + chunk_size].cuda() out, state = prefill(conv_cuda, chunk, state) outputs.append(out) @@ -329,7 +331,7 @@ def test_all_cpu_paths_match(self, conv, dim): outputs = [] state = None for start in range(0, total_len, chunk_size): - chunk = x[:, :, start:start + chunk_size] + chunk = x[:, :, start : start + chunk_size] out, state = prefill(conv, chunk, state) outputs.append(out) path1 = torch.cat(outputs, dim=-1) @@ -374,7 +376,7 @@ def test_all_paths_match_cross_device(self, conv, dim): # CPU chunked outputs, state = [], None for start in range(0, total_len, chunk_size): - out, state = prefill(conv, x[:, :, start:start + chunk_size], state) + out, state = prefill(conv, x[:, :, start : start + chunk_size], state) outputs.append(out) results["cpu_chunked"] = torch.cat(outputs, dim=-1) @@ -393,7 +395,7 @@ def test_all_paths_match_cross_device(self, conv, dim): # CUDA chunked outputs, state = [], None for start in range(0, total_len, chunk_size): - out, state = prefill(conv_cuda, x[:, :, start:start + chunk_size].cuda(), state) + out, state = prefill(conv_cuda, x[:, :, start : start + chunk_size].cuda(), state) outputs.append(out.cpu()) results["cuda_chunked"] = torch.cat(outputs, dim=-1) @@ -431,8 +433,7 @@ def test_all_paths_match_cross_device(self, conv, dim): for name, result in results.items(): tol = tolerances[name] torch.testing.assert_close( - result, reference, atol=tol, rtol=tol, - msg=f"Path '{name}' diverged from reference" + result, reference, atol=tol, rtol=tol, msg=f"Path '{name}' diverged from reference" ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") @@ -468,8 +469,8 @@ def test_long_decode_no_drift(self, conv, dim): # Check no systematic drift (errors shouldn't consistently increase) decode_errors = errors[prefill_len:] - first_half = decode_errors[:len(decode_errors)//2].mean() - second_half = decode_errors[len(decode_errors)//2:].mean() + first_half = decode_errors[: len(decode_errors) // 2].mean() + second_half = decode_errors[len(decode_errors) // 2 :].mean() assert second_half < first_half * 2, "Errors growing over decode steps (drift detected)" diff --git a/fast_llm_external_models/tests/test_apriel2/test_compose_configs.py b/fast_llm_external_models/tests/test_apriel2/test_compose_configs.py index b1ee15d54..3413b9d25 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_compose_configs.py +++ b/fast_llm_external_models/tests/test_apriel2/test_compose_configs.py @@ -20,7 +20,7 @@ import yaml from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config -from fast_llm_external_models.apriel2.conversion.config import apply_surgery, compose_configs +from fast_llm_external_models.apriel2.conversion.config import compose_configs class TestComposeConfigsLaws: @@ -314,7 +314,13 @@ def test_monoid_action_compatibility(self, source_config, num_surgeries): Parameterized to test with 2 and 3 surgeries. """ surgeries = [ - {"decoder": {"block": {"mixer": {"type": "stochastic", "main_mixer_name": "attention", "mixers": {"attention": {}}}}}}, + { + "decoder": { + "block": { + "mixer": {"type": "stochastic", "main_mixer_name": "attention", "mixers": {"attention": {}}} + } + } + }, {"decoder": {"block": {"mixer": {"mixers": {"sliding_window": {"window_size": 512}}}}}}, {"decoder": {"block": {"mixer": {"mixers": {"gdn": {"type": "gdn"}}}}}}, ][:num_surgeries] diff --git a/fast_llm_external_models/tests/test_apriel2/test_conversion_e2e.py b/fast_llm_external_models/tests/test_apriel2/test_conversion_e2e.py index 09fb9fa13..b91fb7e51 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_conversion_e2e.py +++ b/fast_llm_external_models/tests/test_apriel2/test_conversion_e2e.py @@ -16,21 +16,12 @@ import pytest import torch -from fast_llm_external_models.tests.test_apriel2.conftest import requires_cuda - from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config -from fast_llm_external_models.apriel2.conversion import ( - compose, - compose_configs, - execute, - plan_surgery, -) -from fast_llm_external_models.apriel2.conversion.llava import ( - convert_config as convert_llava_config, - plan_llava_to_apriel2, -) +from fast_llm_external_models.apriel2.conversion import compose, compose_configs, execute, plan_surgery +from fast_llm_external_models.apriel2.conversion.llava import convert_config as convert_llava_config +from fast_llm_external_models.apriel2.conversion.llava import plan_llava_to_apriel2 from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForConditionalGeneration - +from fast_llm_external_models.tests.test_apriel2.conftest import requires_cuda # ============================================================================= # Cycling Surgery Generation @@ -87,40 +78,20 @@ def generate_cycling_surgeries(config: dict) -> list[tuple[dict, str]]: if sub_name != main_mixer: # Build surgery path based on block_path if block_path == "block": - surgery = { - "decoder": { - "block": {"mixer": {"main_mixer_name": sub_name}} - } - } + surgery = {"decoder": {"block": {"mixer": {"main_mixer_name": sub_name}}}} else: # block_path is "blocks.block_name" block_name = block_path.split(".")[1] - surgery = { - "decoder": { - "blocks": { - block_name: {"mixer": {"main_mixer_name": sub_name}} - } - } - } + surgery = {"decoder": {"blocks": {block_name: {"mixer": {"main_mixer_name": sub_name}}}}} surgeries.append((surgery, f"cycle {block_path} to {sub_name}")) # Restore original main_mixer_name if any(sub_name != main_mixer for sub_name in sub_mixer_names): if block_path == "block": - restore = { - "decoder": { - "block": {"mixer": {"main_mixer_name": main_mixer}} - } - } + restore = {"decoder": {"block": {"mixer": {"main_mixer_name": main_mixer}}}} else: block_name = block_path.split(".")[1] - restore = { - "decoder": { - "blocks": { - block_name: {"mixer": {"main_mixer_name": main_mixer}} - } - } - } + restore = {"decoder": {"blocks": {block_name: {"mixer": {"main_mixer_name": main_mixer}}}}} surgeries.append((restore, f"restore {block_path} to {main_mixer}")) return surgeries @@ -194,9 +165,7 @@ def source_config(self, llava_pixtral_checkpoint): with open(llava_pixtral_checkpoint / "config.json") as f: return json.load(f) - def test_initial_conversion_produces_working_model( - self, source_config, source_weights - ): + def test_initial_conversion_produces_working_model(self, source_config, source_weights): """Test that Llava → Apriel2 conversion produces a working model.""" # Convert config apriel2_config_dict = convert_llava_config(source_config) @@ -219,9 +188,7 @@ def test_initial_conversion_produces_working_model( assert outputs.logits.shape == (1, 8, config.vocab_size) - def test_each_surgery_step_produces_working_model( - self, source_config, source_weights, additive_surgery_chain - ): + def test_each_surgery_step_produces_working_model(self, source_config, source_weights, additive_surgery_chain): """Test that each surgery step produces a model that can forward pass. Key insight: Surgery plans reference Apriel2 keys, so we must COMPOSE @@ -290,9 +257,7 @@ def test_each_surgery_step_produces_working_model( except Exception as e: pytest.fail(f"Step {i+1}: Forward pass failed - {e}") - def test_all_stochastic_submixers_via_cycling( - self, source_config, source_weights, additive_surgery_chain - ): + def test_all_stochastic_submixers_via_cycling(self, source_config, source_weights, additive_surgery_chain): """Test ALL sub-mixers in stochastic blocks, not just the main mixer. Problem: Forward pass only exercises the main_mixer_name. Other sub-mixers @@ -312,9 +277,7 @@ def test_all_stochastic_submixers_via_cycling( conversion_plan = plan_llava_to_apriel2(source_config) # Expand surgery chain with cycling - expanded_chain = expand_surgery_chain_with_cycling( - additive_surgery_chain, apriel2_config - ) + expanded_chain = expand_surgery_chain_with_cycling(additive_surgery_chain, apriel2_config) # Build cumulative plan: conversion | surgery_1 | cycling_1a | ... | restore_1 | surgery_2 | ... current_plan = conversion_plan @@ -359,9 +322,7 @@ def test_all_stochastic_submixers_via_cycling( except Exception as e: pytest.fail(f"{desc}: Forward pass failed - {e}") - def test_composed_plan_equals_sequential_execution( - self, source_config, source_weights, additive_surgery_chain - ): + def test_composed_plan_equals_sequential_execution(self, source_config, source_weights, additive_surgery_chain): """Test that composing plans gives same result as sequential execution. This verifies plan composition associativity: @@ -399,13 +360,9 @@ def test_composed_plan_equals_sequential_execution( # Compare weights for key in seq_weights: if key in composed_weights: - assert torch.allclose( - seq_weights[key], composed_weights[key], atol=1e-5 - ), f"Weight mismatch for {key}" + assert torch.allclose(seq_weights[key], composed_weights[key], atol=1e-5), f"Weight mismatch for {key}" - def test_final_model_structure( - self, source_config, source_weights, additive_surgery_chain - ): + def test_final_model_structure(self, source_config, source_weights, additive_surgery_chain): """Verify the final model has the expected structure.""" # Initial conversion current_config = convert_llava_config(source_config) @@ -504,9 +461,7 @@ def base_setup(self, llava_pixtral_checkpoint): """Set up base config and weights after Llava conversion.""" from safetensors.torch import load_file - from fast_llm_external_models.apriel2.conversion.llava import ( - convert_config as convert_llava_config, - ) + from fast_llm_external_models.apriel2.conversion.llava import convert_config as convert_llava_config # Load source config and weights with open(llava_pixtral_checkpoint / "config.json") as f: @@ -534,9 +489,7 @@ def _merge_surgeries(self, surgeries: list[dict]) -> dict: result = _deep_merge(result, s) return result - def _build_incremental_plans( - self, base_config: dict, surgeries: list[dict] - ) -> tuple[list, list[dict]]: + def _build_incremental_plans(self, base_config: dict, surgeries: list[dict]) -> tuple[list, list[dict]]: """Build incremental plans for each surgery step. Returns (plans, configs) where configs[i] is the config after surgery i. @@ -552,9 +505,7 @@ def _build_incremental_plans( config = target_config return plans, configs - def test_incremental_equals_direct_full_chain( - self, base_setup, additive_surgery_chain - ): + def test_incremental_equals_direct_full_chain(self, base_setup, additive_surgery_chain): """Test that composing all incremental plans equals one direct plan. compose(P1, P2, ..., Pn) ≡ plan_surgery(base, final) @@ -575,9 +526,7 @@ def test_incremental_equals_direct_full_chain( direct_plan = plan_surgery(base_config, final_config) # Verify same target keys - assert set(composed_plan.mappings.keys()) == set( - direct_plan.mappings.keys() - ), "Plan keys should match" + assert set(composed_plan.mappings.keys()) == set(direct_plan.mappings.keys()), "Plan keys should match" # Execute both and compare weights composed_weights = execute(composed_plan, base_weights, seed=0) @@ -611,9 +560,7 @@ def test_every_prefix_consistency(self, base_setup, additive_surgery_chain): direct = plan_surgery(base_config, configs[k]) # Verify keys match - assert set(composed.mappings.keys()) == set( - direct.mappings.keys() - ), f"Prefix {k}: keys don't match" + assert set(composed.mappings.keys()) == set(direct.mappings.keys()), f"Prefix {k}: keys don't match" # Execute and compare composed_weights = execute(composed, base_weights, seed=0) @@ -781,9 +728,7 @@ def torture_setup(self, llava_pixtral_checkpoint): """Set up for comprehensive torture tests.""" from safetensors.torch import load_file - from fast_llm_external_models.apriel2.conversion.llava import ( - convert_config as convert_llava_config, - ) + from fast_llm_external_models.apriel2.conversion.llava import convert_config as convert_llava_config # Load source with open(llava_pixtral_checkpoint / "config.json") as f: @@ -801,9 +746,7 @@ def torture_setup(self, llava_pixtral_checkpoint): return base_config, base_weights - def test_each_step_produces_valid_config( - self, torture_setup, comprehensive_torture_chain - ): + def test_each_step_produces_valid_config(self, torture_setup, comprehensive_torture_chain): """Test that each surgery step produces a valid config.""" base_config, _ = torture_setup @@ -818,9 +761,7 @@ def test_each_step_produces_valid_config( pytest.fail(f"Step {i+1} produced invalid config: {e}") @requires_cuda - def test_each_step_produces_working_model( - self, torture_setup, comprehensive_torture_chain - ): + def test_each_step_produces_working_model(self, torture_setup, comprehensive_torture_chain): """Test that each surgery step produces a model that can forward pass. This is the ultimate integration test - config composition + plan building @@ -875,9 +816,7 @@ def test_each_step_produces_working_model( current_weights = new_weights @requires_cuda - def test_final_supernet_structure( - self, torture_setup, comprehensive_torture_chain - ): + def test_final_supernet_structure(self, torture_setup, comprehensive_torture_chain): """Verify the final architecture has supernet blocks with all 4 mixer types.""" base_config, base_weights = torture_setup @@ -914,9 +853,7 @@ def test_final_supernet_structure( assert outputs.logits.shape == (1, 8, config.vocab_size) @requires_cuda - def test_plan_config_consistency_comprehensive( - self, torture_setup, comprehensive_torture_chain - ): + def test_plan_config_consistency_comprehensive(self, torture_setup, comprehensive_torture_chain): """Test that incremental plan composition works for the comprehensive chain. Note: We cannot compare to a "direct plan" because the comprehensive chain @@ -1106,7 +1043,7 @@ def test_plan_surgery_random_succeeds_for_any_type_pair(self, mamba_config): plan = plan_surgery(mamba_config, surgery) # Verify the plan has the expected target keys - target_keys = set(str(k) for k in plan.mappings.keys()) + target_keys = {str(k) for k in plan.mappings.keys()} assert any("mixer.q_proj" in k for k in target_keys) def test_plan_surgery_transfer_fails_for_unsupported_type_pair(self, mamba_config): @@ -1159,7 +1096,7 @@ def test_plan_surgery_transfer_succeeds_for_supported_type_pair(self, base_confi plan = plan_surgery(base_config, surgery) # Verify the plan has mamba target keys - target_keys = set(str(k) for k in plan.mappings.keys()) + target_keys = {str(k) for k in plan.mappings.keys()} assert any("mixer.in_proj" in k for k in target_keys) def test_stochastic_init_random_succeeds_for_any_submixer_type(self, mamba_config): @@ -1199,7 +1136,7 @@ def test_stochastic_init_random_succeeds_for_any_submixer_type(self, mamba_confi plan = plan_surgery(mamba_config, surgery) # Verify both sub-mixers have target keys - target_keys = set(str(k) for k in plan.mappings.keys()) + target_keys = {str(k) for k in plan.mappings.keys()} assert any("mixers.attention.q_proj" in k for k in target_keys) assert any("mixers.swa.q_proj" in k for k in target_keys) @@ -1234,7 +1171,7 @@ def test_mixed_init_modes_in_stochastic(self, base_config): plan = plan_surgery(base_config, surgery) # Verify both sub-mixers have target keys - target_keys = set(str(k) for k in plan.mappings.keys()) + target_keys = {str(k) for k in plan.mappings.keys()} assert any("mixers.attention.q_proj" in k for k in target_keys) assert any("mixers.gdn.in_proj_qkvz" in k for k in target_keys) @@ -1369,8 +1306,8 @@ def test_different_paths_same_config_same_plan(self, attention_config_dict): plan_from_b = plan_surgery(config_b, final_surgery) # Compare plan mappings - keys_a = set(str(k) for k in plan_from_a.mappings.keys()) - keys_b = set(str(k) for k in plan_from_b.mappings.keys()) + keys_a = {str(k) for k in plan_from_a.mappings.keys()} + keys_b = {str(k) for k in plan_from_b.mappings.keys()} assert keys_a == keys_b, "Plans from same config via different paths should be identical" def test_init_in_source_config_does_not_affect_plan(self, attention_config_dict): @@ -1408,8 +1345,8 @@ def test_init_in_source_config_does_not_affect_plan(self, attention_config_dict) plan_with = plan_surgery(config_with_init, surgery) plan_without = plan_surgery(config_without_init, surgery) - keys_with = set(str(k) for k in plan_with.mappings.keys()) - keys_without = set(str(k) for k in plan_without.mappings.keys()) + keys_with = {str(k) for k in plan_with.mappings.keys()} + keys_without = {str(k) for k in plan_without.mappings.keys()} # Plans should be identical - source's init field is ignored assert keys_with == keys_without, "Plan should not depend on init in source config" @@ -1614,7 +1551,7 @@ def test_expand_surgery_chain_adds_cycling(self): # Verify restore flag assert expanded[0][2] is False # surgery - not restore assert expanded[1][2] is False # cycle - not restore - assert expanded[2][2] is True # restore + assert expanded[2][2] is True # restore def test_expand_surgery_chain_preserves_invariant(self): """Test that cycling leaves the chain state invariant.""" diff --git a/fast_llm_external_models/tests/test_apriel2/test_convert_from_llava.py b/fast_llm_external_models/tests/test_apriel2/test_convert_from_llava.py index a437f920d..f96f5ac40 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_convert_from_llava.py +++ b/fast_llm_external_models/tests/test_apriel2/test_convert_from_llava.py @@ -14,23 +14,15 @@ """ import json -from pathlib import Path -import pytest import torch from safetensors import safe_open -from safetensors.torch import save_file from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config -from fast_llm_external_models.apriel2.conversion import ( - convert_llava_config as convert_config, - execute, - plan_llava_to_apriel2, - plan_surgery, -) +from fast_llm_external_models.apriel2.conversion import convert_llava_config as convert_config +from fast_llm_external_models.apriel2.conversion import execute, plan_llava_to_apriel2, plan_surgery from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForConditionalGeneration - # ============================================================================= # Config Conversion Tests # ============================================================================= @@ -330,9 +322,9 @@ def test_plan_keys_match_model_state_dict(self, llava_pixtral_checkpoint): extra_in_plan = plan_keys - model_keys # Filter out expected missing keys (caches, positions, etc.) - missing_in_plan = {k for k in missing_in_plan if not any( - skip in k.lower() for skip in ["cache", "position", "mask"] - )} + missing_in_plan = { + k for k in missing_in_plan if not any(skip in k.lower() for skip in ["cache", "position", "mask"]) + } assert not missing_in_plan, f"Model keys not in plan: {sorted(missing_in_plan)[:10]}" assert not extra_in_plan, f"Plan keys not in model: {sorted(extra_in_plan)[:10]}" diff --git a/fast_llm_external_models/tests/test_apriel2/test_equivalence.py b/fast_llm_external_models/tests/test_apriel2/test_equivalence.py index c59ed2000..9b3eb4efe 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_equivalence.py +++ b/fast_llm_external_models/tests/test_apriel2/test_equivalence.py @@ -23,9 +23,6 @@ import torch from transformers import LlavaForConditionalGeneration -from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForConditionalGeneration - - # ============================================================================= # Input Configuration # ============================================================================= @@ -487,8 +484,10 @@ def test_batch_processing_behavior(self, model_pair): batch_tgt = target.get_image_features(pixel_values).view(-1, batch_src.shape[-1]) # Sequential processing - singles_src = [get_pixtral_vision_features(source, pixel_values[i:i+1]) for i in range(3)] - singles_tgt = [target.get_image_features(pixel_values[i:i+1]).view(-1, batch_src.shape[-1]) for i in range(3)] + singles_src = [get_pixtral_vision_features(source, pixel_values[i : i + 1]) for i in range(3)] + singles_tgt = [ + target.get_image_features(pixel_values[i : i + 1]).view(-1, batch_src.shape[-1]) for i in range(3) + ] single_concat_src = torch.cat(singles_src, dim=0) single_concat_tgt = torch.cat(singles_tgt, dim=0) @@ -500,9 +499,9 @@ def test_batch_processing_behavior(self, model_pair): print(f"Apriel2 batch vs sequential: {tgt_diff:.6f}") # Both should have the same behavior (within FP tolerance) - assert abs(src_diff - tgt_diff) < 1e-6, ( - f"Batch processing behavior differs: src={src_diff:.6f}, tgt={tgt_diff:.6f}" - ) + assert ( + abs(src_diff - tgt_diff) < 1e-6 + ), f"Batch processing behavior differs: src={src_diff:.6f}, tgt={tgt_diff:.6f}" if __name__ == "__main__": diff --git a/fast_llm_external_models/tests/test_apriel2/test_expr_plan.py b/fast_llm_external_models/tests/test_apriel2/test_expr_plan.py index 569ed88fd..2dccac5ad 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_expr_plan.py +++ b/fast_llm_external_models/tests/test_apriel2/test_expr_plan.py @@ -1,15 +1,13 @@ """Tests for the expression-based plan system.""" import json + import pytest import torch -from fast_llm_external_models.tests.test_apriel2.conftest import requires_cuda - from fast_llm_external_models.apriel2.conversion import ( Concat, EvalKwargs, - Expr, ExprAdapter, ExprPlan, Init, @@ -18,10 +16,9 @@ Slice, StreamingExecutor, W, - compose, execute, - fuse, full_slice, + fuse, make_slice, plan_dil_attention_to_gdn, plan_kil_attention_to_kda, @@ -31,6 +28,7 @@ slice_spec, substitute, ) +from fast_llm_external_models.tests.test_apriel2.conftest import requires_cuda def make_eval_kwargs( @@ -219,10 +217,13 @@ def test_substitute_init_unchanged(self): def test_substitute_complex(self): """Substitute handles complex nested expressions.""" # Concat of Slice(Ref) and Init - expr = Concat(exprs=( - Slice(expr=Ref(key=W("a")), slices=((0, 5, None),)), - Init(shape=(5,), init_type="zeros"), - ), dim=0) + expr = Concat( + exprs=( + Slice(expr=Ref(key=W("a")), slices=((0, 5, None),)), + Init(shape=(5,), init_type="zeros"), + ), + dim=0, + ) bindings = {W("a"): Ref(key=W("source"))} result = substitute(expr, bindings) @@ -238,7 +239,13 @@ class TestFuse: def test_fuse_flatten_concat(self): """Fuse flattens nested Concat with same dim.""" inner = Concat(exprs=(Ref(key=W("a")), Ref(key=W("b"))), dim=0) - outer = Concat(exprs=(inner, Ref(key=W("c")),), dim=0) + outer = Concat( + exprs=( + inner, + Ref(key=W("c")), + ), + dim=0, + ) result = fuse(outer) assert isinstance(result, Concat) @@ -250,7 +257,13 @@ def test_fuse_flatten_concat(self): def test_fuse_no_flatten_different_dim(self): """Fuse doesn't flatten Concat with different dim.""" inner = Concat(exprs=(Ref(key=W("a")), Ref(key=W("b"))), dim=1) - outer = Concat(exprs=(inner, Ref(key=W("c")),), dim=0) + outer = Concat( + exprs=( + inner, + Ref(key=W("c")), + ), + dim=0, + ) result = fuse(outer) assert isinstance(result, Concat) @@ -340,28 +353,34 @@ class TestExprPlan: def test_plan_define_and_access(self): """Plan stores and retrieves expressions.""" - plan = ExprPlan(mappings={ - W("target"): Ref(key=W("source")), - }) + plan = ExprPlan( + mappings={ + W("target"): Ref(key=W("source")), + } + ) assert W("target") in plan assert isinstance(plan[W("target")], Ref) def test_plan_source_keys(self): """Plan identifies all source references.""" - plan = ExprPlan(mappings={ - W("a"): Ref(key=W("x")), - W("b"): Concat(exprs=(Ref(key=W("y")), Ref(key=W("z"))), dim=0), - W("c"): Init(shape=(10,), init_type="zeros"), - }) + plan = ExprPlan( + mappings={ + W("a"): Ref(key=W("x")), + W("b"): Concat(exprs=(Ref(key=W("y")), Ref(key=W("z"))), dim=0), + W("c"): Init(shape=(10,), init_type="zeros"), + } + ) assert plan.source_keys() == {W("x"), W("y"), W("z")} def test_plan_target_keys(self): """Plan identifies all target keys.""" - plan = ExprPlan(mappings={ - W("a"): Ref(key=W("x")), - W("b"): Ref(key=W("y")), - }) + plan = ExprPlan( + mappings={ + W("a"): Ref(key=W("x")), + W("b"): Ref(key=W("y")), + } + ) assert plan.target_keys() == {W("a"), W("b")} @@ -386,9 +405,17 @@ def test_plan_summary(self): def test_plan_fuse(self): """Plan fuse applies optimizations.""" inner = Concat(exprs=(Ref(key=W("a")), Ref(key=W("b"))), dim=0) - plan = ExprPlan(mappings={ - W("out"): Concat(exprs=(inner, Ref(key=W("c")),), dim=0), - }) + plan = ExprPlan( + mappings={ + W("out"): Concat( + exprs=( + inner, + Ref(key=W("c")), + ), + dim=0, + ), + } + ) fused = plan.fuse() assert isinstance(fused[W("out")], Concat) @@ -532,9 +559,11 @@ class TestStreamingExecution: def test_execute_simple(self): """Execute simple plan.""" - plan = ExprPlan(mappings={ - W("out"): Ref(key=W("in")), - }) + plan = ExprPlan( + mappings={ + W("out"): Ref(key=W("in")), + } + ) sources = {W("in"): torch.tensor([1.0, 2.0, 3.0])} result = execute(plan, sources, seed=42) @@ -544,9 +573,11 @@ def test_execute_simple(self): def test_execute_concat(self): """Execute plan with Concat.""" - plan = ExprPlan(mappings={ - W("combined"): Concat(exprs=(Ref(key=W("a")), Ref(key=W("b"))), dim=0), - }) + plan = ExprPlan( + mappings={ + W("combined"): Concat(exprs=(Ref(key=W("a")), Ref(key=W("b"))), dim=0), + } + ) sources = { W("a"): torch.ones(2, 3), @@ -559,14 +590,19 @@ def test_execute_concat(self): def test_execute_mil_like(self): """Execute MIL-like Concat of Slices and Init.""" # Simulated MIL: in_proj = [z, x, B, C] - plan = ExprPlan(mappings={ - W("in_proj"): Concat(exprs=( - Init(shape=(4, 8), init_type="zeros"), # z - Slice(expr=Ref(key=W("v")), slices=((0, 2, None), (None, None, None))), # x - Slice(expr=Ref(key=W("k")), slices=((0, 2, None), (None, None, None))), # B - Slice(expr=Ref(key=W("q")), slices=((0, 4, None), (None, None, None))), # C - ), dim=0), - }) + plan = ExprPlan( + mappings={ + W("in_proj"): Concat( + exprs=( + Init(shape=(4, 8), init_type="zeros"), # z + Slice(expr=Ref(key=W("v")), slices=((0, 2, None), (None, None, None))), # x + Slice(expr=Ref(key=W("k")), slices=((0, 2, None), (None, None, None))), # B + Slice(expr=Ref(key=W("q")), slices=((0, 4, None), (None, None, None))), # C + ), + dim=0, + ), + } + ) sources = { W("q"): torch.ones(4, 8), @@ -583,11 +619,13 @@ def test_execute_mil_like(self): def test_streaming_execution(self): """Streaming executor processes all targets.""" - plan = ExprPlan(mappings={ - W("out1"): Ref(key=W("shared")), - W("out2"): Ref(key=W("shared")), - W("out3"): Ref(key=W("unique")), - }) + plan = ExprPlan( + mappings={ + W("out1"): Ref(key=W("shared")), + W("out2"): Ref(key=W("shared")), + W("out3"): Ref(key=W("unique")), + } + ) load_calls = [] @@ -858,25 +896,23 @@ def test_plan_dil_execution(self): key_dim = 64 value_dim = 64 - head_k_dim = 16 - head_v_dim = 16 conv_dim = 2 * key_dim + value_dim # 192 # Create attention weights with per-head distinctive values # Q: each head gets value (head_idx + 1) q_weight = torch.zeros(64, 64) for h in range(4): - q_weight[h*16:(h+1)*16, :] = float(h + 1) + q_weight[h * 16 : (h + 1) * 16, :] = float(h + 1) # K: each head gets value (head_idx + 1) * 10 k_weight = torch.zeros(64, 64) for h in range(4): - k_weight[h*16:(h+1)*16, :] = float((h + 1) * 10) + k_weight[h * 16 : (h + 1) * 16, :] = float((h + 1) * 10) # V: each head gets value (head_idx + 1) * 100 v_weight = torch.zeros(64, 64) for h in range(4): - v_weight[h*16:(h+1)*16, :] = float((h + 1) * 100) + v_weight[h * 16 : (h + 1) * 16, :] = float((h + 1) * 100) sources = { W("attn.q_proj.weight"): q_weight, @@ -894,30 +930,23 @@ def test_plan_dil_execution(self): # Q_all (rows 0-63): heads 0,1,2,3 concatenated for h in range(4): - assert torch.allclose( - in_proj_qkvz[h*16:(h+1)*16], - torch.full((16, 64), float(h + 1)) - ) + assert torch.allclose(in_proj_qkvz[h * 16 : (h + 1) * 16], torch.full((16, 64), float(h + 1))) # K_all (rows 64-127): heads 0,1,2,3 concatenated for h in range(4): assert torch.allclose( - in_proj_qkvz[key_dim + h*16:key_dim + (h+1)*16], - torch.full((16, 64), float((h + 1) * 10)) + in_proj_qkvz[key_dim + h * 16 : key_dim + (h + 1) * 16], torch.full((16, 64), float((h + 1) * 10)) ) # V_all (rows 128-191): heads 0,1,2,3 concatenated for h in range(4): assert torch.allclose( - in_proj_qkvz[2*key_dim + h*16:2*key_dim + (h+1)*16], - torch.full((16, 64), float((h + 1) * 100)) + in_proj_qkvz[2 * key_dim + h * 16 : 2 * key_dim + (h + 1) * 16], + torch.full((16, 64), float((h + 1) * 100)), ) # Z_all (rows 192-255): zeros - assert torch.allclose( - in_proj_qkvz[2*key_dim + value_dim:], - torch.zeros(value_dim, 64) - ) + assert torch.allclose(in_proj_qkvz[2 * key_dim + value_dim :], torch.zeros(value_dim, 64)) # in_proj_ba should be zeros in_proj_ba = result[W("in_proj_ba.weight")] @@ -971,17 +1000,17 @@ def test_plan_dil_execution_gqa(self): # Q: 4 heads, each with value (head_idx + 1) q_weight = torch.zeros(64, 64) for h in range(4): - q_weight[h*16:(h+1)*16, :] = float(h + 1) + q_weight[h * 16 : (h + 1) * 16, :] = float(h + 1) # K: 2 kv_heads, each with value (head_idx + 1) * 10 k_weight = torch.zeros(32, 64) for h in range(2): - k_weight[h*16:(h+1)*16, :] = float((h + 1) * 10) + k_weight[h * 16 : (h + 1) * 16, :] = float((h + 1) * 10) # V: 2 kv_heads, each with value (head_idx + 1) * 100 v_weight = torch.zeros(32, 64) for h in range(2): - v_weight[h*16:(h+1)*16, :] = float((h + 1) * 100) + v_weight[h * 16 : (h + 1) * 16, :] = float((h + 1) * 100) sources = { W("attn.q_proj.weight"): q_weight, @@ -1007,22 +1036,22 @@ def test_plan_dil_execution_gqa(self): # K_all (rows 32-63): k_heads 0,1 (maps to source K heads 0,1 via modulo) # k_head 0 → source K head 0 (value 10) - assert torch.allclose(in_proj_qkvz[key_dim:key_dim+16], torch.full((16, 64), 10.0)) + assert torch.allclose(in_proj_qkvz[key_dim : key_dim + 16], torch.full((16, 64), 10.0)) # k_head 1 → source K head 1 (value 20) - assert torch.allclose(in_proj_qkvz[key_dim+16:key_dim+32], torch.full((16, 64), 20.0)) + assert torch.allclose(in_proj_qkvz[key_dim + 16 : key_dim + 32], torch.full((16, 64), 20.0)) # V_all (rows 64-127): 4 v_heads, tiled from 2 source KV heads via modulo # v_head 0 → src_v_head 0 (value 100) - assert torch.allclose(in_proj_qkvz[2*key_dim:2*key_dim+16], torch.full((16, 64), 100.0)) + assert torch.allclose(in_proj_qkvz[2 * key_dim : 2 * key_dim + 16], torch.full((16, 64), 100.0)) # v_head 1 → src_v_head 1 (value 200) - assert torch.allclose(in_proj_qkvz[2*key_dim+16:2*key_dim+32], torch.full((16, 64), 200.0)) + assert torch.allclose(in_proj_qkvz[2 * key_dim + 16 : 2 * key_dim + 32], torch.full((16, 64), 200.0)) # v_head 2 → src_v_head 0 (value 100, tiled) - assert torch.allclose(in_proj_qkvz[2*key_dim+32:2*key_dim+48], torch.full((16, 64), 100.0)) + assert torch.allclose(in_proj_qkvz[2 * key_dim + 32 : 2 * key_dim + 48], torch.full((16, 64), 100.0)) # v_head 3 → src_v_head 1 (value 200, tiled) - assert torch.allclose(in_proj_qkvz[2*key_dim+48:2*key_dim+64], torch.full((16, 64), 200.0)) + assert torch.allclose(in_proj_qkvz[2 * key_dim + 48 : 2 * key_dim + 64], torch.full((16, 64), 200.0)) # Z_all (rows 128-191): zeros - assert torch.allclose(in_proj_qkvz[2*key_dim+value_dim:], torch.zeros(value_dim, 64)) + assert torch.allclose(in_proj_qkvz[2 * key_dim + value_dim :], torch.zeros(value_dim, 64)) def test_plan_kil_attention_to_kda(self): """AIK plan produces correct structure for attention → KDA conversion.""" @@ -1188,6 +1217,7 @@ def test_compose_llava_to_mamba(self, llava_pixtral_config, apriel2_config_stoch # Build surgery plan (need intermediate config) from fast_llm_external_models.apriel2.conversion.llava import convert_config + intermediate_config = convert_config(llava_pixtral_config) target_config = apriel2_config_stochastic.to_dict() surgery_plan = plan_surgery(intermediate_config, target_config) @@ -1210,6 +1240,7 @@ def test_execute_composed_pipeline(self, llava_pixtral_checkpoint): """ import json from pathlib import Path + from safetensors.torch import load_file # Load config @@ -1448,10 +1479,9 @@ def test_comprehensive_conversion_all_mixer_types(self, llava_pixtral_checkpoint the conversion produced correct keys and shapes. """ import json - from pathlib import Path from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config - from fast_llm_external_models.apriel2.convert import build_plan, convert + from fast_llm_external_models.apriel2.convert import convert from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForConditionalGeneration # Load LLaVA config @@ -1477,11 +1507,11 @@ def test_comprehensive_conversion_all_mixer_types(self, llava_pixtral_checkpoint "type": "pattern", "num_blocks": 5, "pattern": [ - "attn", # 0: attention → attention (passthrough) - "mamba", # 1: attention → mamba (MIL) - "gdn", # 2: attention → gated_delta_net (DIL) - "stoch_am", # 3: attention → stochastic(attention + mamba) - "stoch_sg", # 4: attention → stochastic(swa + gdn) + "attn", # 0: attention → attention (passthrough) + "mamba", # 1: attention → mamba (MIL) + "gdn", # 2: attention → gated_delta_net (DIL) + "stoch_am", # 3: attention → stochastic(attention + mamba) + "stoch_sg", # 4: attention → stochastic(swa + gdn) ], "blocks": { # Pure attention (passthrough from source) @@ -1609,7 +1639,8 @@ def test_comprehensive_conversion_all_mixer_types(self, llava_pixtral_checkpoint "type": "attention", "heads": llava_config["vision_config"]["num_attention_heads"], "head_groups": llava_config["vision_config"]["num_attention_heads"], - "head_size": llava_config["vision_config"]["hidden_size"] // llava_config["vision_config"]["num_attention_heads"], + "head_size": llava_config["vision_config"]["hidden_size"] + // llava_config["vision_config"]["num_attention_heads"], "add_linear_biases": False, "causal": False, "rotary": { @@ -1688,7 +1719,6 @@ def test_conversion_plan_targets_match_model_state_dict(self, llava_pixtral_conf This test validates the plan WITHOUT executing it, by comparing plan target keys against what the model expects. """ - import json from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config from fast_llm_external_models.apriel2.convert import build_plan @@ -1703,7 +1733,7 @@ def test_conversion_plan_targets_match_model_state_dict(self, llava_pixtral_conf expected_keys = set(model.state_dict().keys()) # Get plan target keys - plan_target_keys = set(str(k) for k in plan.target_keys()) + plan_target_keys = {str(k) for k in plan.target_keys()} # Compare missing_from_plan = expected_keys - plan_target_keys @@ -1763,20 +1793,23 @@ def test_plan_includes_enabled_attention_biases(self, source_config_with_bias): from fast_llm_external_models.apriel2.conversion.config import compose_configs from fast_llm_external_models.apriel2.conversion.converters import plan_surgery - target_config = compose_configs(source_config_with_bias, { - "decoder": { - "block": { - "mixer": { - "type": "stochastic", - "main_mixer_name": "attention", - "mixers": { - "attention": {"init": "transfer"}, + target_config = compose_configs( + source_config_with_bias, + { + "decoder": { + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"init": "transfer"}, + }, }, + "mlp": {"init": "transfer"}, }, - "mlp": {"init": "transfer"}, }, }, - }) + ) plan = plan_surgery(source_config_with_bias, target_config) mapping_strs = [str(k) for k in plan.mappings.keys()] @@ -1795,20 +1828,23 @@ def test_plan_excludes_disabled_attention_biases(self, source_config_with_bias): from fast_llm_external_models.apriel2.conversion.config import compose_configs from fast_llm_external_models.apriel2.conversion.converters import plan_surgery - target_config = compose_configs(source_config_with_bias, { - "decoder": { - "block": { - "mixer": { - "type": "stochastic", - "main_mixer_name": "attention", - "mixers": { - "attention": {"init": "transfer"}, + target_config = compose_configs( + source_config_with_bias, + { + "decoder": { + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"init": "transfer"}, + }, }, + "mlp": {"init": "transfer"}, }, - "mlp": {"init": "transfer"}, }, }, - }) + ) plan = plan_surgery(source_config_with_bias, target_config) mapping_strs = [str(k) for k in plan.mappings.keys()] @@ -1822,20 +1858,23 @@ def test_plan_includes_enabled_mlp_biases(self, source_config_with_bias): from fast_llm_external_models.apriel2.conversion.config import compose_configs from fast_llm_external_models.apriel2.conversion.converters import plan_surgery - target_config = compose_configs(source_config_with_bias, { - "decoder": { - "block": { - "mixer": { - "type": "stochastic", - "main_mixer_name": "attention", - "mixers": { - "attention": {"init": "transfer"}, + target_config = compose_configs( + source_config_with_bias, + { + "decoder": { + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"init": "transfer"}, + }, }, + "mlp": {"init": "transfer"}, }, - "mlp": {"init": "transfer"}, }, }, - }) + ) plan = plan_surgery(source_config_with_bias, target_config) mapping_strs = [str(k) for k in plan.mappings.keys()] @@ -1849,20 +1888,23 @@ def test_plan_excludes_disabled_mlp_biases(self, source_config_with_bias): from fast_llm_external_models.apriel2.conversion.config import compose_configs from fast_llm_external_models.apriel2.conversion.converters import plan_surgery - target_config = compose_configs(source_config_with_bias, { - "decoder": { - "block": { - "mixer": { - "type": "stochastic", - "main_mixer_name": "attention", - "mixers": { - "attention": {"init": "transfer"}, + target_config = compose_configs( + source_config_with_bias, + { + "decoder": { + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"init": "transfer"}, + }, }, + "mlp": {"init": "transfer"}, }, - "mlp": {"init": "transfer"}, }, }, - }) + ) plan = plan_surgery(source_config_with_bias, target_config) mapping_strs = [str(k) for k in plan.mappings.keys()] @@ -1903,10 +1945,7 @@ def test_plan_random_init_creates_init_expressions_for_bias(self, source_config_ plan = plan_surgery(source_config_with_bias, surgery) # Check that new_attention biases use Init expressions - new_mixer_bias_keys = [ - k for k in plan.mappings.keys() - if "new_attention" in str(k) and "bias" in str(k) - ] + new_mixer_bias_keys = [k for k in plan.mappings.keys() if "new_attention" in str(k) and "bias" in str(k)] assert len(new_mixer_bias_keys) > 0, "Should have bias mappings for new_attention" diff --git a/fast_llm_external_models/tests/test_apriel2/test_integration.py b/fast_llm_external_models/tests/test_apriel2/test_integration.py index b90f0774e..e84fa06ef 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_integration.py +++ b/fast_llm_external_models/tests/test_apriel2/test_integration.py @@ -20,20 +20,14 @@ import torch from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config -from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForCausalLM -from fast_llm_external_models.apriel2.conversion import ( - compose, - compose_configs, - execute, - plan_surgery, -) +from fast_llm_external_models.apriel2.conversion import compose, compose_configs, execute, plan_surgery from fast_llm_external_models.apriel2.conversion.expr import W from fast_llm_external_models.apriel2.conversion.qwen2.config import convert_config as convert_qwen2_config from fast_llm_external_models.apriel2.conversion.qwen2.plan import plan_qwen2_to_apriel2 +from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForCausalLM from .conftest import requires_fastllm - # ============================================================================= # Test Input Variations # ============================================================================= @@ -56,13 +50,11 @@ @pytest.fixture(scope="module") def qwen2_source(): """Load Qwen2.5-0.5B as the source/reference model.""" - from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig + from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer model_name = "Qwen/Qwen2.5-0.5B" tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) - model = AutoModelForCausalLM.from_pretrained( - model_name, torch_dtype=torch.float32, trust_remote_code=True - ) + model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32, trust_remote_code=True) config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) model.eval() @@ -139,11 +131,7 @@ def roundtrip_converted(supernet_converted, qwen2_source): if not torch.cuda.is_available(): pytest.skip("Roundtrip conversion requires CUDA (integration tests need realistic hardware)") - from fast_llm.engine.checkpoint.config import ( - CheckpointLoadConfig, - CheckpointSaveConfig, - FastLLMCheckpointFormat, - ) + from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, CheckpointSaveConfig, FastLLMCheckpointFormat from fast_llm.engine.checkpoint.convert import ConvertConfig from fast_llm.models.gpt.config import GPTModelConfig from fast_llm.models.gpt.conversion.config import Apriel2TextCheckpointFormat @@ -302,9 +290,9 @@ def test_logits_match(self, qwen2_source, converted_model, prompts, max_new_toke ).logits.cpu() max_diff = (ref_logits - test_logits).abs().max().item() - assert torch.allclose(ref_logits, test_logits, rtol=1e-4, atol=1e-4), ( - f"{stage} logits mismatch: max diff = {max_diff:.6f}" - ) + assert torch.allclose( + ref_logits, test_logits, rtol=1e-4, atol=1e-4 + ), f"{stage} logits mismatch: max diff = {max_diff:.6f}" @TEST_INPUTS def test_generation_match(self, qwen2_source, converted_model, prompts, max_new_tokens): diff --git a/fast_llm_external_models/tests/test_apriel2/test_mixer_equivalence.py b/fast_llm_external_models/tests/test_apriel2/test_mixer_equivalence.py index 1aa8a56d9..c6f3337e8 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_mixer_equivalence.py +++ b/fast_llm_external_models/tests/test_apriel2/test_mixer_equivalence.py @@ -28,15 +28,7 @@ import torch import torch.nn as nn -from fast_llm_external_models.apriel2.conversion import ( - Concat, - ExprPlan, - Ref, - Slice, - W, - execute, -) - +from fast_llm_external_models.apriel2.conversion import Concat, ExprPlan, Ref, Slice, W, execute # ============================================================================= # Shared Fixtures @@ -69,10 +61,10 @@ def hidden_size(request): @pytest.fixture( params=[ - pytest.param((8, 8, 32), id="mha-8h-32d"), # MHA: 8 heads, 8 kv heads, 32 head_dim - pytest.param((8, 4, 32), id="gqa-8h4kv-32d"), # GQA: 8 heads, 4 kv heads, 32 head_dim - pytest.param((8, 2, 64), id="gqa-8h2kv-64d"), # GQA: 8 heads, 2 kv heads, 64 head_dim - pytest.param((4, 1, 64), id="mqa-4h1kv-64d"), # MQA: 4 heads, 1 kv head, 64 head_dim + pytest.param((8, 8, 32), id="mha-8h-32d"), # MHA: 8 heads, 8 kv heads, 32 head_dim + pytest.param((8, 4, 32), id="gqa-8h4kv-32d"), # GQA: 8 heads, 4 kv heads, 32 head_dim + pytest.param((8, 2, 64), id="gqa-8h2kv-64d"), # GQA: 8 heads, 2 kv heads, 64 head_dim + pytest.param((4, 1, 64), id="mqa-4h1kv-64d"), # MQA: 4 heads, 1 kv head, 64 head_dim ] ) def attention_config(request): @@ -90,7 +82,7 @@ def attention_config(request): params=[ pytest.param((8, 4, 32, 32), id="8v-4k-32d"), # 8 value heads, 4 key heads, symmetric dims pytest.param((8, 2, 64, 64), id="8v-2k-64d"), # 8 value heads, 2 key heads, larger dims - pytest.param((4, 2, 32, 64), id="4v-2k-asym"), # Asymmetric key/value dims + pytest.param((4, 2, 32, 64), id="4v-2k-asym"), # Asymmetric key/value dims ] ) def gdn_config(request): @@ -100,9 +92,9 @@ def gdn_config(request): @pytest.fixture( params=[ - pytest.param((4, 8), id="4h-8d"), # 4 heads, 8 head_dim (small) - pytest.param((8, 16), id="8h-16d"), # 8 heads, 16 head_dim (medium) - pytest.param((4, 32), id="4h-32d"), # 4 heads, 32 head_dim (large head_dim) + pytest.param((4, 8), id="4h-8d"), # 4 heads, 8 head_dim (small) + pytest.param((8, 16), id="8h-16d"), # 8 heads, 16 head_dim (medium) + pytest.param((4, 32), id="4h-32d"), # 4 heads, 32 head_dim (large head_dim) ] ) def kda_config(request): @@ -283,9 +275,21 @@ def plan_qwen3next_gdn_to_apriel2( for g in range(num_k_heads): base = g * group_size q_slices.append(Slice(expr=qkvz_ref, slices=((base, base + head_k_dim, None), (None, None, None)))) - k_slices.append(Slice(expr=qkvz_ref, slices=((base + head_k_dim, base + 2 * head_k_dim, None), (None, None, None)))) - v_slices.append(Slice(expr=qkvz_ref, slices=((base + 2 * head_k_dim, base + 2 * head_k_dim + v_per_group, None), (None, None, None)))) - z_slices.append(Slice(expr=qkvz_ref, slices=((base + 2 * head_k_dim + v_per_group, base + group_size, None), (None, None, None)))) + k_slices.append( + Slice(expr=qkvz_ref, slices=((base + head_k_dim, base + 2 * head_k_dim, None), (None, None, None))) + ) + v_slices.append( + Slice( + expr=qkvz_ref, + slices=((base + 2 * head_k_dim, base + 2 * head_k_dim + v_per_group, None), (None, None, None)), + ) + ) + z_slices.append( + Slice( + expr=qkvz_ref, + slices=((base + 2 * head_k_dim + v_per_group, base + group_size, None), (None, None, None)), + ) + ) in_proj_qkvz_expr = Concat( exprs=( @@ -304,8 +308,15 @@ def plan_qwen3next_gdn_to_apriel2( b_slices, a_slices = [], [] for g in range(num_k_heads): base = g * ba_per_group - b_slices.append(Slice(expr=ba_ref, slices=((base, base + num_v_heads // num_k_heads, None), (None, None, None)))) - a_slices.append(Slice(expr=ba_ref, slices=((base + num_v_heads // num_k_heads, base + ba_per_group, None), (None, None, None)))) + b_slices.append( + Slice(expr=ba_ref, slices=((base, base + num_v_heads // num_k_heads, None), (None, None, None))) + ) + a_slices.append( + Slice( + expr=ba_ref, + slices=((base + num_v_heads // num_k_heads, base + ba_per_group, None), (None, None, None)), + ) + ) in_proj_ba_expr = Concat( exprs=(Concat(exprs=tuple(b_slices), dim=0), Concat(exprs=tuple(a_slices), dim=0)), @@ -565,6 +576,7 @@ def test_causal_vs_mistral( ): """Verify Apriel2Attention (causal) matches MistralAttention output.""" from transformers.models.mistral.modeling_mistral import MistralAttention, MistralRotaryEmbedding + from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2Attention mixer_config = apriel2_config.decoder["block"]["mixer"] @@ -593,13 +605,20 @@ def test_causal_vs_mistral( apriel2_attn.eval() with torch.no_grad(): - mistral_out = mistral_attn(hidden_states, position_embeddings=position_embeddings, attention_mask=causal_mask)[0] - apriel2_out = apriel2_attn(hidden_states, attention_mask=causal_mask, position_embeddings=position_embeddings)[0] + mistral_out = mistral_attn( + hidden_states, position_embeddings=position_embeddings, attention_mask=causal_mask + )[0] + apriel2_out = apriel2_attn( + hidden_states, attention_mask=causal_mask, position_embeddings=position_embeddings + )[0] rtol, atol = tolerance assert_close( - apriel2_out, mistral_out, rtol=rtol, atol=atol, - msg=f"Apriel2Attention vs MistralAttention (batch={batch_size}, seq={seq_len}, hidden={hidden_size})" + apriel2_out, + mistral_out, + rtol=rtol, + atol=atol, + msg=f"Apriel2Attention vs MistralAttention (batch={batch_size}, seq={seq_len}, hidden={hidden_size})", ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA") @@ -613,8 +632,9 @@ def test_noncausal_vs_pixtral( tolerance, ): """Verify Apriel2Attention (non-causal) matches PixtralAttention output.""" - from transformers.models.pixtral.modeling_pixtral import PixtralAttention, PixtralRotaryEmbedding from transformers.models.pixtral.configuration_pixtral import PixtralVisionConfig + from transformers.models.pixtral.modeling_pixtral import PixtralAttention, PixtralRotaryEmbedding + from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2TextConfig from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2Attention @@ -689,8 +709,11 @@ def test_noncausal_vs_pixtral( rtol, atol = tolerance assert_close( - apriel2_out, pixtral_out, rtol=rtol, atol=atol, - msg=f"Apriel2Attention (non-causal) vs PixtralAttention (batch={batch_size}, seq={seq_len})" + apriel2_out, + pixtral_out, + rtol=rtol, + atol=atol, + msg=f"Apriel2Attention (non-causal) vs PixtralAttention (batch={batch_size}, seq={seq_len})", ) @@ -737,6 +760,7 @@ def test_vs_qwen3next( ): """Verify Apriel2GatedDeltaNet matches Qwen3NextGatedDeltaNet output.""" from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextGatedDeltaNet + from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2GatedDeltaNet value_heads, key_heads, key_head_dim, value_head_dim = gdn_config @@ -758,8 +782,10 @@ def test_vs_qwen3next( # Transfer weights plan = plan_qwen3next_gdn_to_apriel2( - num_k_heads=key_heads, num_v_heads=value_heads, - head_k_dim=key_head_dim, head_v_dim=value_head_dim, + num_k_heads=key_heads, + num_v_heads=value_heads, + head_k_dim=key_head_dim, + head_v_dim=value_head_dim, ) source_weights = extract_module_weights(qwen_gdn) target_weights = execute(plan, source_weights, seed=seed) @@ -778,8 +804,11 @@ def test_vs_qwen3next( rtol, atol = tolerance assert_close( - apriel2_out, qwen_out, rtol=rtol, atol=atol, - msg=f"Apriel2GatedDeltaNet vs Qwen3NextGatedDeltaNet (batch={batch_size}, seq={seq_len})" + apriel2_out, + qwen_out, + rtol=rtol, + atol=atol, + msg=f"Apriel2GatedDeltaNet vs Qwen3NextGatedDeltaNet (batch={batch_size}, seq={seq_len})", ) @@ -803,6 +832,7 @@ def test_vs_fla( ): """Verify Apriel2 KimiDeltaAttention matches FLA KimiDeltaAttention output.""" from fla.layers.kda import KimiDeltaAttention as FLA_KDA + from fast_llm_external_models.apriel2.modeling_apriel2 import KimiDeltaAttention as Apriel2_KDA num_heads, head_dim = kda_config @@ -853,8 +883,11 @@ def test_vs_fla( rtol, atol = tolerance assert_close( - apriel2_out, fla_out, rtol=rtol, atol=atol, - msg=f"Apriel2 KDA vs FLA KDA (batch={batch_size}, seq={seq_len}, hidden={hidden_size})" + apriel2_out, + fla_out, + rtol=rtol, + atol=atol, + msg=f"Apriel2 KDA vs FLA KDA (batch={batch_size}, seq={seq_len}, hidden={hidden_size})", ) @@ -913,7 +946,4 @@ def test_gdn_fast_vs_slow(self, gdn_config, batch_size): slow_out = model(hidden_states)[0].clone() # Looser tolerance for kernel vs reference comparison - assert_close( - fast_out, slow_out, rtol=1e-3, atol=1e-3, - msg="GDN fast path (CUDA) vs slow path (PyTorch)" - ) + assert_close(fast_out, slow_out, rtol=1e-3, atol=1e-3, msg="GDN fast path (CUDA) vs slow path (PyTorch)") diff --git a/fast_llm_external_models/tests/test_apriel2/test_model_structure.py b/fast_llm_external_models/tests/test_apriel2/test_model_structure.py index 23856be30..56d2bc6a6 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_model_structure.py +++ b/fast_llm_external_models/tests/test_apriel2/test_model_structure.py @@ -1,9 +1,9 @@ """Tests for Apriel2 model structure and architecture validation.""" -import pytest import torch -from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForCausalLM + from fast_llm_external_models.apriel2.cache import Apriel2Cache, _AttentionCache, _SSMCache +from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForCausalLM class TestStochasticMixerStructure: @@ -14,20 +14,27 @@ def test_all_submixers_present(self, apriel2_config_all_mixers): model = Apriel2ForCausalLM(apriel2_config_all_mixers) stochastic_layer = model.model.decoder.blocks[1] # Layer 1 is the "all_mixers" layer - assert hasattr(stochastic_layer.mixer, 'mixers'), "Stochastic mixer should have 'mixers' attribute" + assert hasattr(stochastic_layer.mixer, "mixers"), "Stochastic mixer should have 'mixers' attribute" assert set(stochastic_layer.mixer.mixers.keys()) == { - 'attention', 'swa', 'mamba', 'gdn' + "attention", + "swa", + "mamba", + "gdn", }, "Stochastic mixer should contain all 4 configured mixer types" # Verify each mixer is the correct type from fast_llm_external_models.apriel2.modeling_apriel2 import ( - Apriel2Attention, Apriel2Mamba, Apriel2GatedDeltaNet + Apriel2Attention, + Apriel2GatedDeltaNet, + Apriel2Mamba, ) - assert isinstance(stochastic_layer.mixer.mixers['attention'], Apriel2Attention) - assert isinstance(stochastic_layer.mixer.mixers['swa'], Apriel2Attention) # SWA is Apriel2Attention with sliding_window - assert isinstance(stochastic_layer.mixer.mixers['mamba'], Apriel2Mamba) - assert isinstance(stochastic_layer.mixer.mixers['gdn'], Apriel2GatedDeltaNet) + assert isinstance(stochastic_layer.mixer.mixers["attention"], Apriel2Attention) + assert isinstance( + stochastic_layer.mixer.mixers["swa"], Apriel2Attention + ) # SWA is Apriel2Attention with sliding_window + assert isinstance(stochastic_layer.mixer.mixers["mamba"], Apriel2Mamba) + assert isinstance(stochastic_layer.mixer.mixers["gdn"], Apriel2GatedDeltaNet) def test_main_mixer_is_configured(self, apriel2_config_all_mixers): """Verify main_mixer_name is set correctly.""" @@ -44,7 +51,10 @@ def test_cache_has_all_submixer_slots(self, apriel2_config_all_mixers): assert isinstance(layer_cache, dict), "Stochastic layer cache should be a dict" assert set(layer_cache.keys()) == { - 'attention', 'swa', 'mamba', 'gdn' + "attention", + "swa", + "mamba", + "gdn", }, "Cache should have slots for all 4 mixers" def test_attention_mixers_use_attention_cache(self, apriel2_config_all_mixers): @@ -53,12 +63,12 @@ def test_attention_mixers_use_attention_cache(self, apriel2_config_all_mixers): layer_cache = cache.layers[1] # Attention-based mixers use AttentionCache - assert isinstance(layer_cache['attention'], _AttentionCache) - assert isinstance(layer_cache['swa'], _AttentionCache) + assert isinstance(layer_cache["attention"], _AttentionCache) + assert isinstance(layer_cache["swa"], _AttentionCache) # SSM-based mixers use SSMCache - assert isinstance(layer_cache['mamba'], _SSMCache) - assert isinstance(layer_cache['gdn'], _SSMCache) + assert isinstance(layer_cache["mamba"], _SSMCache) + assert isinstance(layer_cache["gdn"], _SSMCache) def test_parameter_counts_differ_by_config(self): """Different configs create models with different parameter counts.""" @@ -74,8 +84,10 @@ def test_parameter_counts_differ_by_config(self): } config_tiny = Apriel2Config( - vocab_size=100, hidden_size=64, - num_attention_heads=4, num_key_value_heads=2, + vocab_size=100, + hidden_size=64, + num_attention_heads=4, + num_key_value_heads=2, decoder={ "type": "fixed", "num_blocks": 2, @@ -88,8 +100,10 @@ def test_parameter_counts_differ_by_config(self): ) config_stochastic = Apriel2Config( - vocab_size=100, hidden_size=64, - num_attention_heads=4, num_key_value_heads=2, + vocab_size=100, + hidden_size=64, + num_attention_heads=4, + num_key_value_heads=2, decoder={ "type": "pattern", "num_blocks": 2, @@ -106,14 +120,14 @@ def test_parameter_counts_differ_by_config(self): "main_mixer_name": "attention", "mixers": { "attention": attn_config, - "mamba": {"type": "mamba", "conv_bias": True, "dt_proj_bias": True} - } + "mamba": {"type": "mamba", "conv_bias": True, "dt_proj_bias": True}, + }, }, "mlp": {"type": "mlp", "intermediate_size": 256, "gated": True}, "normalization": {"type": "rms_norm"}, - } - } - } + }, + }, + }, ) model_tiny = Apriel2ForCausalLM(config_tiny) @@ -122,8 +136,9 @@ def test_parameter_counts_differ_by_config(self): params_tiny = sum(p.numel() for p in model_tiny.parameters()) params_stochastic = sum(p.numel() for p in model_stochastic.parameters()) - assert params_stochastic > params_tiny, \ - "Stochastic mixer should have more parameters (has both attention and mamba)" + assert ( + params_stochastic > params_tiny + ), "Stochastic mixer should have more parameters (has both attention and mamba)" def test_weights_are_initialized(self, apriel2_config_all_mixers): """Verify model weights are initialized (not all zeros/constant).""" @@ -136,9 +151,7 @@ def test_weights_are_initialized(self, apriel2_config_all_mixers): # Basic sanity: at least some parameters should be non-zero non_zero_params = sum( - not torch.all(p == 0) - for mixer in stochastic_layer.mixer.mixers.values() - for p in mixer.parameters() + not torch.all(p == 0) for mixer in stochastic_layer.mixer.mixers.values() for p in mixer.parameters() ) assert non_zero_params > 0, "At least some mixer parameters should be non-zero" diff --git a/fast_llm_external_models/tests/test_apriel2/test_modeling.py b/fast_llm_external_models/tests/test_apriel2/test_modeling.py index 47c877d09..8e2f610bb 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_modeling.py +++ b/fast_llm_external_models/tests/test_apriel2/test_modeling.py @@ -2,19 +2,23 @@ import pytest import torch + from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForCausalLM class TestApriel2Modeling: """End-to-end tests for Apriel2 model with different configurations.""" - @pytest.mark.parametrize("config_name", [ - "apriel2_config_tiny", - "apriel2_config_stochastic", - "apriel2_config_multi_mixer", - "apriel2_config_all_mixers", # Tests all 4 mixer types - "apriel2_config_with_bias", # Tests per-layer bias and non-gated MLP - ]) + @pytest.mark.parametrize( + "config_name", + [ + "apriel2_config_tiny", + "apriel2_config_stochastic", + "apriel2_config_multi_mixer", + "apriel2_config_all_mixers", # Tests all 4 mixer types + "apriel2_config_with_bias", # Tests per-layer bias and non-gated MLP + ], + ) def test_model_end_to_end(self, config_name, request): """Test instantiation, forward pass, cache correctness, and generation. @@ -43,7 +47,7 @@ def test_model_end_to_end(self, config_name, request): # 2. Forward pass - basic shape validation outputs = model(input_ids, use_cache=False) assert outputs.logits.shape == (2, seq_len, config.vocab_size) - assert hasattr(outputs, 'logits') + assert hasattr(outputs, "logits") # 3. Verify cache is actually being used (not dormant) split_pos = 30 @@ -53,28 +57,23 @@ def test_model_end_to_end(self, config_name, request): assert outputs_part1.past_key_values is not None outputs_correct_cache = model( - input_ids[:, split_pos:split_pos+1], - past_key_values=outputs_part1.past_key_values, - use_cache=True + input_ids[:, split_pos : split_pos + 1], past_key_values=outputs_part1.past_key_values, use_cache=True ) # Test 1: Empty cache should give different results than filled cache # This verifies cache is being used at all from fast_llm_external_models.apriel2.cache import Apriel2Cache, _AttentionCache + empty_cache = Apriel2Cache(config) outputs_empty_cache = model( - input_ids[:, split_pos:split_pos+1], - past_key_values=empty_cache, - use_cache=True + input_ids[:, split_pos : split_pos + 1], past_key_values=empty_cache, use_cache=True ) - cache_affects_output = not torch.allclose( - outputs_correct_cache.logits, - outputs_empty_cache.logits, - atol=1e-3 - ) - assert cache_affects_output, f"Cache appears dormant for {config_name} - empty cache gives same results as filled cache" + cache_affects_output = not torch.allclose(outputs_correct_cache.logits, outputs_empty_cache.logits, atol=1e-3) + assert ( + cache_affects_output + ), f"Cache appears dormant for {config_name} - empty cache gives same results as filled cache" # Test 2: Corrupted cache (zeros) should give different results than correct cache # This verifies the actual cache VALUES are being used @@ -99,17 +98,15 @@ def test_model_end_to_end(self, config_name, request): corrupted_layer[name].value = torch.zeros_like(correct_sub.value) outputs_corrupted_cache = model( - input_ids[:, split_pos:split_pos+1], - past_key_values=corrupted_cache, - use_cache=True + input_ids[:, split_pos : split_pos + 1], past_key_values=corrupted_cache, use_cache=True ) cache_values_matter = not torch.allclose( - outputs_correct_cache.logits, - outputs_corrupted_cache.logits, - atol=1e-3 + outputs_correct_cache.logits, outputs_corrupted_cache.logits, atol=1e-3 ) - assert cache_values_matter, f"Cache values not used for {config_name} - zeroed cache gives same results as correct cache" + assert ( + cache_values_matter + ), f"Cache values not used for {config_name} - zeroed cache gives same results as correct cache" # 4. Cache correctness - validate cache produces same results as no-cache # Compute full sequence without cache @@ -118,18 +115,14 @@ def test_model_end_to_end(self, config_name, request): # Compute in two steps with cache outputs_part1 = model(input_ids[:, :split_pos], use_cache=True) outputs_part2 = model( - input_ids[:, split_pos:split_pos+1], - past_key_values=outputs_part1.past_key_values, - use_cache=True + input_ids[:, split_pos : split_pos + 1], past_key_values=outputs_part1.past_key_values, use_cache=True ) # Logits should match between cached and non-cached # Note: GPU execution with bfloat16/float16 has lower precision than CPU float32, # so we use a looser tolerance here. assert torch.allclose( - outputs_full.logits[:, split_pos, :], - outputs_part2.logits[:, 0, :], - atol=1e-3 + outputs_full.logits[:, split_pos, :], outputs_part2.logits[:, 0, :], atol=1e-3 ), f"Cache correctness failed for {config_name}: cached and non-cached logits differ" # 5. Generation - end-to-end validation diff --git a/fast_llm_external_models/tests/test_apriel2/test_plan_execution.py b/fast_llm_external_models/tests/test_apriel2/test_plan_execution.py index 9a98ec13b..ca0c8739f 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_plan_execution.py +++ b/fast_llm_external_models/tests/test_apriel2/test_plan_execution.py @@ -61,27 +61,27 @@ - Tests document the laws they verify in their docstrings """ +from functools import reduce + import pytest import torch -from functools import reduce from fast_llm_external_models.apriel2.conversion import ( + Concat, + ExprPlan, + Init, + Ref, + Slice, + W, compose, compose_configs, execute, plan_surgery, - ExprPlan, - W, - Ref, - Concat, - Slice, - Init, ) # Import shared helper from conftest from fast_llm_external_models.tests.test_apriel2.conftest import make_weights_for_config - # ============================================================================= # Fixtures: Use shared fixtures from conftest.py where possible # ============================================================================= @@ -125,7 +125,9 @@ def test_associativity(self, expr_type): p3 = ExprPlan(mappings={W("out"): Ref(key=W("part"))}) elif expr_type == "with_init": p1 = ExprPlan(mappings={W("x"): Ref(key=W("a"))}) - p2 = ExprPlan(mappings={W("y"): Concat(exprs=(Ref(key=W("x")), Init(shape=(5,), init_type="zeros")), dim=0)}) + p2 = ExprPlan( + mappings={W("y"): Concat(exprs=(Ref(key=W("x")), Init(shape=(5,), init_type="zeros")), dim=0)} + ) p3 = ExprPlan(mappings={W("z"): Ref(key=W("y"))}) left = compose(compose(p1, p2), p3) @@ -194,7 +196,7 @@ def test_functoriality( configs.append(compose_configs(configs[-1], s)) # Build incremental plans: Pₖ = plan_surgery(Cₖ₋₁, Cₖ) - plans = [plan_surgery(configs[i], configs[i+1]) for i in range(len(surgeries))] + plans = [plan_surgery(configs[i], configs[i + 1]) for i in range(len(surgeries))] # Compose all incremental plans composed_plan = reduce(compose, plans) @@ -208,12 +210,14 @@ def test_functoriality( direct_weights = execute(direct_plan, weights, seed=42) # Verify semantic equivalence - assert set(composed_weights.keys()) == set(direct_weights.keys()), \ - f"Key sets differ for chain_length={chain_length}, use_bias={use_bias}" + assert set(composed_weights.keys()) == set( + direct_weights.keys() + ), f"Key sets differ for chain_length={chain_length}, use_bias={use_bias}" for key in composed_weights: - assert torch.allclose(composed_weights[key], direct_weights[key], atol=1e-6), \ - f"Weight mismatch for {key} with chain_length={chain_length}, use_bias={use_bias}" + assert torch.allclose( + composed_weights[key], direct_weights[key], atol=1e-6 + ), f"Weight mismatch for {key} with chain_length={chain_length}, use_bias={use_bias}" @pytest.mark.parametrize("split_point", [1, 2]) def test_arbitrary_grouping( @@ -240,7 +244,7 @@ def test_arbitrary_grouping( configs.append(compose_configs(configs[-1], s)) # Build incremental plans - plans = [plan_surgery(configs[i], configs[i+1]) for i in range(3)] + plans = [plan_surgery(configs[i], configs[i + 1]) for i in range(3)] # Different groupings left_grouped = compose(compose(plans[0], plans[1]), plans[2]) @@ -296,21 +300,19 @@ def test_qkv_biases_preserved_through_chain( for s in surgeries: configs.append(compose_configs(configs[-1], s)) - plans = [plan_surgery(configs[i], configs[i+1]) for i in range(num_surgeries)] + plans = [plan_surgery(configs[i], configs[i + 1]) for i in range(num_surgeries)] final_plan = reduce(compose, plans) if len(plans) > 1 else plans[0] # Check bias keys present target_keys = {str(k) for k in final_plan.target_keys()} - assert any("q_proj.bias" in k for k in target_keys), \ - f"q_proj.bias missing after {num_surgeries} surgeries" - assert any("k_proj.bias" in k for k in target_keys), \ - f"k_proj.bias missing after {num_surgeries} surgeries" - assert any("v_proj.bias" in k for k in target_keys), \ - f"v_proj.bias missing after {num_surgeries} surgeries" + assert any("q_proj.bias" in k for k in target_keys), f"q_proj.bias missing after {num_surgeries} surgeries" + assert any("k_proj.bias" in k for k in target_keys), f"k_proj.bias missing after {num_surgeries} surgeries" + assert any("v_proj.bias" in k for k in target_keys), f"v_proj.bias missing after {num_surgeries} surgeries" # O bias should NOT be present (disabled in source) - assert not any("o_proj.bias" in k for k in target_keys), \ - f"o_proj.bias should not be present (disabled in source)" + assert not any( + "o_proj.bias" in k for k in target_keys + ), f"o_proj.bias should not be present (disabled in source)" def test_bias_values_preserved( self, @@ -331,8 +333,7 @@ def test_bias_values_preserved( dst_key = W(f"model.decoder.blocks.{i}.mixer.mixers.attention.q_proj.bias") assert dst_key in result, f"Missing {dst_key}" - assert torch.allclose(weights[src_key], result[dst_key]), \ - f"Bias values differ for block {i}" + assert torch.allclose(weights[src_key], result[dst_key]), f"Bias values differ for block {i}" # ============================================================================= @@ -379,8 +380,9 @@ def test_build_plan_preserves_inherited_fields( # Verify bias mappings in plan target_keys = {str(k) for k in plan.target_keys()} - assert any("q_proj.bias" in k for k in target_keys), \ - f"build_plan with {num_surgeries} surgeries missing q_proj.bias" + assert any( + "q_proj.bias" in k for k in target_keys + ), f"build_plan with {num_surgeries} surgeries missing q_proj.bias" # ============================================================================= @@ -435,7 +437,7 @@ def test_init_random_produces_init_expression(self, base_config_with_bias_dict): has_init_expr = True break # Also check inside Concat/other composite expressions - if hasattr(expr, 'exprs'): + if hasattr(expr, "exprs"): for sub in expr.exprs: if isinstance(sub, Init): has_init_expr = True @@ -533,9 +535,7 @@ def test_build_plan_strips_init_between_iterations(self, base_config_with_bias_d substitutes Ref → Init, but the semantics are correct: GDN is initialized once (in surgery 1), not re-randomized in surgery 2. """ - from fast_llm_external_models.apriel2.conversion import ( - compose_configs, strip_init_fields, plan_surgery, compose - ) + from fast_llm_external_models.apriel2.conversion import compose_configs, plan_surgery, strip_init_fields # Surgery 1: Add GDN with random init surgery1 = { @@ -581,8 +581,9 @@ def test_build_plan_strips_init_between_iterations(self, base_config_with_bias_d # Iteration 2: s1 has no init for GDN t2 = compose_configs(s1, surgery2) - assert t2["decoder"]["block"]["mixer"]["mixers"]["gdn"].get("init") is None, \ - "GDN should have no init in T2 (wasn't in surgery2, stripped from s1)" + assert ( + t2["decoder"]["block"]["mixer"]["mixers"]["gdn"].get("init") is None + ), "GDN should have no init in T2 (wasn't in surgery2, stripped from s1)" # plan_surgery(s1, t2) should use Ref for GDN (transfer, not random) plan2 = plan_surgery(s1, t2) diff --git a/setup.py b/setup.py index b273e077e..5c4d0def6 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,6 @@ -import sys -import re import pathlib +import re +import sys try: import pybind11 @@ -18,6 +18,7 @@ print(f"Error: setuptools version {_SETUPTOOLS_MIN_VERSION} " "or greater is required") sys.exit(1) + def get_version(): """Read version from fast_llm/__init__.py""" init_file = pathlib.Path(__file__).parent.joinpath("fast_llm", "__init__.py").read_text() @@ -26,6 +27,7 @@ def get_version(): return version_match.group(1) raise RuntimeError("Unable to find version string in fast_llm/__init__.py") + cpp_extension = setuptools.Extension( "fast_llm.csrc.data", sources=["fast_llm/csrc/data.cpp"], diff --git a/tests/data/test_tokenizer.py b/tests/data/test_tokenizer.py index f8f07ef0f..4e9e2fdd5 100644 --- a/tests/data/test_tokenizer.py +++ b/tests/data/test_tokenizer.py @@ -94,7 +94,35 @@ def test_validate_chat_template_with_markers(common_tokenizer): {"role": "user", "content": "C"}, {"role": "assistant", "content": "D"}, ], - [49152, 27, 789, 29, 32, 750, 789, 2293, 17822, 29, 33, 750, 17822, 2293, 789, 29, 34, 750, 789, 2293, 17822, 29, 35, 750, 17822, 29, 49152], + [ + 49152, + 27, + 789, + 29, + 32, + 750, + 789, + 2293, + 17822, + 29, + 33, + 750, + 17822, + 2293, + 789, + 29, + 34, + 750, + 789, + 2293, + 17822, + 29, + 35, + 750, + 17822, + 29, + 49152, + ], [(0, 7), (14, 19), (26, 27)], ), # System + user + assistant: full assistant turn trainable @@ -105,7 +133,31 @@ def test_validate_chat_template_with_markers(common_tokenizer): {"role": "user", "content": "Hi"}, {"role": "assistant", "content": "Hello"}, ], - [49152, 27, 3144, 29, 5815, 1139, 44569, 6928, 3144, 2293, 789, 29, 16946, 750, 789, 2293, 17822, 29, 7371, 750, 17822, 29, 49152], + [ + 49152, + 27, + 3144, + 29, + 5815, + 1139, + 44569, + 6928, + 3144, + 2293, + 789, + 29, + 16946, + 750, + 789, + 2293, + 17822, + 29, + 7371, + 750, + 17822, + 29, + 49152, + ], [(0, 15), (22, 23)], ), # User only: no trainable tokens @@ -127,7 +179,93 @@ def test_validate_chat_template_with_markers(common_tokenizer): {"role": "user", "content": "And Italy?"}, {"role": "assistant", "content": "The capital of Italy is Rome."}, ], - [49152, 27, 3144, 29, 5815, 1139, 373, 44569, 2424, 11886, 954, 15737, 14516, 6928, 3144, 2293, 789, 29, 13938, 438, 331, 25016, 457, 12409, 562, 35838, 789, 2293, 17822, 29, 2111, 25016, 457, 12409, 562, 438, 4235, 280, 6928, 17822, 2293, 789, 29, 13938, 5028, 759, 42226, 35838, 789, 2293, 17822, 29, 2111, 25016, 457, 759, 42226, 438, 29784, 3556, 6928, 17822, 2293, 789, 29, 1996, 4413, 3326, 35838, 789, 2293, 17822, 29, 2111, 25016, 457, 4413, 3326, 438, 613, 1361, 6928, 17822, 29, 49152], + [ + 49152, + 27, + 3144, + 29, + 5815, + 1139, + 373, + 44569, + 2424, + 11886, + 954, + 15737, + 14516, + 6928, + 3144, + 2293, + 789, + 29, + 13938, + 438, + 331, + 25016, + 457, + 12409, + 562, + 35838, + 789, + 2293, + 17822, + 29, + 2111, + 25016, + 457, + 12409, + 562, + 438, + 4235, + 280, + 6928, + 17822, + 2293, + 789, + 29, + 13938, + 5028, + 759, + 42226, + 35838, + 789, + 2293, + 17822, + 29, + 2111, + 25016, + 457, + 759, + 42226, + 438, + 29784, + 3556, + 6928, + 17822, + 2293, + 789, + 29, + 1996, + 4413, + 3326, + 35838, + 789, + 2293, + 17822, + 29, + 2111, + 25016, + 457, + 4413, + 3326, + 438, + 613, + 1361, + 6928, + 17822, + 29, + 49152, + ], [(0, 27), (41, 49), (63, 70), (84, 85)], ), ), From b1b0c31c0e9e0dc1ad2c58a1d3b6e9b7e5fa0081 Mon Sep 17 00:00:00 2001 From: Torsten Scholak Date: Sun, 21 Dec 2025 03:29:43 +0000 Subject: [PATCH 15/25] Add forward KL evaluator for teacher trace evaluation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a new evaluator type that computes forward KL divergence by comparing student log-probs against pre-computed teacher log-probs from a HuggingFace dataset of traces. The evaluator bypasses Fast-LLM's data pipeline and loads traces directly, making it suitable for monitoring distillation quality during training. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- fast_llm/engine/evaluation/config.py | 45 +++++ .../engine/evaluation/forward_kl/__init__.py | 0 .../engine/evaluation/forward_kl/evaluator.py | 161 ++++++++++++++++++ 3 files changed, 206 insertions(+) create mode 100644 fast_llm/engine/evaluation/forward_kl/__init__.py create mode 100644 fast_llm/engine/evaluation/forward_kl/evaluator.py diff --git a/fast_llm/engine/evaluation/config.py b/fast_llm/engine/evaluation/config.py index df7ab0f51..4ae39e03d 100644 --- a/fast_llm/engine/evaluation/config.py +++ b/fast_llm/engine/evaluation/config.py @@ -8,6 +8,7 @@ if typing.TYPE_CHECKING: from fast_llm.engine.evaluation.evaluator import Evaluator, EvaluatorLmEval, LossEvaluator + from fast_llm.engine.evaluation.forward_kl.evaluator import ForwardKLEvaluator @config_class() @@ -119,3 +120,47 @@ def get_evaluator( from fast_llm.engine.evaluation.lm_eval.evaluator import LmEvalEvaluator return LmEvalEvaluator(name, self, batch_config, data_load_num_proc, train_iters) + + +@config_class(dynamic_type={EvaluatorConfig: "forward_kl"}) +class ForwardKLEvaluatorConfig(EvaluatorConfig): + _abstract: typing.ClassVar[bool] = False + + dataset_path: str | None = Field( + default=None, + desc="HuggingFace dataset path containing teacher traces.", + hint=FieldHint.core, + ) + task: str | None = Field( + default=None, + desc="Dataset configuration/task name.", + hint=FieldHint.optional, + ) + num_samples: int | None = Field( + default=None, + desc="Maximum number of traces to evaluate. None for all.", + hint=FieldHint.optional, + valid=skip_valid_if_none(check_field(Assert.gt, 0)), + ) + batch_size: int = Field( + default=8, + desc="Batch size for forward passes.", + hint=FieldHint.performance, + valid=check_field(Assert.gt, 0), + ) + trust_remote_code: bool = Field( + default=False, + desc="Trust remote code when loading dataset.", + hint=FieldHint.optional, + ) + + def get_evaluator( + self, + name: str, + batch_config: BatchConfig, + data_load_num_proc: int, + train_iters: int | None = None, + ) -> "ForwardKLEvaluator": + from fast_llm.engine.evaluation.forward_kl.evaluator import ForwardKLEvaluator + + return ForwardKLEvaluator(name, self, batch_config, data_load_num_proc, train_iters) diff --git a/fast_llm/engine/evaluation/forward_kl/__init__.py b/fast_llm/engine/evaluation/forward_kl/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fast_llm/engine/evaluation/forward_kl/evaluator.py b/fast_llm/engine/evaluation/forward_kl/evaluator.py new file mode 100644 index 000000000..298e7204c --- /dev/null +++ b/fast_llm/engine/evaluation/forward_kl/evaluator.py @@ -0,0 +1,161 @@ +import logging +import typing + +import datasets +import torch +import torch.nn.functional as F + +from fast_llm.core.distributed import safe_barrier +from fast_llm.data.data.abstract import Data +from fast_llm.engine.config_utils.run import Run, log_main_rank +from fast_llm.engine.distributed.config import PhaseType +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.engine.evaluation.evaluator import ( + EvaluationMetrics, + Evaluator, + EvaluatorSamplingParameters, + TrainingProgress, +) +from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel +from fast_llm.engine.schedule.runner import ScheduleRunner + +if typing.TYPE_CHECKING: + from fast_llm.engine.evaluation.config import ForwardKLEvaluatorConfig + from fast_llm.engine.inference.huggingface import HuggingfacePreTrainedModel + +logger = logging.getLogger(__name__) + + +class ForwardKLEvaluator[ConfigType: "ForwardKLEvaluatorConfig"](Evaluator[ConfigType]): + _hf_model: "HuggingfacePreTrainedModel" = None + + def setup( + self, + distributed: Distributed, + run: Run, + multi_stage: FastLLMModel, + runner: ScheduleRunner, + data: Data, + phase: PhaseType, + ) -> None: + super().setup(distributed, run, multi_stage, runner, data, phase) + + self._hf_model = self._multi_stage.config_class.get_huggingface_model_for_causal_lm_class()( + self._multi_stage, runner=self._runner + ) + self._is_setup = True + + def get_sampling_parameters(self) -> EvaluatorSamplingParameters | None: + return None + + def run( + self, + training_progress: TrainingProgress | None = None, + run_index: int | None = None, + ) -> EvaluationMetrics: + assert self._is_setup + + safe_barrier(self._distributed.world_group, f"forward_kl_{self._name} begin") + + traces = self._load_traces() + if len(traces) == 0: + return EvaluationMetrics() + + forward_kl, num_traces = self._compute_forward_kl(traces) + + safe_barrier(self._distributed.world_group, f"forward_kl_{self._name} end") + + metrics = { + f"validation.{self._name}": { + "forward_kl": forward_kl, + "num_traces": num_traces, + } + } + + if training_progress is not None: + metrics[f"validation.{self._name}"]["iteration"] = training_progress.completed_steps + + formatted = f"Forward KL ({self._name}): {forward_kl:.4f} ({num_traces} traces)" + log_main_rank(formatted) + + return EvaluationMetrics(metrics, formatted) + + def _load_traces(self) -> datasets.Dataset: + if self._config.dataset_path is None: + return [] + + return datasets.load_dataset( + self._config.dataset_path, + name=self._config.task, + split="validation", + trust_remote_code=self._config.trust_remote_code, + ) + + @torch.inference_mode() + def _compute_forward_kl(self, traces: datasets.Dataset) -> tuple[float, int]: + device = self._hf_model.device + total_kl = 0.0 + num_traces = 0 + + num_samples = min(len(traces), self._config.num_samples) if self._config.num_samples else len(traces) + + for i in range(0, num_samples, self._config.batch_size): + batch_end = min(i + self._config.batch_size, num_samples) + batch = traces.select(range(i, batch_end)) + + student_log_probs = self._compute_batch_log_probs(batch, device) + + for j, trace in enumerate(batch): + teacher_lp = trace["teacher_log_prob"] + student_lp = student_log_probs[j] + total_kl += teacher_lp - student_lp + num_traces += 1 + + torch.cuda.empty_cache() + + return total_kl / num_traces if num_traces > 0 else 0.0, num_traces + + def _compute_batch_log_probs(self, batch: datasets.Dataset, device: torch.device) -> list[float]: + max_len = max(len(t["prompt_tokens"]) + len(t["completion_tokens"]) for t in batch) + pad_token_id = getattr(self._hf_model.config, "pad_token_id", 0) or 0 + + input_ids_list = [] + attention_mask_list = [] + prompt_lengths = [] + completion_lengths = [] + + for trace in batch: + prompt = trace["prompt_tokens"] + completion = trace["completion_tokens"] + full = prompt + completion + padding = [pad_token_id] * (max_len - len(full)) + + input_ids_list.append(full + padding) + attention_mask_list.append([1] * len(full) + [0] * len(padding)) + prompt_lengths.append(len(prompt)) + completion_lengths.append(len(completion)) + + input_ids = torch.tensor(input_ids_list, device=device) + attention_mask = torch.tensor(attention_mask_list, device=device) + + output = self._hf_model( + input_ids=input_ids, + attention_mask=attention_mask, + use_cache=False, + return_dict=True, + ) + logits = output.logits + + results = [] + for idx in range(len(batch)): + prompt_len = prompt_lengths[idx] + completion_len = completion_lengths[idx] + + pred_logits = logits[idx, prompt_len - 1 : prompt_len + completion_len - 1] + targets = input_ids[idx, prompt_len : prompt_len + completion_len] + + log_probs = F.log_softmax(pred_logits.float(), dim=-1) + token_log_probs = log_probs.gather(-1, targets.unsqueeze(-1)).squeeze(-1) + results.append(token_log_probs.sum().item()) + + return results From c774cec9fb0e465a10a8963c02e077cdab41c87a Mon Sep 17 00:00:00 2001 From: Torsten Scholak Date: Sun, 21 Dec 2025 04:05:05 +0000 Subject: [PATCH 16/25] Refactor ForwardKLEvaluator to use InferenceRunner MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace HuggingFace wrapper with native Fast-LLM inference path: - Use InferenceRunner for forward passes instead of HF model wrapper - Create LanguageModelBatch from trace data with proper padding - Handle variable-length sequences via TokenSample lengths - Use preprocess_batch for attention mask handling This approach works for all model types including linear attention. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- .../engine/evaluation/forward_kl/evaluator.py | 89 +++++++++---------- 1 file changed, 42 insertions(+), 47 deletions(-) diff --git a/fast_llm/engine/evaluation/forward_kl/evaluator.py b/fast_llm/engine/evaluation/forward_kl/evaluator.py index 298e7204c..058c7a25c 100644 --- a/fast_llm/engine/evaluation/forward_kl/evaluator.py +++ b/fast_llm/engine/evaluation/forward_kl/evaluator.py @@ -1,33 +1,32 @@ import logging import typing -import datasets import torch import torch.nn.functional as F from fast_llm.core.distributed import safe_barrier from fast_llm.data.data.abstract import Data +from fast_llm.data.sample.language_model import LanguageModelBatch, LanguageModelSample +from fast_llm.data.sample.token import TokenSample from fast_llm.engine.config_utils.run import Run, log_main_rank from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.engine.evaluation.config import ForwardKLEvaluatorConfig from fast_llm.engine.evaluation.evaluator import ( EvaluationMetrics, Evaluator, EvaluatorSamplingParameters, TrainingProgress, ) +from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.engine.schedule.runner import ScheduleRunner -if typing.TYPE_CHECKING: - from fast_llm.engine.evaluation.config import ForwardKLEvaluatorConfig - from fast_llm.engine.inference.huggingface import HuggingfacePreTrainedModel - logger = logging.getLogger(__name__) -class ForwardKLEvaluator[ConfigType: "ForwardKLEvaluatorConfig"](Evaluator[ConfigType]): - _hf_model: "HuggingfacePreTrainedModel" = None +class ForwardKLEvaluator[ConfigType: ForwardKLEvaluatorConfig](Evaluator[ConfigType]): + _inference_runner: InferenceRunner def setup( self, @@ -39,10 +38,8 @@ def setup( phase: PhaseType, ) -> None: super().setup(distributed, run, multi_stage, runner, data, phase) - - self._hf_model = self._multi_stage.config_class.get_huggingface_model_for_causal_lm_class()( - self._multi_stage, runner=self._runner - ) + self._inference_runner = InferenceRunner(self._multi_stage, runner=self._runner) + self._inference_runner.setup() self._is_setup = True def get_sampling_parameters(self) -> EvaluatorSamplingParameters | None: @@ -55,16 +52,18 @@ def run( ) -> EvaluationMetrics: assert self._is_setup - safe_barrier(self._distributed.world_group, f"forward_kl_{self._name} begin") - - traces = self._load_traces() - if len(traces) == 0: + if self._config.dataset_path is None: return EvaluationMetrics() - forward_kl, num_traces = self._compute_forward_kl(traces) + safe_barrier(self._distributed.world_group, f"forward_kl_{self._name} begin") + + forward_kl, num_traces = self._compute_forward_kl() safe_barrier(self._distributed.world_group, f"forward_kl_{self._name} end") + if num_traces == 0: + return EvaluationMetrics() + metrics = { f"validation.{self._name}": { "forward_kl": forward_kl, @@ -80,47 +79,37 @@ def run( return EvaluationMetrics(metrics, formatted) - def _load_traces(self) -> datasets.Dataset: - if self._config.dataset_path is None: - return [] + @torch.inference_mode() + def _compute_forward_kl(self) -> tuple[float, int]: + import datasets - return datasets.load_dataset( + traces = datasets.load_dataset( self._config.dataset_path, name=self._config.task, split="validation", trust_remote_code=self._config.trust_remote_code, ) - @torch.inference_mode() - def _compute_forward_kl(self, traces: datasets.Dataset) -> tuple[float, int]: - device = self._hf_model.device total_kl = 0.0 num_traces = 0 - num_samples = min(len(traces), self._config.num_samples) if self._config.num_samples else len(traces) for i in range(0, num_samples, self._config.batch_size): - batch_end = min(i + self._config.batch_size, num_samples) - batch = traces.select(range(i, batch_end)) - - student_log_probs = self._compute_batch_log_probs(batch, device) + batch = [traces[j] for j in range(i, min(i + self._config.batch_size, num_samples))] + student_log_probs = self._compute_batch_log_probs(batch) for j, trace in enumerate(batch): - teacher_lp = trace["teacher_log_prob"] - student_lp = student_log_probs[j] - total_kl += teacher_lp - student_lp + total_kl += trace["teacher_log_prob"] - student_log_probs[j] num_traces += 1 torch.cuda.empty_cache() return total_kl / num_traces if num_traces > 0 else 0.0, num_traces - def _compute_batch_log_probs(self, batch: datasets.Dataset, device: torch.device) -> list[float]: + def _compute_batch_log_probs(self, batch: list[dict[str, typing.Any]]) -> list[float]: max_len = max(len(t["prompt_tokens"]) + len(t["completion_tokens"]) for t in batch) - pad_token_id = getattr(self._hf_model.config, "pad_token_id", 0) or 0 - input_ids_list = [] - attention_mask_list = [] + samples = [] prompt_lengths = [] completion_lengths = [] @@ -128,23 +117,29 @@ def _compute_batch_log_probs(self, batch: datasets.Dataset, device: torch.device prompt = trace["prompt_tokens"] completion = trace["completion_tokens"] full = prompt + completion - padding = [pad_token_id] * (max_len - len(full)) + actual_len = len(full) + pad_len = max_len - actual_len - input_ids_list.append(full + padding) - attention_mask_list.append([1] * len(full) + [0] * len(padding)) + tokens = torch.tensor(full + [0] * pad_len, dtype=torch.int64) + samples.append(LanguageModelSample(TokenSample(tokens, lengths=[actual_len]))) prompt_lengths.append(len(prompt)) completion_lengths.append(len(completion)) - input_ids = torch.tensor(input_ids_list, device=device) - attention_mask = torch.tensor(attention_mask_list, device=device) + lm_batch = LanguageModelBatch.from_samples(samples) - output = self._hf_model( - input_ids=input_ids, - attention_mask=attention_mask, - use_cache=False, - return_dict=True, + preprocessed = self._multi_stage.base_model.preprocess_batch( + lm_batch, + phase=PhaseType.inference, + iteration=0, ) - logits = output.logits + + for input_, kwargs in preprocessed: + self._inference_runner.forward(input_, kwargs) + logits = kwargs["logits"] + + sequence_first = kwargs.get("sequence_first", False) + if sequence_first: + logits = logits.transpose(0, 1) results = [] for idx in range(len(batch)): @@ -152,7 +147,7 @@ def _compute_batch_log_probs(self, batch: datasets.Dataset, device: torch.device completion_len = completion_lengths[idx] pred_logits = logits[idx, prompt_len - 1 : prompt_len + completion_len - 1] - targets = input_ids[idx, prompt_len : prompt_len + completion_len] + targets = lm_batch.tokens.tokens[idx, prompt_len : prompt_len + completion_len] log_probs = F.log_softmax(pred_logits.float(), dim=-1) token_log_probs = log_probs.gather(-1, targets.unsqueeze(-1)).squeeze(-1) From 565d137ae702833a6bc84387e20b6621d23aa7c6 Mon Sep 17 00:00:00 2001 From: Torsten Scholak Date: Sun, 21 Dec 2025 04:15:39 +0000 Subject: [PATCH 17/25] Add sequence length handling and global_logits support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add max_sequence_length config field (defaults to model's position embedding limit) - Skip traces exceeding max length with warning and count - Set global_logits=True for correct tensor-parallel behavior - Report number of skipped traces in output 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- fast_llm/engine/evaluation/config.py | 6 ++++ .../engine/evaluation/forward_kl/evaluator.py | 34 ++++++++++++++++--- 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/fast_llm/engine/evaluation/config.py b/fast_llm/engine/evaluation/config.py index 4ae39e03d..b42fc1bc2 100644 --- a/fast_llm/engine/evaluation/config.py +++ b/fast_llm/engine/evaluation/config.py @@ -148,6 +148,12 @@ class ForwardKLEvaluatorConfig(EvaluatorConfig): hint=FieldHint.performance, valid=check_field(Assert.gt, 0), ) + max_sequence_length: int | None = Field( + default=None, + desc="Maximum sequence length for traces. If None, uses model's position embedding limit.", + hint=FieldHint.optional, + valid=skip_valid_if_none(check_field(Assert.gt, 0)), + ) trust_remote_code: bool = Field( default=False, desc="Trust remote code when loading dataset.", diff --git a/fast_llm/engine/evaluation/forward_kl/evaluator.py b/fast_llm/engine/evaluation/forward_kl/evaluator.py index 058c7a25c..c2719dfcc 100644 --- a/fast_llm/engine/evaluation/forward_kl/evaluator.py +++ b/fast_llm/engine/evaluation/forward_kl/evaluator.py @@ -27,6 +27,7 @@ class ForwardKLEvaluator[ConfigType: ForwardKLEvaluatorConfig](Evaluator[ConfigType]): _inference_runner: InferenceRunner + _max_sequence_length: int def setup( self, @@ -40,6 +41,12 @@ def setup( super().setup(distributed, run, multi_stage, runner, data, phase) self._inference_runner = InferenceRunner(self._multi_stage, runner=self._runner) self._inference_runner.setup() + + if self._config.max_sequence_length is not None: + self._max_sequence_length = self._config.max_sequence_length + else: + self._max_sequence_length = self._multi_stage.base_model._config.embeddings.num_position_embeddings + self._is_setup = True def get_sampling_parameters(self) -> EvaluatorSamplingParameters | None: @@ -57,7 +64,7 @@ def run( safe_barrier(self._distributed.world_group, f"forward_kl_{self._name} begin") - forward_kl, num_traces = self._compute_forward_kl() + forward_kl, num_traces, num_skipped = self._compute_forward_kl() safe_barrier(self._distributed.world_group, f"forward_kl_{self._name} end") @@ -75,12 +82,14 @@ def run( metrics[f"validation.{self._name}"]["iteration"] = training_progress.completed_steps formatted = f"Forward KL ({self._name}): {forward_kl:.4f} ({num_traces} traces)" + if num_skipped > 0: + formatted += f" [{num_skipped} skipped]" log_main_rank(formatted) return EvaluationMetrics(metrics, formatted) @torch.inference_mode() - def _compute_forward_kl(self) -> tuple[float, int]: + def _compute_forward_kl(self) -> tuple[float, int, int]: import datasets traces = datasets.load_dataset( @@ -92,10 +101,26 @@ def _compute_forward_kl(self) -> tuple[float, int]: total_kl = 0.0 num_traces = 0 + num_skipped = 0 num_samples = min(len(traces), self._config.num_samples) if self._config.num_samples else len(traces) for i in range(0, num_samples, self._config.batch_size): - batch = [traces[j] for j in range(i, min(i + self._config.batch_size, num_samples))] + batch_indices = range(i, min(i + self._config.batch_size, num_samples)) + batch = [] + for j in batch_indices: + trace = traces[j] + trace_len = len(trace["prompt_tokens"]) + len(trace["completion_tokens"]) + if trace_len > self._max_sequence_length: + logger.warning( + f"Skipping trace {j}: length {trace_len} exceeds max {self._max_sequence_length}" + ) + num_skipped += 1 + continue + batch.append(trace) + + if not batch: + continue + student_log_probs = self._compute_batch_log_probs(batch) for j, trace in enumerate(batch): @@ -104,7 +129,7 @@ def _compute_forward_kl(self) -> tuple[float, int]: torch.cuda.empty_cache() - return total_kl / num_traces if num_traces > 0 else 0.0, num_traces + return total_kl / num_traces if num_traces > 0 else 0.0, num_traces, num_skipped def _compute_batch_log_probs(self, batch: list[dict[str, typing.Any]]) -> list[float]: max_len = max(len(t["prompt_tokens"]) + len(t["completion_tokens"]) for t in batch) @@ -134,6 +159,7 @@ def _compute_batch_log_probs(self, batch: list[dict[str, typing.Any]]) -> list[f ) for input_, kwargs in preprocessed: + kwargs["global_logits"] = True self._inference_runner.forward(input_, kwargs) logits = kwargs["logits"] From 90e32005a1516dde99de1259408a1c37042a2a9e Mon Sep 17 00:00:00 2001 From: Torsten Scholak Date: Sun, 21 Dec 2025 04:16:50 +0000 Subject: [PATCH 18/25] Make max_sequence_length mandatory with default 2048 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- fast_llm/engine/evaluation/config.py | 10 +++++----- fast_llm/engine/evaluation/forward_kl/evaluator.py | 7 +------ 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/fast_llm/engine/evaluation/config.py b/fast_llm/engine/evaluation/config.py index b42fc1bc2..d98b12763 100644 --- a/fast_llm/engine/evaluation/config.py +++ b/fast_llm/engine/evaluation/config.py @@ -148,11 +148,11 @@ class ForwardKLEvaluatorConfig(EvaluatorConfig): hint=FieldHint.performance, valid=check_field(Assert.gt, 0), ) - max_sequence_length: int | None = Field( - default=None, - desc="Maximum sequence length for traces. If None, uses model's position embedding limit.", - hint=FieldHint.optional, - valid=skip_valid_if_none(check_field(Assert.gt, 0)), + max_sequence_length: int = Field( + default=2048, + desc="Maximum sequence length for traces.", + hint=FieldHint.core, + valid=check_field(Assert.gt, 0), ) trust_remote_code: bool = Field( default=False, diff --git a/fast_llm/engine/evaluation/forward_kl/evaluator.py b/fast_llm/engine/evaluation/forward_kl/evaluator.py index c2719dfcc..66f8bdacd 100644 --- a/fast_llm/engine/evaluation/forward_kl/evaluator.py +++ b/fast_llm/engine/evaluation/forward_kl/evaluator.py @@ -41,12 +41,7 @@ def setup( super().setup(distributed, run, multi_stage, runner, data, phase) self._inference_runner = InferenceRunner(self._multi_stage, runner=self._runner) self._inference_runner.setup() - - if self._config.max_sequence_length is not None: - self._max_sequence_length = self._config.max_sequence_length - else: - self._max_sequence_length = self._multi_stage.base_model._config.embeddings.num_position_embeddings - + self._max_sequence_length = self._config.max_sequence_length self._is_setup = True def get_sampling_parameters(self) -> EvaluatorSamplingParameters | None: From 66ceee16a62a0e50c7e9372eefbc1eb930692d48 Mon Sep 17 00:00:00 2001 From: Torsten Scholak Date: Sun, 21 Dec 2025 05:17:14 +0000 Subject: [PATCH 19/25] Add distributed training support to ForwardKLEvaluator MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add full support for TP, SP, PP, and DP parallelism modes - Use training's sequence_length instead of separate max_sequence_length - Use GPTBatchConfig for proper SP sequence splitting - Add HuggingFace dataset sharding for efficient DP distribution - Add all_reduce across data_group and pipeline_group - Fix device mismatch bug (move targets to GPU) - Use AttentionKwargs.sequence_first constant 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- fast_llm/engine/evaluation/config.py | 6 - .../engine/evaluation/forward_kl/evaluator.py | 124 +++++++++++++----- 2 files changed, 93 insertions(+), 37 deletions(-) diff --git a/fast_llm/engine/evaluation/config.py b/fast_llm/engine/evaluation/config.py index d98b12763..4ae39e03d 100644 --- a/fast_llm/engine/evaluation/config.py +++ b/fast_llm/engine/evaluation/config.py @@ -148,12 +148,6 @@ class ForwardKLEvaluatorConfig(EvaluatorConfig): hint=FieldHint.performance, valid=check_field(Assert.gt, 0), ) - max_sequence_length: int = Field( - default=2048, - desc="Maximum sequence length for traces.", - hint=FieldHint.core, - valid=check_field(Assert.gt, 0), - ) trust_remote_code: bool = Field( default=False, desc="Trust remote code when loading dataset.", diff --git a/fast_llm/engine/evaluation/forward_kl/evaluator.py b/fast_llm/engine/evaluation/forward_kl/evaluator.py index 66f8bdacd..09c7ff553 100644 --- a/fast_llm/engine/evaluation/forward_kl/evaluator.py +++ b/fast_llm/engine/evaluation/forward_kl/evaluator.py @@ -4,7 +4,8 @@ import torch import torch.nn.functional as F -from fast_llm.core.distributed import safe_barrier +from fast_llm.config import NoAutoValidate +from fast_llm.core.distributed import all_reduce, safe_barrier from fast_llm.data.data.abstract import Data from fast_llm.data.sample.language_model import LanguageModelBatch, LanguageModelSample from fast_llm.data.sample.token import TokenSample @@ -21,13 +22,16 @@ from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.engine.schedule.runner import ScheduleRunner +from fast_llm.layers.attention.config import AttentionKwargs +from fast_llm.models.gpt.config import GPTBatchConfig logger = logging.getLogger(__name__) class ForwardKLEvaluator[ConfigType: ForwardKLEvaluatorConfig](Evaluator[ConfigType]): _inference_runner: InferenceRunner - _max_sequence_length: int + _sequence_length: int + _micro_sequence_length: int def setup( self, @@ -41,7 +45,11 @@ def setup( super().setup(distributed, run, multi_stage, runner, data, phase) self._inference_runner = InferenceRunner(self._multi_stage, runner=self._runner) self._inference_runner.setup() - self._max_sequence_length = self._config.max_sequence_length + + # Get sequence configuration from training batch config (required for SP support) + self._sequence_length = self._batch_config.sequence_length + self._micro_sequence_length = self._batch_config.micro_sequence_length + self._is_setup = True def get_sampling_parameters(self) -> EvaluatorSamplingParameters | None: @@ -87,6 +95,10 @@ def run( def _compute_forward_kl(self) -> tuple[float, int, int]: import datasets + # Shard traces across data-parallel ranks + data_rank = self._distributed.config.data_rank + data_parallel = self._distributed.config.data_parallel + traces = datasets.load_dataset( self._config.dataset_path, name=self._config.task, @@ -94,41 +106,70 @@ def _compute_forward_kl(self) -> tuple[float, int, int]: trust_remote_code=self._config.trust_remote_code, ) + # Apply num_samples limit before sharding to preserve semantics + # (num_samples = total traces across all ranks, not per-rank) + if self._config.num_samples and len(traces) > self._config.num_samples: + traces = traces.select(range(self._config.num_samples)) + + # Shard across DP ranks (lazy operation - just changes which indices are accessible) + traces = traces.shard(num_shards=data_parallel, index=data_rank) + total_kl = 0.0 num_traces = 0 num_skipped = 0 - num_samples = min(len(traces), self._config.num_samples) if self._config.num_samples else len(traces) - - for i in range(0, num_samples, self._config.batch_size): - batch_indices = range(i, min(i + self._config.batch_size, num_samples)) - batch = [] - for j in batch_indices: - trace = traces[j] - trace_len = len(trace["prompt_tokens"]) + len(trace["completion_tokens"]) - if trace_len > self._max_sequence_length: - logger.warning( - f"Skipping trace {j}: length {trace_len} exceeds max {self._max_sequence_length}" - ) - num_skipped += 1 - continue - batch.append(trace) - - if not batch: + + # Collect traces for this rank, filtering by length + rank_traces = [] + for trace in traces: + trace_len = len(trace["prompt_tokens"]) + len(trace["completion_tokens"]) + if trace_len > self._sequence_length: + num_skipped += 1 continue + rank_traces.append(trace) + + if num_skipped > 0: + logger.warning( + f"Skipped {num_skipped} traces exceeding sequence length {self._sequence_length}" + ) + + # Process traces in batches + for i in range(0, len(rank_traces), self._config.batch_size): + batch = rank_traces[i : i + self._config.batch_size] student_log_probs = self._compute_batch_log_probs(batch) - for j, trace in enumerate(batch): - total_kl += trace["teacher_log_prob"] - student_log_probs[j] - num_traces += 1 + # student_log_probs is None on non-last pipeline ranks (they don't have logits) + if student_log_probs is not None: + for j, trace in enumerate(batch): + total_kl += trace["teacher_log_prob"] - student_log_probs[j] + num_traces += 1 torch.cuda.empty_cache() - return total_kl / num_traces if num_traces > 0 else 0.0, num_traces, num_skipped + # Reduce across data group (sum KL and counts from all DP ranks) + if self._distributed.data_group: + total_kl_tensor = torch.tensor([total_kl], dtype=torch.float64, device=self._distributed.device) + num_traces_tensor = torch.tensor([num_traces], dtype=torch.int64, device=self._distributed.device) + num_skipped_tensor = torch.tensor([num_skipped], dtype=torch.int64, device=self._distributed.device) + all_reduce(total_kl_tensor, group=self._distributed.data_group) + all_reduce(num_traces_tensor, group=self._distributed.data_group) + all_reduce(num_skipped_tensor, group=self._distributed.data_group) + total_kl = total_kl_tensor.item() + num_traces = int(num_traces_tensor.item()) + num_skipped = int(num_skipped_tensor.item()) + + # Reduce across pipeline group (last PP rank has the values, others have zeros) + if self._distributed.pipeline_group: + total_kl_tensor = torch.tensor([total_kl], dtype=torch.float64, device=self._distributed.device) + num_traces_tensor = torch.tensor([num_traces], dtype=torch.int64, device=self._distributed.device) + all_reduce(total_kl_tensor, group=self._distributed.pipeline_group) + all_reduce(num_traces_tensor, group=self._distributed.pipeline_group) + total_kl = total_kl_tensor.item() + num_traces = int(num_traces_tensor.item()) - def _compute_batch_log_probs(self, batch: list[dict[str, typing.Any]]) -> list[float]: - max_len = max(len(t["prompt_tokens"]) + len(t["completion_tokens"]) for t in batch) + return total_kl / num_traces if num_traces > 0 else 0.0, num_traces, num_skipped + def _compute_batch_log_probs(self, batch: list[dict[str, typing.Any]]) -> list[float] | None: samples = [] prompt_lengths = [] completion_lengths = [] @@ -138,7 +179,8 @@ def _compute_batch_log_probs(self, batch: list[dict[str, typing.Any]]) -> list[f completion = trace["completion_tokens"] full = prompt + completion actual_len = len(full) - pad_len = max_len - actual_len + # Pad to training sequence length (required for SP support) + pad_len = self._sequence_length - actual_len tokens = torch.tensor(full + [0] * pad_len, dtype=torch.int64) samples.append(LanguageModelSample(TokenSample(tokens, lengths=[actual_len]))) @@ -147,8 +189,22 @@ def _compute_batch_log_probs(self, batch: list[dict[str, typing.Any]]) -> list[f lm_batch = LanguageModelBatch.from_samples(samples) + # Create batch config with training's sequence settings (required for SP support) + with NoAutoValidate(): + batch_config = GPTBatchConfig( + micro_batch_size=len(batch), + sequence_length=self._sequence_length, + micro_sequence_length=self._micro_sequence_length, + ) + batch_config.setup(self._distributed.config) + batch_config.validate() + + # Get preprocessing metadata using GPTBatchConfig (enables proper SP splitting) + preprocessed_meta = self._multi_stage.base_model.preprocess_meta(batch_config, PhaseType.inference) + preprocessed = self._multi_stage.base_model.preprocess_batch( lm_batch, + preprocessed_meta, phase=PhaseType.inference, iteration=0, ) @@ -156,19 +212,25 @@ def _compute_batch_log_probs(self, batch: list[dict[str, typing.Any]]) -> list[f for input_, kwargs in preprocessed: kwargs["global_logits"] = True self._inference_runner.forward(input_, kwargs) - logits = kwargs["logits"] - sequence_first = kwargs.get("sequence_first", False) - if sequence_first: + # With pipeline parallelism, only the last stage has logits. + # Other stages participate in the forward pass but don't compute logits. + if "logits" not in kwargs: + return None + + logits = kwargs["logits"] + + if kwargs.get(AttentionKwargs.sequence_first, False): logits = logits.transpose(0, 1) results = [] + device = logits.device for idx in range(len(batch)): prompt_len = prompt_lengths[idx] completion_len = completion_lengths[idx] pred_logits = logits[idx, prompt_len - 1 : prompt_len + completion_len - 1] - targets = lm_batch.tokens.tokens[idx, prompt_len : prompt_len + completion_len] + targets = lm_batch.tokens.tokens[idx, prompt_len : prompt_len + completion_len].to(device) log_probs = F.log_softmax(pred_logits.float(), dim=-1) token_log_probs = log_probs.gather(-1, targets.unsqueeze(-1)).squeeze(-1) From fd7670bc758578d5f7985f4b316d30cd41b00968 Mon Sep 17 00:00:00 2001 From: Torsten Scholak Date: Mon, 22 Dec 2025 22:36:28 +0000 Subject: [PATCH 20/25] Fix global_logits storage during distillation and clean up evaluator MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Store raw logits unconditionally when global_logits=True in _logits_cross_entropy_forward_backward, fixing ForwardKL evaluation during distillation training where targets is never None. Also cleaned up ForwardKL evaluator: - Use GPTInferenceRunner instead of generic InferenceRunner - Add shuffle with configurable seed for reproducibility - Add split/seed config fields (replaced task field) - Proper padding via get_padding() and from_documents() - Remove memory tracking tooling, keep gc.collect cleanup 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- fast_llm/engine/evaluation/config.py | 13 +++-- .../engine/evaluation/forward_kl/evaluator.py | 48 +++++++++++++++---- fast_llm/layers/language_model/head.py | 25 ++++++---- .../examples/train_supernet_qwen2.yaml | 38 +++++++-------- 4 files changed, 83 insertions(+), 41 deletions(-) diff --git a/fast_llm/engine/evaluation/config.py b/fast_llm/engine/evaluation/config.py index 4ae39e03d..744506b65 100644 --- a/fast_llm/engine/evaluation/config.py +++ b/fast_llm/engine/evaluation/config.py @@ -131,14 +131,19 @@ class ForwardKLEvaluatorConfig(EvaluatorConfig): desc="HuggingFace dataset path containing teacher traces.", hint=FieldHint.core, ) - task: str | None = Field( - default=None, - desc="Dataset configuration/task name.", + split: str = Field( + default="validation", + desc="Dataset split to evaluate on. Use 'train+validation' syntax to combine multiple splits.", + hint=FieldHint.optional, + ) + seed: int = Field( + default=42, + desc="Random seed for shuffling traces. Ensures reproducible evaluation across runs.", hint=FieldHint.optional, ) num_samples: int | None = Field( default=None, - desc="Maximum number of traces to evaluate. None for all.", + desc="Maximum number of traces to evaluate (after shuffling). None for all.", hint=FieldHint.optional, valid=skip_valid_if_none(check_field(Assert.gt, 0)), ) diff --git a/fast_llm/engine/evaluation/forward_kl/evaluator.py b/fast_llm/engine/evaluation/forward_kl/evaluator.py index 09c7ff553..5548a8b2a 100644 --- a/fast_llm/engine/evaluation/forward_kl/evaluator.py +++ b/fast_llm/engine/evaluation/forward_kl/evaluator.py @@ -1,3 +1,4 @@ +import gc import logging import typing @@ -19,17 +20,17 @@ EvaluatorSamplingParameters, TrainingProgress, ) -from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.engine.schedule.runner import ScheduleRunner from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.models.gpt.config import GPTBatchConfig +from fast_llm.models.gpt.model import GPTInferenceRunner logger = logging.getLogger(__name__) class ForwardKLEvaluator[ConfigType: ForwardKLEvaluatorConfig](Evaluator[ConfigType]): - _inference_runner: InferenceRunner + _inference_runner: GPTInferenceRunner _sequence_length: int _micro_sequence_length: int @@ -43,7 +44,11 @@ def setup( phase: PhaseType, ) -> None: super().setup(distributed, run, multi_stage, runner, data, phase) - self._inference_runner = InferenceRunner(self._multi_stage, runner=self._runner) + + # TODO: instead of using GPTInferenceRunner, we should get ourselves + # the FastLLMModelConfig instance and build the correct InferenceRunner + # with config.get_inference_runner_class() + self._inference_runner = GPTInferenceRunner(self._multi_stage, runner=self._runner) self._inference_runner.setup() # Get sequence configuration from training batch config (required for SP support) @@ -101,11 +106,14 @@ def _compute_forward_kl(self) -> tuple[float, int, int]: traces = datasets.load_dataset( self._config.dataset_path, - name=self._config.task, - split="validation", + split=self._config.split, trust_remote_code=self._config.trust_remote_code, ) + # Shuffle traces for better problem coverage when using num_samples. + # Uses a fixed seed for reproducibility across distributed ranks. + traces = traces.shuffle(seed=self._config.seed) + # Apply num_samples limit before sharding to preserve semantics # (num_samples = total traces across all ranks, not per-rank) if self._config.num_samples and len(traces) > self._config.num_samples: @@ -127,6 +135,10 @@ def _compute_forward_kl(self) -> tuple[float, int, int]: continue rank_traces.append(trace) + # Free the HuggingFace dataset - we've extracted what we need + del traces + gc.collect() + if num_skipped > 0: logger.warning( f"Skipped {num_skipped} traces exceeding sequence length {self._sequence_length}" @@ -144,6 +156,8 @@ def _compute_forward_kl(self) -> tuple[float, int, int]: total_kl += trace["teacher_log_prob"] - student_log_probs[j] num_traces += 1 + # Memory cleanup + gc.collect() torch.cuda.empty_cache() # Reduce across data group (sum KL and counts from all DP ranks) @@ -179,22 +193,33 @@ def _compute_batch_log_probs(self, batch: list[dict[str, typing.Any]]) -> list[f completion = trace["completion_tokens"] full = prompt + completion actual_len = len(full) - # Pad to training sequence length (required for SP support) pad_len = self._sequence_length - actual_len - tokens = torch.tensor(full + [0] * pad_len, dtype=torch.int64) - samples.append(LanguageModelSample(TokenSample(tokens, lengths=[actual_len]))) + trace_tokens = torch.tensor(full, dtype=torch.int64) + trace_sample = LanguageModelSample(TokenSample(trace_tokens)) + + if pad_len > 0: + padding_sample = trace_sample.get_padding(pad_len) + sample = LanguageModelSample.from_documents([trace_sample, padding_sample]) + elif pad_len == 0: + sample = trace_sample + else: + raise ValueError("Trace length exceeds sequence length") + + samples.append(sample) prompt_lengths.append(len(prompt)) completion_lengths.append(len(completion)) lm_batch = LanguageModelBatch.from_samples(samples) # Create batch config with training's sequence settings (required for SP support) + # truncate_documents=False enables mask_inputs, which handles -100 padding tokens with NoAutoValidate(): batch_config = GPTBatchConfig( micro_batch_size=len(batch), sequence_length=self._sequence_length, micro_sequence_length=self._micro_sequence_length, + truncate_documents=False, ) batch_config.setup(self._distributed.config) batch_config.validate() @@ -229,6 +254,7 @@ def _compute_batch_log_probs(self, batch: list[dict[str, typing.Any]]) -> list[f prompt_len = prompt_lengths[idx] completion_len = completion_lengths[idx] + # Extract only the slice we need, then compute on it pred_logits = logits[idx, prompt_len - 1 : prompt_len + completion_len - 1] targets = lm_batch.tokens.tokens[idx, prompt_len : prompt_len + completion_len].to(device) @@ -236,4 +262,10 @@ def _compute_batch_log_probs(self, batch: list[dict[str, typing.Any]]) -> list[f token_log_probs = log_probs.gather(-1, targets.unsqueeze(-1)).squeeze(-1) results.append(token_log_probs.sum().item()) + # Explicitly delete intermediates + del pred_logits, targets, log_probs, token_log_probs + + # Explicitly delete the large logits tensor + del logits, kwargs, preprocessed, lm_batch + return results diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index b1d0c2acd..94ddbded9 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -250,16 +250,10 @@ def _logits_cross_entropy_forward_backward_split( input_, targets, weight, grad_output, kwargs, losses ) if targets is None: - # TODO: Make a proper way of returning the model output. - loss = loss.detach() - if kwargs.get("global_logits"): - if self._vocab_parallel: - loss = gather_op(loss, self._parallel_dim.group, 2) - elif self._sequence_parallel_logits: - loss = gather_op( - loss, self._parallel_dim.group, 0 if kwargs[LanguageModelKwargs.sequence_first] else 1 - ) - kwargs["logits" if self._prediction_distance == 0 else f"logits_{self._prediction_distance}"] = loss + # global_logits: raw logits already stored and gathered in inner function + # non-global_logits: store scaled logits for distillation backwards compat + if not kwargs.get("global_logits"): + kwargs["logits" if self._prediction_distance == 0 else f"logits_{self._prediction_distance}"] = loss.detach() return None, None else: loss = None @@ -342,6 +336,17 @@ def _logits_cross_entropy_forward_backward( dims = None self._debug(logits, "logits", dims, kwargs, scale=self._config.logits_scale_factor) + if kwargs.get("global_logits"): + logits_for_storage = logits.detach() + if self._vocab_parallel: + logits_for_storage = gather_op(logits_for_storage, self._parallel_dim.group, 2) + elif self._sequence_parallel_logits: + logits_for_storage = gather_op( + logits_for_storage, self._parallel_dim.group, 0 if kwargs[LanguageModelKwargs.sequence_first] else 1 + ) + logits_key = "logits" if self._prediction_distance == 0 else f"logits_{self._prediction_distance}" + kwargs[logits_key] = logits_for_storage + if targets is None: return logits * self._config.logits_scale_factor, None dpo_target, lm_target, distillation_target, loss_mask = targets diff --git a/fast_llm_external_models/apriel2/examples/train_supernet_qwen2.yaml b/fast_llm_external_models/apriel2/examples/train_supernet_qwen2.yaml index 5b190955f..aad168713 100644 --- a/fast_llm_external_models/apriel2/examples/train_supernet_qwen2.yaml +++ b/fast_llm_external_models/apriel2/examples/train_supernet_qwen2.yaml @@ -83,14 +83,10 @@ # PERFORMANCE TUNING # ============================================================================= # -# Default config uses seq=4096, micro_batch=2, batch=16 which gives: -# - ~8k tokens/s/gpu throughput -# - ~61GB GPU memory usage -# - ~25 hours for 1B tokens on single GPU -# -# Adjust batch settings based on your GPU memory: -# - Reduce micro_batch_size if OOM -# - Increase micro_batch_size/batch_size if memory available +# Default config uses seq=2048, micro_batch=2, batch=64 (~131k tokens/iter). +# Adjust settings based on your GPU memory: +# - Reduce micro_batch_size or sequence_length if OOM +# - Increase micro_batch_size or sequence_length if memory available # # ============================================================================= # OUTPUT @@ -118,14 +114,16 @@ model: lr_scale: 0.0 # Freeze MLP normalization: lr_scale: 0.0 # Freeze layer norms - # Activation-level distillation from teacher distillation_model: teacher - activation_distillation_factor: 0.8 + activation_distillation_factor: 0.5 embeddings: lr_scale: 0.0 # Freeze word embeddings head: lr_scale: 0.0 # Freeze output head - cross_entropy_implementation: torch + # cross_entropy_implementation: torch + distillation_model: teacher + distillation_loss_factor: 1.0 + distillation_loss_implementation: reverse_kl multi_stage: zero_stage: 2 distributed: @@ -143,11 +141,13 @@ reference_models: model_weights: true load_config: model -# Batch configuration (tuned for ~61GB GPU memory, ~8k tokens/s) +# Batch configuration batch: - sequence_length: 4096 + sequence_length: 2048 micro_batch_size: 2 - batch_size: 16 + batch_size: 64 + truncate_documents: false + use_loss_masking_spans: true # Data configuration (prepared Tulu 3 dataset) data: @@ -159,7 +159,7 @@ data: # Optimizer configuration optimizer: learning_rate: - base: 1.0e-05 + base: 3.0e-05 decay_style: cosine warmup_iterations: 100 decay_iterations: 10000 @@ -169,17 +169,16 @@ optimizer: beta_2: 0.95 # Training configuration -# At seq=4096, batch=16: ~65k tokens/iter, ~280 iters/hour -# 10000 iters ≈ 650M tokens ≈ 35 hours +# At seq=2048, batch=64: ~131k tokens/iter training: train_iters: 10000 num_workers: 4 logs: interval: 10 checkpoint: - interval: 280 # ~hourly + interval: 100 export: - interval: 280 # ~hourly (useful for development/testing during training) + interval: 100 format: apriel2_text test_iters: 0 evaluators: {} @@ -187,6 +186,7 @@ training: # wandb: # entity_name: your-entity # project_name: your-project + # group_name: your-group # Experiment directory run: From 10e24ca21ecb774c4587371a724d72101ffd04fd Mon Sep 17 00:00:00 2001 From: Torsten Scholak Date: Wed, 24 Dec 2025 01:32:24 +0000 Subject: [PATCH 21/25] Refactor ForwardKLEvaluator to compute IS accuracy and ESS metrics MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Replace forward KL with importance-weighted accuracy and effective sample size - Shard by problem_id hash (not trace index) so each rank gets complete problems - Add TraceTensors dataclass with smart constructors (empty, from_traces) - Vectorize log prob computation using F.cross_entropy with completion mask - Add _scatter_logsumexp for numerically stable grouped reductions - Use allreduce_scalar for cleaner distributed reduction - Pre-tensorize all trace data for efficient batch slicing 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- .../engine/evaluation/forward_kl/evaluator.py | 409 +++++++++++------- 1 file changed, 255 insertions(+), 154 deletions(-) diff --git a/fast_llm/engine/evaluation/forward_kl/evaluator.py b/fast_llm/engine/evaluation/forward_kl/evaluator.py index 5548a8b2a..8b5f45f3a 100644 --- a/fast_llm/engine/evaluation/forward_kl/evaluator.py +++ b/fast_llm/engine/evaluation/forward_kl/evaluator.py @@ -1,12 +1,13 @@ +import dataclasses import gc +import hashlib import logging -import typing import torch import torch.nn.functional as F from fast_llm.config import NoAutoValidate -from fast_llm.core.distributed import all_reduce, safe_barrier +from fast_llm.core.distributed import allreduce_scalar, safe_barrier from fast_llm.data.data.abstract import Data from fast_llm.data.sample.language_model import LanguageModelBatch, LanguageModelSample from fast_llm.data.sample.token import TokenSample @@ -29,7 +30,92 @@ logger = logging.getLogger(__name__) +@dataclasses.dataclass +class TraceTensors: + tokens: torch.Tensor # (num_traces, sequence_length) + prompt_lens: torch.Tensor # (num_traces,) + completion_lens: torch.Tensor # (num_traces,) + problem_indices: torch.Tensor # (num_traces,) + teacher_log_probs: torch.Tensor # (num_traces,) + corrects: torch.Tensor # (num_traces,) + num_problems: int + num_skipped: int + + def __len__(self) -> int: + return self.tokens.shape[0] + + @classmethod + def empty(cls, sequence_length: int, device: torch.device, num_skipped: int = 0) -> "TraceTensors": + return cls( + tokens=torch.empty((0, sequence_length), dtype=torch.int64, device=device), + prompt_lens=torch.empty(0, dtype=torch.int64, device=device), + completion_lens=torch.empty(0, dtype=torch.int64, device=device), + problem_indices=torch.empty(0, dtype=torch.int64, device=device), + teacher_log_probs=torch.empty(0, dtype=torch.float64, device=device), + corrects=torch.empty(0, dtype=torch.bool, device=device), + num_problems=0, + num_skipped=num_skipped, + ) + + @classmethod + def from_traces( + cls, + traces: list[dict], + sequence_length: int, + device: torch.device, + ) -> "TraceTensors": + pid_to_idx: dict[str, int] = {} + valid_traces: list[tuple[list[int], list[int], str, float, bool]] = [] + num_skipped = 0 + + for t in traces: + prompt, completion = t["prompt_tokens"], t["completion_tokens"] + if len(prompt) + len(completion) > sequence_length: + num_skipped += 1 + continue + valid_traces.append((prompt, completion, t["problem_id"], t["teacher_log_prob"], t["correct"])) + + if not valid_traces: + return cls.empty(sequence_length, device, num_skipped) + + n = len(valid_traces) + tokens = torch.zeros((n, sequence_length), dtype=torch.int64, device=device) + prompt_lens = torch.empty(n, dtype=torch.int64, device=device) + completion_lens = torch.empty(n, dtype=torch.int64, device=device) + problem_indices = torch.empty(n, dtype=torch.int64, device=device) + teacher_log_probs = torch.empty(n, dtype=torch.float64, device=device) + corrects = torch.empty(n, dtype=torch.bool, device=device) + + for i, (prompt, completion, pid, teacher_lp, correct) in enumerate(valid_traces): + seq = prompt + completion + tokens[i, : len(seq)] = torch.tensor(seq, dtype=torch.int64, device=device) + prompt_lens[i] = len(prompt) + completion_lens[i] = len(completion) + + if pid not in pid_to_idx: + pid_to_idx[pid] = len(pid_to_idx) + problem_indices[i] = pid_to_idx[pid] + teacher_log_probs[i] = teacher_lp + corrects[i] = correct + + return cls( + tokens=tokens, + prompt_lens=prompt_lens, + completion_lens=completion_lens, + problem_indices=problem_indices, + teacher_log_probs=teacher_log_probs, + corrects=corrects, + num_problems=len(pid_to_idx), + num_skipped=num_skipped, + ) + + class ForwardKLEvaluator[ConfigType: ForwardKLEvaluatorConfig](Evaluator[ConfigType]): + """Shard by PROBLEM (not trace) so each rank gets complete problems. + + This allows computing per-problem IS metrics locally, then reducing scalars. + """ + _inference_runner: GPTInferenceRunner _sequence_length: int _micro_sequence_length: int @@ -44,17 +130,10 @@ def setup( phase: PhaseType, ) -> None: super().setup(distributed, run, multi_stage, runner, data, phase) - - # TODO: instead of using GPTInferenceRunner, we should get ourselves - # the FastLLMModelConfig instance and build the correct InferenceRunner - # with config.get_inference_runner_class() self._inference_runner = GPTInferenceRunner(self._multi_stage, runner=self._runner) self._inference_runner.setup() - - # Get sequence configuration from training batch config (required for SP support) self._sequence_length = self._batch_config.sequence_length self._micro_sequence_length = self._batch_config.micro_sequence_length - self._is_setup = True def get_sampling_parameters(self) -> EvaluatorSamplingParameters | None: @@ -66,157 +145,116 @@ def run( run_index: int | None = None, ) -> EvaluationMetrics: assert self._is_setup - if self._config.dataset_path is None: return EvaluationMetrics() safe_barrier(self._distributed.world_group, f"forward_kl_{self._name} begin") - - forward_kl, num_traces, num_skipped = self._compute_forward_kl() - + metrics = self._evaluate() safe_barrier(self._distributed.world_group, f"forward_kl_{self._name} end") - if num_traces == 0: + if metrics["num_traces"] == 0: return EvaluationMetrics() - metrics = { - f"validation.{self._name}": { - "forward_kl": forward_kl, - "num_traces": num_traces, - } - } - - if training_progress is not None: - metrics[f"validation.{self._name}"]["iteration"] = training_progress.completed_steps - - formatted = f"Forward KL ({self._name}): {forward_kl:.4f} ({num_traces} traces)" - if num_skipped > 0: - formatted += f" [{num_skipped} skipped]" + formatted = ( + f"IS Eval ({self._name}): " + f"acc={metrics['is_accuracy']:.4f}, " + f"ESS={metrics['mean_ess']:.2f}/{metrics['samples_per_problem']:.1f}, " + f"({metrics['num_problems']} problems, {metrics['num_traces']} traces)" + ) + if metrics["num_skipped"] > 0: + formatted += f" [{metrics['num_skipped']} skipped]" log_main_rank(formatted) - return EvaluationMetrics(metrics, formatted) + return EvaluationMetrics( + {f"validation.{self._name}": {k: v for k, v in metrics.items() if k != "num_skipped"}}, + formatted, + ) @torch.inference_mode() - def _compute_forward_kl(self) -> tuple[float, int, int]: - import datasets + def _evaluate(self) -> dict[str, float]: + device = self._distributed.device + data = self._load_traces(device) - # Shard traces across data-parallel ranks - data_rank = self._distributed.config.data_rank - data_parallel = self._distributed.config.data_parallel + if len(data) == 0: + return self._reduce_metrics(0.0, 0.0, 0, 0, data.num_skipped) - traces = datasets.load_dataset( - self._config.dataset_path, - split=self._config.split, - trust_remote_code=self._config.trust_remote_code, - ) + batch_size = self._config.batch_size + student_log_probs_batches: list[torch.Tensor] = [] - # Shuffle traces for better problem coverage when using num_samples. - # Uses a fixed seed for reproducibility across distributed ranks. - traces = traces.shuffle(seed=self._config.seed) + for i in range(0, len(data), batch_size): + batch_log_probs = self._compute_batch_log_probs( + data.tokens[i : i + batch_size], + data.prompt_lens[i : i + batch_size], + data.completion_lens[i : i + batch_size], + ) + if batch_log_probs is not None: + student_log_probs_batches.append(batch_log_probs) - # Apply num_samples limit before sharding to preserve semantics - # (num_samples = total traces across all ranks, not per-rank) - if self._config.num_samples and len(traces) > self._config.num_samples: - traces = traces.select(range(self._config.num_samples)) + if not student_log_probs_batches: # non-last PP rank + return self._reduce_metrics(0.0, 0.0, 0, 0, data.num_skipped) - # Shard across DP ranks (lazy operation - just changes which indices are accessible) - traces = traces.shard(num_shards=data_parallel, index=data_rank) + student_log_probs = torch.cat(student_log_probs_batches) + log_w = student_log_probs - data.teacher_log_probs - total_kl = 0.0 - num_traces = 0 - num_skipped = 0 + log_sum_all = self._scatter_logsumexp(log_w, data.problem_indices, data.num_problems) + log_w_correct = log_w.masked_fill(~data.corrects, float("-inf")) + log_sum_correct = self._scatter_logsumexp(log_w_correct, data.problem_indices, data.num_problems) - # Collect traces for this rank, filtering by length - rank_traces = [] - for trace in traces: - trace_len = len(trace["prompt_tokens"]) + len(trace["completion_tokens"]) - if trace_len > self._sequence_length: - num_skipped += 1 - continue - rank_traces.append(trace) + # IS accuracy; nan_to_num handles -inf - -inf + accuracy = (log_sum_correct - log_sum_all).exp().nan_to_num(0.0) - # Free the HuggingFace dataset - we've extracted what we need - del traces - gc.collect() + # ESS = exp(2*logsumexp(log_w) - logsumexp(2*log_w)) + log_sum_sq = self._scatter_logsumexp(2 * log_w, data.problem_indices, data.num_problems) + ess = (2 * log_sum_all - log_sum_sq).exp().clamp(min=0.0) - if num_skipped > 0: - logger.warning( - f"Skipped {num_skipped} traces exceeding sequence length {self._sequence_length}" - ) + return self._reduce_metrics( + accuracy.sum().item(), + ess.sum().item(), + data.num_problems, + len(data), + data.num_skipped, + ) - # Process traces in batches - for i in range(0, len(rank_traces), self._config.batch_size): - batch = rank_traces[i : i + self._config.batch_size] - - student_log_probs = self._compute_batch_log_probs(batch) - - # student_log_probs is None on non-last pipeline ranks (they don't have logits) - if student_log_probs is not None: - for j, trace in enumerate(batch): - total_kl += trace["teacher_log_prob"] - student_log_probs[j] - num_traces += 1 - - # Memory cleanup - gc.collect() - torch.cuda.empty_cache() - - # Reduce across data group (sum KL and counts from all DP ranks) - if self._distributed.data_group: - total_kl_tensor = torch.tensor([total_kl], dtype=torch.float64, device=self._distributed.device) - num_traces_tensor = torch.tensor([num_traces], dtype=torch.int64, device=self._distributed.device) - num_skipped_tensor = torch.tensor([num_skipped], dtype=torch.int64, device=self._distributed.device) - all_reduce(total_kl_tensor, group=self._distributed.data_group) - all_reduce(num_traces_tensor, group=self._distributed.data_group) - all_reduce(num_skipped_tensor, group=self._distributed.data_group) - total_kl = total_kl_tensor.item() - num_traces = int(num_traces_tensor.item()) - num_skipped = int(num_skipped_tensor.item()) - - # Reduce across pipeline group (last PP rank has the values, others have zeros) - if self._distributed.pipeline_group: - total_kl_tensor = torch.tensor([total_kl], dtype=torch.float64, device=self._distributed.device) - num_traces_tensor = torch.tensor([num_traces], dtype=torch.int64, device=self._distributed.device) - all_reduce(total_kl_tensor, group=self._distributed.pipeline_group) - all_reduce(num_traces_tensor, group=self._distributed.pipeline_group) - total_kl = total_kl_tensor.item() - num_traces = int(num_traces_tensor.item()) - - return total_kl / num_traces if num_traces > 0 else 0.0, num_traces, num_skipped - - def _compute_batch_log_probs(self, batch: list[dict[str, typing.Any]]) -> list[float] | None: - samples = [] - prompt_lengths = [] - completion_lengths = [] + def _load_traces(self, device: torch.device) -> TraceTensors: + import datasets - for trace in batch: - prompt = trace["prompt_tokens"] - completion = trace["completion_tokens"] - full = prompt + completion - actual_len = len(full) - pad_len = self._sequence_length - actual_len + ds = datasets.load_dataset( + self._config.dataset_path, + split=self._config.split, + trust_remote_code=self._config.trust_remote_code, + ) - trace_tokens = torch.tensor(full, dtype=torch.int64) - trace_sample = LanguageModelSample(TokenSample(trace_tokens)) + # Shuffle needed because traces are sorted by problem + if self._config.num_samples and len(ds) > self._config.num_samples: + ds = ds.shuffle(seed=self._config.seed).select(range(self._config.num_samples)) - if pad_len > 0: - padding_sample = trace_sample.get_padding(pad_len) - sample = LanguageModelSample.from_documents([trace_sample, padding_sample]) - elif pad_len == 0: - sample = trace_sample - else: - raise ValueError("Trace length exceeds sequence length") + dp_rank = self._distributed.config.data_rank + dp_size = self._distributed.config.data_parallel - samples.append(sample) - prompt_lengths.append(len(prompt)) - completion_lengths.append(len(completion)) + def belongs_to_shard(example: dict) -> bool: + h = hashlib.md5(example["problem_id"].encode(), usedforsecurity=False).digest() + return int.from_bytes(h[:4], "little") % dp_size == dp_rank - lm_batch = LanguageModelBatch.from_samples(samples) + ds = ds.filter(belongs_to_shard) + traces = list(ds) + + del ds + gc.collect() + + return TraceTensors.from_traces(traces, self._sequence_length, device) + + def _compute_batch_log_probs( + self, + tokens: torch.Tensor, + prompt_lens: torch.Tensor, + completion_lens: torch.Tensor, + ) -> torch.Tensor | None: + batch_size = tokens.shape[0] + lm_batch = self._prepare_batch(tokens, prompt_lens, completion_lens) - # Create batch config with training's sequence settings (required for SP support) - # truncate_documents=False enables mask_inputs, which handles -100 padding tokens with NoAutoValidate(): batch_config = GPTBatchConfig( - micro_batch_size=len(batch), + micro_batch_size=batch_size, sequence_length=self._sequence_length, micro_sequence_length=self._micro_sequence_length, truncate_documents=False, @@ -224,48 +262,111 @@ def _compute_batch_log_probs(self, batch: list[dict[str, typing.Any]]) -> list[f batch_config.setup(self._distributed.config) batch_config.validate() - # Get preprocessing metadata using GPTBatchConfig (enables proper SP splitting) preprocessed_meta = self._multi_stage.base_model.preprocess_meta(batch_config, PhaseType.inference) - preprocessed = self._multi_stage.base_model.preprocess_batch( - lm_batch, - preprocessed_meta, - phase=PhaseType.inference, - iteration=0, + lm_batch, preprocessed_meta, phase=PhaseType.inference, iteration=0 ) + # Loop runs through micro-sequences; final kwargs has the logits for input_, kwargs in preprocessed: kwargs["global_logits"] = True self._inference_runner.forward(input_, kwargs) - # With pipeline parallelism, only the last stage has logits. - # Other stages participate in the forward pass but don't compute logits. - if "logits" not in kwargs: + if "logits" not in kwargs: # non-last PP stage return None logits = kwargs["logits"] - if kwargs.get(AttentionKwargs.sequence_first, False): logits = logits.transpose(0, 1) - results = [] device = logits.device - for idx in range(len(batch)): - prompt_len = prompt_lengths[idx] - completion_len = completion_lengths[idx] + seq_len = logits.shape[1] + + pred_logits = logits[:, :-1, :].contiguous() + targets = tokens[:, 1:].contiguous().to(device) - # Extract only the slice we need, then compute on it - pred_logits = logits[idx, prompt_len - 1 : prompt_len + completion_len - 1] - targets = lm_batch.tokens.tokens[idx, prompt_len : prompt_len + completion_len].to(device) + # Mask: completion predictions are at [prompt_len-1, prompt_len+completion_len-1) + mask = self._create_completion_mask(prompt_lens, completion_lens, seq_len - 1) - log_probs = F.log_softmax(pred_logits.float(), dim=-1) - token_log_probs = log_probs.gather(-1, targets.unsqueeze(-1)).squeeze(-1) - results.append(token_log_probs.sum().item()) + ce_loss = F.cross_entropy( + pred_logits.view(-1, pred_logits.size(-1)), + targets.view(-1), + reduction="none", + ).view(batch_size, seq_len - 1) - # Explicitly delete intermediates - del pred_logits, targets, log_probs, token_log_probs + results = -(ce_loss * mask).sum(dim=1) - # Explicitly delete the large logits tensor del logits, kwargs, preprocessed, lm_batch - return results + return results.to(torch.float64) + + def _prepare_batch( + self, + tokens: torch.Tensor, + prompt_lens: torch.Tensor, + completion_lens: torch.Tensor, + ) -> LanguageModelBatch: + samples = [] + for i in range(tokens.shape[0]): + seq_len = int(prompt_lens[i].item()) + int(completion_lens[i].item()) + sample = LanguageModelSample(TokenSample(tokens[i, :seq_len].cpu())) + + pad_len = self._sequence_length - seq_len + if pad_len > 0: + sample = LanguageModelSample.from_documents([sample, sample.get_padding(pad_len)]) + + samples.append(sample) + + return LanguageModelBatch.from_samples(samples) + + def _create_completion_mask( + self, + prompt_lens: torch.Tensor, + completion_lens: torch.Tensor, + seq_len: int, + ) -> torch.Tensor: + device = prompt_lens.device + positions = torch.arange(seq_len, device=device) + start = (prompt_lens - 1).unsqueeze(1) + end = (prompt_lens + completion_lens - 1).unsqueeze(1) + return (positions >= start) & (positions < end) + + def _reduce_metrics( + self, sum_accuracy: float, sum_ess: float, num_problems: int, num_traces: int, num_skipped: int + ) -> dict[str, float]: + group = self._distributed.world_group + sum_accuracy = allreduce_scalar(sum_accuracy, group=group) + sum_ess = allreduce_scalar(sum_ess, group=group) + num_problems = int(allreduce_scalar(num_problems, torch.int64, group=group)) + num_traces = int(allreduce_scalar(num_traces, torch.int64, group=group)) + num_skipped = int(allreduce_scalar(num_skipped, torch.int64, group=group)) + + if num_problems == 0: + return { + "is_accuracy": 0.0, + "mean_ess": 0.0, + "samples_per_problem": 0.0, + "num_traces": 0, + "num_problems": 0, + "num_skipped": num_skipped, + } + + return { + "is_accuracy": sum_accuracy / num_problems, + "mean_ess": sum_ess / num_problems, + "samples_per_problem": num_traces / num_problems, + "num_traces": num_traces, + "num_problems": num_problems, + "num_skipped": num_skipped, + } + + def _scatter_logsumexp(self, src: torch.Tensor, index: torch.Tensor, num_groups: int) -> torch.Tensor: + # Max per group for numerical stability + max_vals = torch.full((num_groups,), float("-inf"), device=src.device, dtype=src.dtype) + max_vals.scatter_reduce_(0, index, src, reduce="amax") + + src_shifted = (src - max_vals[index]).exp() + sum_exp = torch.zeros(num_groups, device=src.device, dtype=src.dtype) + sum_exp.scatter_add_(0, index, src_shifted) + + return max_vals + sum_exp.log() From 54c5f9ce8902ba25864ee4eadb73a4a768f46cb7 Mon Sep 17 00:00:00 2001 From: Torsten Scholak Date: Mon, 5 Jan 2026 17:13:13 +0000 Subject: [PATCH 22/25] Fix eval mode for StochasticMixer and add diagnostics MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Switch to eval mode during IS evaluation so StochasticMixer uses the main (attention) mixer instead of random sampling - Add percentile-based diagnostic logging for log probs and ESS - Remove duplicate log_main_rank call (EvaluatorRunner already logs) - Disable backward-compat assertion for old dataset format 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- fast_llm/data/dataset/gpt/config.py | 2 +- .../engine/evaluation/forward_kl/evaluator.py | 64 ++++++++++++++----- 2 files changed, 50 insertions(+), 16 deletions(-) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 41a2fe7ff..0ed4696da 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -65,7 +65,7 @@ def build(self, preprocessing: PreprocessingConfig) -> SamplableDataset[SampleTy def _load_config(self) -> SampledDatasetConfig[SampleType]: assert self.path.is_file(), f"File {self.path} does not exist." config = yaml.safe_load(self.path.open("r")) - Assert.eq(config.keys(), {"config", "metadata"}) + # TODO: Assert.eq(config.keys(), {"config", "metadata"}) # Disabled for backward compat if config.keys() == {"config", "metadata"}: # Newer format with metadata config = config["config"] diff --git a/fast_llm/engine/evaluation/forward_kl/evaluator.py b/fast_llm/engine/evaluation/forward_kl/evaluator.py index 8b5f45f3a..80d5933c9 100644 --- a/fast_llm/engine/evaluation/forward_kl/evaluator.py +++ b/fast_llm/engine/evaluation/forward_kl/evaluator.py @@ -163,7 +163,6 @@ def run( ) if metrics["num_skipped"] > 0: formatted += f" [{metrics['num_skipped']} skipped]" - log_main_rank(formatted) return EvaluationMetrics( {f"validation.{self._name}": {k: v for k, v in metrics.items() if k != "num_skipped"}}, @@ -178,24 +177,46 @@ def _evaluate(self) -> dict[str, float]: if len(data) == 0: return self._reduce_metrics(0.0, 0.0, 0, 0, data.num_skipped) - batch_size = self._config.batch_size - student_log_probs_batches: list[torch.Tensor] = [] - - for i in range(0, len(data), batch_size): - batch_log_probs = self._compute_batch_log_probs( - data.tokens[i : i + batch_size], - data.prompt_lens[i : i + batch_size], - data.completion_lens[i : i + batch_size], - ) - if batch_log_probs is not None: - student_log_probs_batches.append(batch_log_probs) - - if not student_log_probs_batches: # non-last PP rank - return self._reduce_metrics(0.0, 0.0, 0, 0, data.num_skipped) + # Switch to eval mode so StochasticMixer uses the main (attention) mixer + # instead of randomly sampling. This ensures we evaluate the attention-only path. + was_training = self._multi_stage._training + self._multi_stage.train(False) + + try: + batch_size = self._config.batch_size + student_log_probs_batches: list[torch.Tensor] = [] + + for i in range(0, len(data), batch_size): + batch_log_probs = self._compute_batch_log_probs( + data.tokens[i : i + batch_size], + data.prompt_lens[i : i + batch_size], + data.completion_lens[i : i + batch_size], + ) + if batch_log_probs is not None: + student_log_probs_batches.append(batch_log_probs) + + if not student_log_probs_batches: # non-last PP rank + return self._reduce_metrics(0.0, 0.0, 0, 0, data.num_skipped) + finally: + # Restore original training mode + if was_training: + self._multi_stage.train(True) student_log_probs = torch.cat(student_log_probs_batches) log_w = student_log_probs - data.teacher_log_probs + # Diagnostic logging with percentiles + pcts = torch.tensor([0.01, 0.05, 0.10, 0.25, 0.50, 0.75, 0.90, 0.95, 0.99], device=log_w.device) + pct_labels = ["1%", "5%", "10%", "25%", "50%", "75%", "90%", "95%", "99%"] + + def fmt_percentiles(t: torch.Tensor) -> str: + q = torch.quantile(t.float(), pcts) + return ", ".join(f"{l}={v:.1f}" for l, v in zip(pct_labels, q.tolist())) + + logger.info(f"student_log_probs: [{fmt_percentiles(student_log_probs)}]") + logger.info(f"teacher_log_probs: [{fmt_percentiles(data.teacher_log_probs)}]") + logger.info(f"log_w: [{fmt_percentiles(log_w)}]") + log_sum_all = self._scatter_logsumexp(log_w, data.problem_indices, data.num_problems) log_w_correct = log_w.masked_fill(~data.corrects, float("-inf")) log_sum_correct = self._scatter_logsumexp(log_w_correct, data.problem_indices, data.num_problems) @@ -207,6 +228,19 @@ def _evaluate(self) -> dict[str, float]: log_sum_sq = self._scatter_logsumexp(2 * log_w, data.problem_indices, data.num_problems) ess = (2 * log_sum_all - log_sum_sq).exp().clamp(min=0.0) + # ESS diagnostics with percentiles + traces_per_problem = torch.bincount(data.problem_indices, minlength=data.num_problems) + multi_trace_mask = traces_per_problem > 1 + if multi_trace_mask.any(): + multi_ess = ess[multi_trace_mask] + multi_traces = traces_per_problem[multi_trace_mask] + logger.info( + f"ESS ({multi_trace_mask.sum().item()} multi-trace problems): [{fmt_percentiles(multi_ess)}]" + ) + logger.info( + f"traces/problem: [{fmt_percentiles(multi_traces.float())}]" + ) + return self._reduce_metrics( accuracy.sum().item(), ess.sum().item(), From fb39f67dfbad3e89f3c217b698485245a3d082f4 Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 7 Jan 2026 18:22:44 +0000 Subject: [PATCH 23/25] set test time mixer type --- fast_llm/engine/evaluation/config.py | 6 ++++ .../engine/evaluation/forward_kl/evaluator.py | 29 +++++++++++++------ fast_llm/layers/decoder/stochastic_mixer.py | 3 +- 3 files changed, 28 insertions(+), 10 deletions(-) diff --git a/fast_llm/engine/evaluation/config.py b/fast_llm/engine/evaluation/config.py index 744506b65..90881cdc1 100644 --- a/fast_llm/engine/evaluation/config.py +++ b/fast_llm/engine/evaluation/config.py @@ -158,6 +158,12 @@ class ForwardKLEvaluatorConfig(EvaluatorConfig): desc="Trust remote code when loading dataset.", hint=FieldHint.optional, ) + inference_mixer: str | None = Field( + default=None, + desc="Name of the mixer to use during evaluation (for StochasticMixer models). " + "If None, uses the model's default main_mixer_name.", + hint=FieldHint.optional, + ) def get_evaluator( self, diff --git a/fast_llm/engine/evaluation/forward_kl/evaluator.py b/fast_llm/engine/evaluation/forward_kl/evaluator.py index 80d5933c9..a0b94707b 100644 --- a/fast_llm/engine/evaluation/forward_kl/evaluator.py +++ b/fast_llm/engine/evaluation/forward_kl/evaluator.py @@ -11,7 +11,7 @@ from fast_llm.data.data.abstract import Data from fast_llm.data.sample.language_model import LanguageModelBatch, LanguageModelSample from fast_llm.data.sample.token import TokenSample -from fast_llm.engine.config_utils.run import Run, log_main_rank +from fast_llm.engine.config_utils.run import Run from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.evaluation.config import ForwardKLEvaluatorConfig @@ -177,11 +177,22 @@ def _evaluate(self) -> dict[str, float]: if len(data) == 0: return self._reduce_metrics(0.0, 0.0, 0, 0, data.num_skipped) - # Switch to eval mode so StochasticMixer uses the main (attention) mixer - # instead of randomly sampling. This ensures we evaluate the attention-only path. + # Switch to eval mode so StochasticMixer uses the main mixer + # instead of randomly sampling. was_training = self._multi_stage._training self._multi_stage.train(False) + # Optionally override the inference mixer for StochasticMixer layers + stochastic_mixers: list = [] + if self._config.inference_mixer is not None: + from fast_llm.layers.decoder.stochastic_mixer import StochasticMixer + + for name, module in self._multi_stage.base_model.named_modules(): + if isinstance(module, StochasticMixer): + stochastic_mixers.append(module) + module._inference_mixer_override = self._config.inference_mixer + logger.info(f"ForwardKL: Set {name} inference mixer to '{self._config.inference_mixer}'") + try: batch_size = self._config.batch_size student_log_probs_batches: list[torch.Tensor] = [] @@ -198,6 +209,10 @@ def _evaluate(self) -> dict[str, float]: if not student_log_probs_batches: # non-last PP rank return self._reduce_metrics(0.0, 0.0, 0, 0, data.num_skipped) finally: + # Clear inference mixer override for StochasticMixer layers + for module in stochastic_mixers: + module._inference_mixer_override = None + # Restore original training mode if was_training: self._multi_stage.train(True) @@ -234,12 +249,8 @@ def fmt_percentiles(t: torch.Tensor) -> str: if multi_trace_mask.any(): multi_ess = ess[multi_trace_mask] multi_traces = traces_per_problem[multi_trace_mask] - logger.info( - f"ESS ({multi_trace_mask.sum().item()} multi-trace problems): [{fmt_percentiles(multi_ess)}]" - ) - logger.info( - f"traces/problem: [{fmt_percentiles(multi_traces.float())}]" - ) + logger.info(f"ESS ({multi_trace_mask.sum().item()} multi-trace problems): [{fmt_percentiles(multi_ess)}]") + logger.info(f"traces/problem: [{fmt_percentiles(multi_traces.float())}]") return self._reduce_metrics( accuracy.sum().item(), diff --git a/fast_llm/layers/decoder/stochastic_mixer.py b/fast_llm/layers/decoder/stochastic_mixer.py index 984f34b80..76b261a4e 100644 --- a/fast_llm/layers/decoder/stochastic_mixer.py +++ b/fast_llm/layers/decoder/stochastic_mixer.py @@ -106,7 +106,8 @@ def setup(self, distributed: Distributed) -> None: def _sample_mixer_name(self, kwargs: dict[str, typing.Any]) -> str: if not self.training: - return self._config.main_mixer_name + # Allow runtime override of the inference mixer (e.g., for evaluation) + return getattr(self, "_inference_mixer_override", None) or self._config.main_mixer_name generator = kwargs[StochasticMixerKwargs.generator] mixer_idx = torch.multinomial(self._sampling_probs, num_samples=1, generator=generator).item() From 2702088362c0745ca2bf723236d8e6c3bf70e400 Mon Sep 17 00:00:00 2001 From: oleksost Date: Thu, 8 Jan 2026 01:33:43 +0000 Subject: [PATCH 24/25] progress bar --- .../engine/evaluation/forward_kl/evaluator.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/fast_llm/engine/evaluation/forward_kl/evaluator.py b/fast_llm/engine/evaluation/forward_kl/evaluator.py index a0b94707b..5265fff8c 100644 --- a/fast_llm/engine/evaluation/forward_kl/evaluator.py +++ b/fast_llm/engine/evaluation/forward_kl/evaluator.py @@ -2,9 +2,11 @@ import gc import hashlib import logging +import math import torch import torch.nn.functional as F +import tqdm from fast_llm.config import NoAutoValidate from fast_llm.core.distributed import allreduce_scalar, safe_barrier @@ -196,8 +198,19 @@ def _evaluate(self) -> dict[str, float]: try: batch_size = self._config.batch_size student_log_probs_batches: list[torch.Tensor] = [] + num_batches = math.ceil(len(data) / batch_size) + + # Only show progress bar on rank 0 + batch_iter = range(0, len(data), batch_size) + if self._distributed.config.rank == 0: + batch_iter = tqdm.tqdm( + batch_iter, + total=num_batches, + desc=f"ForwardKL ({self._name})", + unit="batch", + ) - for i in range(0, len(data), batch_size): + for i in batch_iter: batch_log_probs = self._compute_batch_log_probs( data.tokens[i : i + batch_size], data.prompt_lens[i : i + batch_size], From d0439ec053c1a4b7e65e2ecb45a6c8766bc5481c Mon Sep 17 00:00:00 2001 From: oleksost Date: Thu, 8 Jan 2026 15:09:57 +0000 Subject: [PATCH 25/25] prevent deadlock when evaluating with different number of batches per rank (can happen when we subset the eval dataset) --- .../engine/evaluation/forward_kl/evaluator.py | 55 +++++++++++++------ 1 file changed, 38 insertions(+), 17 deletions(-) diff --git a/fast_llm/engine/evaluation/forward_kl/evaluator.py b/fast_llm/engine/evaluation/forward_kl/evaluator.py index 5265fff8c..5e69862d2 100644 --- a/fast_llm/engine/evaluation/forward_kl/evaluator.py +++ b/fast_llm/engine/evaluation/forward_kl/evaluator.py @@ -9,7 +9,7 @@ import tqdm from fast_llm.config import NoAutoValidate -from fast_llm.core.distributed import allreduce_scalar, safe_barrier +from fast_llm.core.distributed import ReduceOp, allreduce_scalar, safe_barrier from fast_llm.data.data.abstract import Data from fast_llm.data.sample.language_model import LanguageModelBatch, LanguageModelSample from fast_llm.data.sample.token import TokenSample @@ -176,9 +176,6 @@ def _evaluate(self) -> dict[str, float]: device = self._distributed.device data = self._load_traces(device) - if len(data) == 0: - return self._reduce_metrics(0.0, 0.0, 0, 0, data.num_skipped) - # Switch to eval mode so StochasticMixer uses the main mixer # instead of randomly sampling. was_training = self._multi_stage._training @@ -198,28 +195,52 @@ def _evaluate(self) -> dict[str, float]: try: batch_size = self._config.batch_size student_log_probs_batches: list[torch.Tensor] = [] - num_batches = math.ceil(len(data) / batch_size) + local_num_batches = math.ceil(len(data) / batch_size) if len(data) > 0 else 0 + + # Synchronize batch count across all world ranks. + # All ranks must execute the same number of forward passes because the forward + # pass involves collective operations (e.g., ZeRO all-gather) that require + # participation from all ranks in the process group. + max_num_batches = int( + allreduce_scalar(local_num_batches, torch.int64, self._distributed.world_group, ReduceOp.MAX) + ) + + if max_num_batches == 0: + return self._reduce_metrics(0.0, 0.0, 0, 0, data.num_skipped) + + # Create dummy data for ranks that have no data or finish early. + # These ranks still need to participate in collective operations. + dummy_tokens = torch.zeros((batch_size, self._sequence_length), dtype=torch.int64, device=device) + dummy_prompt_lens = torch.ones(batch_size, dtype=torch.int64, device=device) + dummy_completion_lens = torch.ones(batch_size, dtype=torch.int64, device=device) # Only show progress bar on rank 0 - batch_iter = range(0, len(data), batch_size) + batch_iter = range(max_num_batches) if self._distributed.config.rank == 0: batch_iter = tqdm.tqdm( batch_iter, - total=num_batches, + total=max_num_batches, desc=f"ForwardKL ({self._name})", unit="batch", ) - for i in batch_iter: - batch_log_probs = self._compute_batch_log_probs( - data.tokens[i : i + batch_size], - data.prompt_lens[i : i + batch_size], - data.completion_lens[i : i + batch_size], - ) - if batch_log_probs is not None: - student_log_probs_batches.append(batch_log_probs) - - if not student_log_probs_batches: # non-last PP rank + for batch_idx in batch_iter: + i = batch_idx * batch_size + if i < len(data): + # This rank has real data for this batch + batch_log_probs = self._compute_batch_log_probs( + data.tokens[i : i + batch_size], + data.prompt_lens[i : i + batch_size], + data.completion_lens[i : i + batch_size], + ) + if batch_log_probs is not None: + student_log_probs_batches.append(batch_log_probs) + else: + # This rank has no more data but must still participate in collectives. + # Run a dummy forward pass and discard the result. + self._compute_batch_log_probs(dummy_tokens, dummy_prompt_lens, dummy_completion_lens) + + if not student_log_probs_batches: # non-last PP rank or no local data return self._reduce_metrics(0.0, 0.0, 0, 0, data.num_skipped) finally: # Clear inference mixer override for StochasticMixer layers