From ed6b793b19fe37ade07a55f70f801cf09183036b Mon Sep 17 00:00:00 2001 From: Torsten Scholak Date: Sun, 14 Dec 2025 02:06:30 +0000 Subject: [PATCH 01/12] Refactor Apriel2 cache and add Qwen2 converter MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Cache improvements: - Add methods to _AttentionCache and _SSMCache (reset, reorder, crop, batch_repeat, batch_select, is_initialized, batch_size) - Add _iter_caches() helper to flatten stochastic layer dicts - Simplify Apriel2Cache methods using new abstractions - Fix sliding window attention mask sizes (cumulative_length tracking) - Localize KDA tuple handling in _SSMCache Test improvements: - Split tests into contract tests (vs HuggingFace) and Apriel2-specific - Add shared fixtures to conftest.py - Add edge case tests for SSM tuple operations - Remove duplicated fixture definitions Qwen2 converter: - Add Qwen2/Qwen2.5 to Apriel2 config conversion - Add weight mapping plan for Qwen2 models πŸ€– Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- fast_llm_external_models/apriel2/cache.py | 311 ++-- .../apriel2/conversion/qwen2/__init__.py | 6 + .../apriel2/conversion/qwen2/config.py | 81 ++ .../apriel2/conversion/qwen2/plan.py | 113 ++ fast_llm_external_models/apriel2/convert.py | 9 +- .../tests/test_apriel2/conftest.py | 187 +++ .../tests/test_apriel2/test_cache.py | 1258 ----------------- .../test_cache_apriel2_specific.py | 342 +++++ .../test_apriel2/test_cache_contracts.py | 592 ++++++++ 9 files changed, 1499 insertions(+), 1400 deletions(-) create mode 100644 fast_llm_external_models/apriel2/conversion/qwen2/__init__.py create mode 100644 fast_llm_external_models/apriel2/conversion/qwen2/config.py create mode 100644 fast_llm_external_models/apriel2/conversion/qwen2/plan.py delete mode 100644 fast_llm_external_models/tests/test_apriel2/test_cache.py create mode 100644 fast_llm_external_models/tests/test_apriel2/test_cache_apriel2_specific.py create mode 100644 fast_llm_external_models/tests/test_apriel2/test_cache_contracts.py diff --git a/fast_llm_external_models/apriel2/cache.py b/fast_llm_external_models/apriel2/cache.py index 86c67a085..32db547b9 100644 --- a/fast_llm_external_models/apriel2/cache.py +++ b/fast_llm_external_models/apriel2/cache.py @@ -4,14 +4,18 @@ class _AttentionCache: - __slots__ = ["key", "value", "window"] + __slots__ = ["key", "value", "window", "cumulative_length"] def __init__(self, window=None): self.key = None self.value = None self.window = window + self.cumulative_length = 0 def update(self, key, value): + new_tokens = key.shape[-2] + self.cumulative_length += new_tokens + if self.key is None: if self.window and key.shape[-2] > self.window: self.key = key[..., -self.window :, :].contiguous() @@ -35,6 +39,40 @@ def _window(self, cache, new): return cache return torch.cat([cache, new], -2)[..., -self.window :, :].contiguous() + def reset(self): + self.key = None + self.value = None + self.cumulative_length = 0 + + def reorder(self, beam_idx): + if self.key is not None: + self.key = self.key.index_select(0, beam_idx.to(self.key.device)) + self.value = self.value.index_select(0, beam_idx.to(self.value.device)) + + def crop(self, max_length): + if self.key is not None: + self.key = self.key[..., :max_length, :] + self.value = self.value[..., :max_length, :] + self.cumulative_length = self.key.shape[-2] + + def batch_repeat(self, repeats): + if self.key is not None: + self.key = self.key.repeat_interleave(repeats, dim=0) + self.value = self.value.repeat_interleave(repeats, dim=0) + + def batch_select(self, indices): + if self.key is not None: + self.key = self.key.index_select(0, indices.to(self.key.device)) + self.value = self.value.index_select(0, indices.to(self.value.device)) + + @property + def is_initialized(self): + return self.key is not None + + @property + def batch_size(self): + return self.key.shape[0] if self.key is not None else None + class _SSMCache: __slots__ = ["conv", "recurrent"] @@ -43,6 +81,52 @@ def __init__(self): self.conv = None self.recurrent = None + def reset(self): + self.conv = None + self.recurrent = None + + def reorder(self, beam_idx): + if self.conv is not None: + if isinstance(self.conv, tuple): + self.conv = tuple(c.index_select(0, beam_idx.to(c.device)) for c in self.conv) + else: + self.conv = self.conv.index_select(0, beam_idx.to(self.conv.device)) + if self.recurrent is not None: + self.recurrent = self.recurrent.index_select(0, beam_idx.to(self.recurrent.device)) + + def crop(self, max_length): + pass # SSM caches don't have sequence dimension to crop + + def batch_repeat(self, repeats): + if self.conv is not None: + if isinstance(self.conv, tuple): + self.conv = tuple(c.repeat_interleave(repeats, dim=0) for c in self.conv) + else: + self.conv = self.conv.repeat_interleave(repeats, dim=0) + if self.recurrent is not None: + self.recurrent = self.recurrent.repeat_interleave(repeats, dim=0) + + def batch_select(self, indices): + if self.conv is not None: + if isinstance(self.conv, tuple): + self.conv = tuple(c.index_select(0, indices.to(c.device)) for c in self.conv) + else: + self.conv = self.conv.index_select(0, indices.to(self.conv.device)) + if self.recurrent is not None: + self.recurrent = self.recurrent.index_select(0, indices.to(self.recurrent.device)) + + @property + def is_initialized(self): + return self.conv is not None + + @property + def batch_size(self): + if self.conv is None: + return None + if isinstance(self.conv, tuple): + return self.conv[0].shape[0] + return self.conv.shape[0] + class _DummyCacheLayer: pass @@ -93,14 +177,19 @@ def set_active_mixer(self, layer_idx, mixer_name): self.active_mixers[layer_idx] = mixer_name def get_seq_length(self, layer_idx=0): + """Returns the cumulative sequence length of tokens seen by the cache. + + For sliding window caches, this returns the total tokens seen (not just cached). + This matches HuggingFace's DynamicSlidingWindowLayer behavior. + """ layer = self.layers[layer_idx] if isinstance(layer, dict): mixer = self.active_mixers[layer_idx] if mixer and isinstance(layer[mixer], _AttentionCache): - return layer[mixer].key.shape[-2] if layer[mixer].key is not None else 0 + return layer[mixer].cumulative_length return 0 if isinstance(layer, _AttentionCache): - return layer.key.shape[-2] if layer.key is not None else 0 + return layer.cumulative_length return 0 def get_max_cache_shape(self, layer_idx=0): @@ -114,22 +203,61 @@ def get_max_cache_shape(self, layer_idx=0): return None def get_mask_sizes(self, cache_position, layer_idx): + """Return the length and offset of the cache, used to generate the attention mask. + + For standard (non-sliding) attention: + kv_offset = 0 (KV[0] corresponds to sequence position 0) + kv_length = cumulative_length + query_length + + For sliding window attention: + kv_offset = max(cumulative_length - window + 1, 0) + kv_length = min(cumulative_length, window - 1) + query_length + + For SSM/linear layers: + kv_offset = 0, kv_length = query_length (no KV cache to attend to) + """ query_length = cache_position.shape[0] - past_seen_tokens = self.get_seq_length(layer_idx) - kv_length = query_length + past_seen_tokens - kv_offset = past_seen_tokens - return kv_length, kv_offset + layer = self.layers[layer_idx] + + # Handle stochastic layers by getting the active mixer's cache + if isinstance(layer, dict): + mixer = self.active_mixers[layer_idx] + if mixer is None: + # No active mixer set, return defaults + return query_length, 0 + cache = layer[mixer] + else: + cache = layer + + # SSM layers don't have KV cache for attention mask purposes + if isinstance(cache, _SSMCache): + return query_length, 0 + + # Attention cache - check if sliding window + if isinstance(cache, _AttentionCache): + cumulative = cache.cumulative_length + window = cache.window + + if window is not None: + # Sliding window attention + kv_offset = max(cumulative - window + 1, 0) + if cumulative >= window: + kv_length = window - 1 + query_length + else: + kv_length = cumulative + query_length + else: + # Full attention + kv_offset = 0 + kv_length = cumulative + query_length + + return kv_length, kv_offset + + # Fallback + return query_length, 0 @property def has_previous_state(self): - for layer in self.layers: - if isinstance(layer, dict): - for cache in layer.values(): - if isinstance(cache, _SSMCache) and cache.conv is not None: - return True - elif isinstance(layer, _SSMCache) and layer.conv is not None: - return True - return False + return any(isinstance(cache, _SSMCache) and cache.conv is not None for cache in self._iter_caches()) @property def key_cache(self): @@ -147,101 +275,33 @@ def conv_states(self): def recurrent_states(self): return _LayerListAccessor(self, "recurrent") - def reorder_cache(self, beam_idx): - for i, layer in enumerate(self.layers): + def _iter_caches(self): + """Iterate over all leaf cache objects (flattening stochastic layer dicts).""" + for layer in self.layers: if isinstance(layer, dict): - for cache in layer.values(): - self._reorder_cache_obj(cache, beam_idx) + yield from layer.values() else: - self._reorder_cache_obj(layer, beam_idx) + yield layer - def _reorder_cache_obj(self, cache, beam_idx): - if isinstance(cache, _AttentionCache): - if cache.key is not None: - cache.key = cache.key.index_select(0, beam_idx.to(cache.key.device)) - cache.value = cache.value.index_select(0, beam_idx.to(cache.value.device)) - elif isinstance(cache, _SSMCache): - if cache.conv is not None: - # Handle both single tensor (GDN/Mamba) and tuple (KDA) conv states - if isinstance(cache.conv, tuple): - cache.conv = tuple(c.index_select(0, beam_idx.to(c.device)) for c in cache.conv) - else: - cache.conv = cache.conv.index_select(0, beam_idx.to(cache.conv.device)) - if cache.recurrent is not None: - cache.recurrent = cache.recurrent.index_select(0, beam_idx.to(cache.recurrent.device)) + def reorder_cache(self, beam_idx): + for cache in self._iter_caches(): + cache.reorder(beam_idx) def reset(self): - for layer in self.layers: - if isinstance(layer, dict): - for cache in layer.values(): - self._reset_cache_obj(cache) - else: - self._reset_cache_obj(layer) - - def _reset_cache_obj(self, cache): - if isinstance(cache, _AttentionCache): - cache.key = None - cache.value = None - elif isinstance(cache, _SSMCache): - cache.conv = None - cache.recurrent = None + for cache in self._iter_caches(): + cache.reset() def crop(self, max_length): - for layer in self.layers: - if isinstance(layer, dict): - for cache in layer.values(): - if isinstance(cache, _AttentionCache) and cache.key is not None: - cache.key = cache.key[..., :max_length, :] - cache.value = cache.value[..., :max_length, :] - elif isinstance(layer, _AttentionCache) and layer.key is not None: - layer.key = layer.key[..., :max_length, :] - layer.value = layer.value[..., :max_length, :] + for cache in self._iter_caches(): + cache.crop(max_length) def batch_repeat_interleave(self, repeats): - for layer in self.layers: - if isinstance(layer, dict): - for cache in layer.values(): - self._batch_repeat_cache_obj(cache, repeats) - else: - self._batch_repeat_cache_obj(layer, repeats) - - def _batch_repeat_cache_obj(self, cache, repeats): - if isinstance(cache, _AttentionCache): - if cache.key is not None: - cache.key = cache.key.repeat_interleave(repeats, dim=0) - cache.value = cache.value.repeat_interleave(repeats, dim=0) - elif isinstance(cache, _SSMCache): - if cache.conv is not None: - # Handle both single tensor (GDN/Mamba) and tuple (KDA) conv states - if isinstance(cache.conv, tuple): - cache.conv = tuple(c.repeat_interleave(repeats, dim=0) for c in cache.conv) - else: - cache.conv = cache.conv.repeat_interleave(repeats, dim=0) - if cache.recurrent is not None: - cache.recurrent = cache.recurrent.repeat_interleave(repeats, dim=0) + for cache in self._iter_caches(): + cache.batch_repeat(repeats) def batch_select_indices(self, indices): - for layer in self.layers: - if isinstance(layer, dict): - for cache in layer.values(): - self._batch_select_cache_obj(cache, indices) - else: - self._batch_select_cache_obj(layer, indices) - - def _batch_select_cache_obj(self, cache, indices): - if isinstance(cache, _AttentionCache): - if cache.key is not None: - cache.key = cache.key.index_select(0, indices.to(cache.key.device)) - cache.value = cache.value.index_select(0, indices.to(cache.value.device)) - elif isinstance(cache, _SSMCache): - if cache.conv is not None: - # Handle both single tensor (GDN/Mamba) and tuple (KDA) conv states - if isinstance(cache.conv, tuple): - cache.conv = tuple(c.index_select(0, indices.to(c.device)) for c in cache.conv) - else: - cache.conv = cache.conv.index_select(0, indices.to(cache.conv.device)) - if cache.recurrent is not None: - cache.recurrent = cache.recurrent.index_select(0, indices.to(cache.recurrent.device)) + for cache in self._iter_caches(): + cache.batch_select(indices) @property def is_compileable(self): @@ -249,19 +309,7 @@ def is_compileable(self): @property def is_initialized(self): - for layer in self.layers: - if isinstance(layer, dict): - for cache in layer.values(): - if isinstance(cache, _AttentionCache) and cache.key is not None: - return True - if isinstance(cache, _SSMCache) and cache.conv is not None: - return True - else: - if isinstance(layer, _AttentionCache) and layer.key is not None: - return True - if isinstance(layer, _SSMCache) and layer.conv is not None: - return True - return False + return any(cache.is_initialized for cache in self._iter_caches()) @property def is_sliding(self): @@ -280,39 +328,20 @@ def is_sliding(self): @property def max_batch_size(self): - for layer in self.layers: - if isinstance(layer, dict): - for cache in layer.values(): - if isinstance(cache, _AttentionCache) and cache.key is not None: - return cache.key.shape[0] - if isinstance(cache, _SSMCache) and cache.conv is not None: - # Handle both single tensor and tuple conv states - if isinstance(cache.conv, tuple): - return cache.conv[0].shape[0] - return cache.conv.shape[0] - else: - if isinstance(layer, _AttentionCache) and layer.key is not None: - return layer.key.shape[0] - if isinstance(layer, _SSMCache) and layer.conv is not None: - # Handle both single tensor and tuple conv states - if isinstance(layer.conv, tuple): - return layer.conv[0].shape[0] - return layer.conv.shape[0] + for cache in self._iter_caches(): + bs = cache.batch_size + if bs is not None: + return bs return None @property def max_cache_len(self): - max_len = None - for layer in self.layers: - if isinstance(layer, dict): - for cache in layer.values(): - if isinstance(cache, _AttentionCache): - if cache.window is not None: - max_len = cache.window if max_len is None else min(max_len, cache.window) - elif isinstance(layer, _AttentionCache): - if layer.window is not None: - max_len = layer.window if max_len is None else min(max_len, layer.window) - return max_len + windows = [ + cache.window + for cache in self._iter_caches() + if isinstance(cache, _AttentionCache) and cache.window is not None + ] + return min(windows) if windows else None def __len__(self): return len(self.layers) diff --git a/fast_llm_external_models/apriel2/conversion/qwen2/__init__.py b/fast_llm_external_models/apriel2/conversion/qwen2/__init__.py new file mode 100644 index 000000000..d0a0b8e6e --- /dev/null +++ b/fast_llm_external_models/apriel2/conversion/qwen2/__init__.py @@ -0,0 +1,6 @@ +"""Qwen2/Qwen2.5 to Apriel2 conversion module.""" + +from fast_llm_external_models.apriel2.conversion.qwen2.config import convert_config +from fast_llm_external_models.apriel2.conversion.qwen2.plan import plan_qwen2_to_apriel2 + +__all__ = ["convert_config", "plan_qwen2_to_apriel2"] diff --git a/fast_llm_external_models/apriel2/conversion/qwen2/config.py b/fast_llm_external_models/apriel2/conversion/qwen2/config.py new file mode 100644 index 000000000..36df744c0 --- /dev/null +++ b/fast_llm_external_models/apriel2/conversion/qwen2/config.py @@ -0,0 +1,81 @@ +"""Qwen2/Qwen2.5 to Apriel2 config conversion.""" + + +def convert_config(qwen2_config: dict) -> dict: + """Convert Qwen2/Qwen2.5 config to Apriel2TextConfig format. + + Qwen2.5 architecture: + - Standard transformer with GQA (grouped query attention) + - QKV bias enabled, O bias disabled + - MLP bias disabled + - Gated SwiGLU MLP + - RMSNorm + - RoPE embeddings + + Args: + qwen2_config: HuggingFace Qwen2Config as dict + + Returns: + Apriel2TextConfig-compatible dict + """ + hidden_size = qwen2_config["hidden_size"] + num_attention_heads = qwen2_config["num_attention_heads"] + num_key_value_heads = qwen2_config.get("num_key_value_heads", num_attention_heads) + head_dim = hidden_size // num_attention_heads + + # Qwen2 uses QKV bias but not O bias + # The add_linear_biases in Apriel2 attention config controls all biases uniformly, + # but we can set it to True and the o_proj bias will just be missing from weights + # (handled by strict=False loading or explicit handling in the plan) + + return { + "model_type": "apriel2_text", + "architectures": ["Apriel2ForCausalLM"], + "auto_map": { + "AutoConfig": "configuration_apriel2.Apriel2TextConfig", + "AutoModel": "modeling_apriel2.Apriel2TextModel", + "AutoModelForCausalLM": "modeling_apriel2.Apriel2ForCausalLM", + }, + "hidden_size": hidden_size, + "vocab_size": qwen2_config["vocab_size"], + "tie_word_embeddings": qwen2_config.get("tie_word_embeddings", False), + "decoder": { + "type": "fixed", + "num_blocks": qwen2_config["num_hidden_layers"], + "block": { + "mixer": { + "type": "attention", + "heads": num_attention_heads, + "head_groups": num_key_value_heads, + "head_size": head_dim, + # Qwen2 has QKV bias but not O bias + # We set True and handle O bias separately + "add_linear_biases": True, + "rotary": { + "type": "mistral_1d", + "theta": qwen2_config.get("rope_theta", 1000000.0), + }, + }, + "mlp": { + "type": "mlp", + "intermediate_size": qwen2_config["intermediate_size"], + "activation": qwen2_config.get("hidden_act", "silu"), + "gated": True, + "add_linear_biases": False, + }, + "normalization": { + "type": "rms_norm", + "epsilon": qwen2_config.get("rms_norm_eps", 1e-6), + }, + }, + }, + "head": { + "normalization": { + "type": "rms_norm", + "epsilon": qwen2_config.get("rms_norm_eps", 1e-6), + } + }, + "embeddings": { + "max_position_embeddings": qwen2_config.get("max_position_embeddings", 32768), + }, + } diff --git a/fast_llm_external_models/apriel2/conversion/qwen2/plan.py b/fast_llm_external_models/apriel2/conversion/qwen2/plan.py new file mode 100644 index 000000000..e5ae3e9d8 --- /dev/null +++ b/fast_llm_external_models/apriel2/conversion/qwen2/plan.py @@ -0,0 +1,113 @@ +"""Qwen2/Qwen2.5 to Apriel2 weight conversion plan.""" + +from fast_llm_external_models.apriel2.conversion.expr import ( + Expr, + ExprPlan, + Init, + Ref, + W, +) + + +def plan_qwen2_to_apriel2(qwen2_config: dict) -> ExprPlan: + """Build an expression plan for Qwen2/Qwen2.5 to Apriel2 conversion. + + This is a pure mapping (all Ref expressions) since Qwen2β†’Apriel2 + is just renaming keys. The weight tensors are identical. + + Key mapping (source keys have "model." prefix in safetensors): + Qwen2 (safetensor key) Apriel2 + ---------------------- ------- + model.embed_tokens.weight -> model.embed_tokens.weight + model.norm.weight -> model.norm.weight + model.layers.{i}.input_layernorm.weight -> model.decoder.blocks.{i}.input_layernorm.weight + model.layers.{i}.post_attention_layernorm.weight -> model.decoder.blocks.{i}.post_attention_layernorm.weight + model.layers.{i}.self_attn.q_proj.weight -> model.decoder.blocks.{i}.mixer.q_proj.weight + model.layers.{i}.self_attn.k_proj.weight -> model.decoder.blocks.{i}.mixer.k_proj.weight + model.layers.{i}.self_attn.v_proj.weight -> model.decoder.blocks.{i}.mixer.v_proj.weight + model.layers.{i}.self_attn.o_proj.weight -> model.decoder.blocks.{i}.mixer.o_proj.weight + model.layers.{i}.mlp.gate_proj.weight -> model.decoder.blocks.{i}.mlp.gate_proj.weight + model.layers.{i}.mlp.up_proj.weight -> model.decoder.blocks.{i}.mlp.up_proj.weight + model.layers.{i}.mlp.down_proj.weight -> model.decoder.blocks.{i}.mlp.down_proj.weight + + Note: Qwen2 has QKV biases but no O bias. We skip the biases in the conversion + since Apriel2 is configured with add_linear_biases=False for uniform handling. + + Args: + qwen2_config: HuggingFace Qwen2Config as dict + + Returns: + ExprPlan with Ref mappings + """ + mappings: dict[str, Expr] = {} + + num_layers = qwen2_config["num_hidden_layers"] + hidden_size = qwen2_config["hidden_size"] + + # Static mappings (embeddings and final norm) + # Note: Qwen2 safetensor keys have "model." prefix + static_mappings = [ + (W("model", "embed_tokens", "weight"), W("model", "embed_tokens", "weight")), + (W("model", "norm", "weight"), W("model", "norm", "weight")), + ] + + # lm_head - only if not tied + if not qwen2_config.get("tie_word_embeddings", False): + static_mappings.append( + (W("lm_head", "weight"), W("lm_head", "weight")) + ) + + for src, tgt in static_mappings: + mappings[tgt] = Ref(key=src) + + # Layer mappings + for layer in range(num_layers): + # Source has "model.layers.{i}" prefix + qwen_layer = W("model", "layers", layer) + apriel_layer = W("model", "decoder", "blocks", layer) + + # Attention projections (weights and biases) + # Qwen2 has QKV bias but no O bias + for proj in ["q_proj", "k_proj", "v_proj", "o_proj"]: + src = qwen_layer / "self_attn" / proj / "weight" + tgt = apriel_layer / "mixer" / proj / "weight" + mappings[tgt] = Ref(key=src) + + # QKV biases (Qwen2 has these, but not O bias) + for proj in ["q_proj", "k_proj", "v_proj"]: + src = qwen_layer / "self_attn" / proj / "bias" + tgt = apriel_layer / "mixer" / proj / "bias" + mappings[tgt] = Ref(key=src) + + # O bias - Qwen2 doesn't have this, so initialize to zeros + # Shape is hidden_size (d_model) + mappings[apriel_layer / "mixer" / "o_proj" / "bias"] = Init( + shape=(hidden_size,), + init_type="zeros", + ) + + # MLP projections + for proj in ["gate_proj", "up_proj", "down_proj"]: + src = qwen_layer / "mlp" / proj / "weight" + tgt = apriel_layer / "mlp" / proj / "weight" + mappings[tgt] = Ref(key=src) + + # Layer norms + mappings[apriel_layer / "input_layernorm" / "weight"] = Ref( + key=qwen_layer / "input_layernorm" / "weight" + ) + mappings[apriel_layer / "post_attention_layernorm" / "weight"] = Ref( + key=qwen_layer / "post_attention_layernorm" / "weight" + ) + + return ExprPlan( + mappings=mappings, + source_format="qwen2", + target_format="apriel2", + metadata={ + "num_layers": num_layers, + "hidden_size": qwen2_config["hidden_size"], + "num_attention_heads": qwen2_config["num_attention_heads"], + "num_key_value_heads": qwen2_config.get("num_key_value_heads", qwen2_config["num_attention_heads"]), + }, + ) diff --git a/fast_llm_external_models/apriel2/convert.py b/fast_llm_external_models/apriel2/convert.py index cbf921b31..05c38c7ce 100644 --- a/fast_llm_external_models/apriel2/convert.py +++ b/fast_llm_external_models/apriel2/convert.py @@ -15,6 +15,7 @@ Supported source formats: - llava: Llava/Pixtral models +- qwen2: Qwen2/Qwen2.5 models - apriel2: Apriel2 models (surgery-only mode - no conversion, just apply surgeries) """ @@ -46,6 +47,7 @@ # Import source-specific converters from fast_llm_external_models.apriel2.conversion import llava as llava_converter +from fast_llm_external_models.apriel2.conversion import qwen2 as qwen2_converter logger = logging.getLogger(__name__) @@ -73,6 +75,7 @@ def _identity_plan(config: dict) -> ExprPlan: # Each entry maps format name to (config_converter, plan_builder) SOURCE_FORMATS: dict[str, tuple[Callable[[dict], dict], Callable[[dict], ExprPlan]]] = { "llava": (llava_converter.convert_config, llava_converter.plan_llava_to_apriel2), + "qwen2": (qwen2_converter.convert_config, qwen2_converter.plan_qwen2_to_apriel2), "apriel2": (_identity_config, _identity_plan), } @@ -88,8 +91,12 @@ def detect_source_format(config: dict) -> str | None: if model_type in ("llava", "pixtral") or "text_config" in config: return "llava" + # Qwen2/Qwen2.5 detection + if model_type == "qwen2": + return "qwen2" + # Apriel2 detection - check for Apriel2-specific structure - if model_type == "apriel2" or "decoder" in config: + if model_type in ("apriel2", "apriel2_text") or "decoder" in config: return "apriel2" return None diff --git a/fast_llm_external_models/tests/test_apriel2/conftest.py b/fast_llm_external_models/tests/test_apriel2/conftest.py index 8585aec65..5c127d97e 100644 --- a/fast_llm_external_models/tests/test_apriel2/conftest.py +++ b/fast_llm_external_models/tests/test_apriel2/conftest.py @@ -7,6 +7,8 @@ import torch from transformers import LlavaConfig, LlavaForConditionalGeneration, MistralConfig +from fast_llm_external_models.apriel2.cache import _AttentionCache, _SSMCache + # Skip marker for tests that require CUDA for Mamba forward pass requires_cuda = pytest.mark.skipif( @@ -1532,3 +1534,188 @@ def torture_surgery_chain(): }, }, ] + + +# ============================================================================= +# Cache Test Fixtures - Tensor Dimensions +# ============================================================================= + + +@pytest.fixture +def batch_size(): + """Default batch size for cache tests.""" + return 2 + + +@pytest.fixture +def num_heads(): + """Default number of attention heads for cache tests.""" + return 4 + + +@pytest.fixture +def head_dim(): + """Default head dimension for cache tests.""" + return 16 + + +@pytest.fixture +def make_kv(batch_size, num_heads, head_dim): + """Factory fixture for creating KV tensors.""" + + def _make_kv(seq_len): + return ( + torch.randn(batch_size, num_heads, seq_len, head_dim), + torch.randn(batch_size, num_heads, seq_len, head_dim), + ) + + return _make_kv + + +# ============================================================================= +# Cache Test Fixtures - HuggingFace Cache Layers +# ============================================================================= + + +@pytest.fixture +def hf_dynamic_layer(): + """HuggingFace DynamicLayer for full attention contract testing.""" + from transformers.cache_utils import DynamicLayer + + return DynamicLayer() + + +@pytest.fixture +def hf_sliding_layer(window_size): + """HuggingFace DynamicSlidingWindowLayer for sliding window contract testing.""" + from transformers.cache_utils import DynamicSlidingWindowLayer + + return DynamicSlidingWindowLayer(sliding_window=window_size) + + +# ============================================================================= +# Cache Test Fixtures - Apriel2 Low-level Caches +# ============================================================================= + + +@pytest.fixture +def apriel_attention_cache(): + """Apriel2 attention cache without window (full attention).""" + return _AttentionCache(window=None) + + +@pytest.fixture +def apriel_sliding_cache(window_size): + """Apriel2 attention cache with sliding window.""" + return _AttentionCache(window=window_size) + + +@pytest.fixture +def ssm_cache(): + """Apriel2 SSM cache for Mamba/GDN/KDA layers.""" + return _SSMCache() + + +# ============================================================================= +# Cache Test Fixtures - Apriel2 Configs (Simple Versions) +# ============================================================================= + + +@pytest.fixture +def attention_config(): + """Pure attention config (2 layers, no sliding window).""" + from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config + + return Apriel2Config( + vocab_size=100, + hidden_size=64, + decoder={ + "type": "fixed", + "num_blocks": 2, + "block": { + "mixer": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16}, + "mlp": {"type": "mlp", "intermediate_size": 256}, + "normalization": {"type": "rms_norm", "epsilon": 1e-5}, + }, + }, + ) + + +@pytest.fixture +def swa_config(): + """Sliding window attention config (2 layers, window=8).""" + from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config + + return Apriel2Config( + vocab_size=100, + hidden_size=64, + decoder={ + "type": "fixed", + "num_blocks": 2, + "block": { + "mixer": { + "type": "attention", + "heads": 4, + "head_groups": 2, + "head_size": 16, + "window_size": 8, + }, + "mlp": {"type": "mlp", "intermediate_size": 256}, + "normalization": {"type": "rms_norm", "epsilon": 1e-5}, + }, + }, + ) + + +@pytest.fixture +def ssm_config(): + """Pure SSM config (2 layers).""" + from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config + + return Apriel2Config( + vocab_size=100, + hidden_size=64, + decoder={ + "type": "fixed", + "num_blocks": 2, + "block": { + "mixer": {"type": "mamba", "state_size": 16}, + "mlp": {"type": "mlp", "intermediate_size": 256}, + "normalization": {"type": "rms_norm", "epsilon": 1e-5}, + }, + }, + ) + + +@pytest.fixture +def stochastic_config(): + """Stochastic mixer config with attention and mamba (2 layers).""" + from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config + + return Apriel2Config( + vocab_size=100, + hidden_size=64, + decoder={ + "type": "fixed", + "num_blocks": 2, + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16}, + "mamba": {"type": "mamba", "state_size": 16}, + }, + }, + "mlp": {"type": "mlp", "intermediate_size": 256}, + "normalization": {"type": "rms_norm", "epsilon": 1e-5}, + }, + }, + ) + + +# Parameterized window size fixture (used by hf_sliding_layer and apriel_sliding_cache) +@pytest.fixture(params=[4, 8, 16, 32]) +def window_size(request): + """Parameterized window sizes for sliding window tests.""" + return request.param diff --git a/fast_llm_external_models/tests/test_apriel2/test_cache.py b/fast_llm_external_models/tests/test_apriel2/test_cache.py deleted file mode 100644 index ca8158b4f..000000000 --- a/fast_llm_external_models/tests/test_apriel2/test_cache.py +++ /dev/null @@ -1,1258 +0,0 @@ -"""Comprehensive tests for Apriel2Cache. - -Architecture Overview -===================== -Apriel2Cache manages state for autoregressive generation across different mixer types: - -1. **Attention Cache** (_AttentionCache): Stores key/value states - - Supports sliding window (window_size) for SWA - - Efficient roll optimization for single-token decode - -2. **SSM Cache** (_SSMCache): Stores conv and recurrent states - - Used by Mamba, GDN, KDA - - KDA uses tuple conv states (q, k, v), others use single tensor - -3. **Stochastic Mixer Routing**: For layers with multiple mixer options - - Each mixer has independent cache (no sharing) - - active_mixer pointer routes operations to correct sub-cache - - Switching mixers preserves each mixer's independent state - -Cache Invalidation Semantics -============================ -When switching between mixers in a stochastic layer: -- Each mixer maintains its OWN independent history -- Switching does NOT invalidate the previous mixer's cache -- Switching does NOT copy state between mixers -- To invalidate: call reset() explicitly - -This is intentional for training with stochastic sampling where each mixer -should learn from its own history. For inference, main_mixer_name is fixed. - -Test Organization -================= -1. CREATION & PROPERTIES - Cache initialization, config parsing -2. ATTENTION CACHE - Updates, sliding window, concatenation -3. SSM CACHE - Conv states, recurrent states, KDA tuples -4. STOCHASTIC ROUTING - Active mixer, isolation, switching -5. CACHE INVALIDATION - Reset, per-mixer reset, coherence -6. BEAM SEARCH - batch_repeat, reorder, select -7. HF INTEGRATION - get_mask_sizes, indexing, properties -8. GENERATION PATTERNS - Prefillβ†’decode, cropβ†’continue -9. ERROR HANDLING - Guards, bounds, invalid operations -""" - -import pytest -import torch - -from fast_llm_external_models.apriel2.cache import ( - Apriel2Cache, - _AttentionCache, - _SSMCache, -) - - -# ============================================================================= -# FIXTURES - Configs and Sample Data -# ============================================================================= - - -@pytest.fixture -def tiny_attention_config(): - """Minimal config with pure attention layers.""" - from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config - - return Apriel2Config( - vocab_size=100, - hidden_size=64, - decoder={ - "type": "fixed", - "num_blocks": 2, - "block": { - "mixer": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16}, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - }, - ) - - -@pytest.fixture -def swa_config(): - """Config with sliding window attention.""" - from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config - - return Apriel2Config( - vocab_size=100, - hidden_size=64, - decoder={ - "type": "fixed", - "num_blocks": 2, - "block": { - "mixer": { - "type": "attention", - "heads": 4, - "head_groups": 2, - "head_size": 16, - "window_size": 8, # Small for testing - }, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - }, - ) - - -@pytest.fixture -def ssm_config(): - """Config with pure SSM layers (mamba).""" - from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config - - return Apriel2Config( - vocab_size=100, - hidden_size=64, - decoder={ - "type": "fixed", - "num_blocks": 2, - "block": { - "mixer": { - "type": "mamba", - "d_inner": 128, - "d_state": 16, - "dt_rank": 4, - "d_conv": 4, - }, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - }, - ) - - -@pytest.fixture -def kda_config(): - """Config with pure KDA layers.""" - from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config - - return Apriel2Config( - vocab_size=100, - hidden_size=64, - decoder={ - "type": "fixed", - "num_blocks": 2, - "block": { - "mixer": { - "type": "kda", - "heads": 4, - "head_dim": 16, - "convolution_layer": {"kernel_size": 4}, - "normalization": {"epsilon": 1e-5}, - }, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - }, - ) - - -@pytest.fixture -def stochastic_config(): - """Config with stochastic mixer (attention + mamba).""" - from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config - - return Apriel2Config( - vocab_size=100, - hidden_size=64, - decoder={ - "type": "pattern", - "num_blocks": 2, - "pattern": ["attn", "stochastic"], - "blocks": { - "attn": { - "mixer": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16}, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - "stochastic": { - "mixer": { - "type": "stochastic", - "main_mixer_name": "attention", - "mixers": { - "attention": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16}, - "mamba": {"type": "mamba", "d_inner": 128, "d_state": 16, "dt_rank": 4, "d_conv": 4}, - }, - }, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - }, - }, - ) - - -@pytest.fixture -def all_mixers_config(): - """Config with stochastic mixer containing all 5 mixer types.""" - from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config - - return Apriel2Config( - vocab_size=100, - hidden_size=64, - decoder={ - "type": "pattern", - "num_blocks": 2, - "pattern": ["attn", "all_mixers"], - "blocks": { - "attn": { - "mixer": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16}, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - "all_mixers": { - "mixer": { - "type": "stochastic", - "main_mixer_name": "attention", - "mixers": { - "attention": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16}, - "swa": { - "type": "attention", - "heads": 4, - "head_groups": 2, - "head_size": 16, - "window_size": 1024, - }, - "mamba": {"type": "mamba", "d_inner": 128, "d_state": 16, "dt_rank": 4, "d_conv": 4}, - "gdn": { - "type": "gdn", - "value_heads": 4, - "key_heads": 2, - "key_head_dim": 16, - "value_head_dim": 16, - "convolution_layer": {"kernel_size": 4}, - }, - "kda": { - "type": "kda", - "heads": 4, - "head_dim": 16, - "convolution_layer": {"kernel_size": 4}, - "normalization": {"epsilon": 1e-5}, - }, - }, - }, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - }, - }, - ) - - -@pytest.fixture -def multi_window_config(): - """Config with multiple different window sizes.""" - from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config - - return Apriel2Config( - vocab_size=100, - hidden_size=64, - decoder={ - "type": "pattern", - "num_blocks": 3, - "pattern": ["full", "small_window", "large_window"], - "blocks": { - "full": { - "mixer": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16}, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - "small_window": { - "mixer": { - "type": "attention", - "heads": 4, - "head_groups": 2, - "head_size": 16, - "window_size": 512, - }, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - "large_window": { - "mixer": { - "type": "attention", - "heads": 4, - "head_groups": 2, - "head_size": 16, - "window_size": 2048, - }, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - }, - }, - ) - - -@pytest.fixture -def sample_kv(): - """Sample key/value tensors: [batch=2, heads=4, seq=10, head_dim=16].""" - return torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16) - - -@pytest.fixture -def sample_conv_single(): - """Sample single-tensor conv state: [batch=2, d_inner=128, kernel=4].""" - return torch.randn(2, 128, 4) - - -@pytest.fixture -def sample_conv_tuple(): - """Sample tuple conv state for KDA: (q, k, v) each [batch=2, d=64, kernel=3].""" - return (torch.randn(2, 64, 3), torch.randn(2, 64, 3), torch.randn(2, 64, 3)) - - -@pytest.fixture -def sample_recurrent(): - """Sample recurrent state: [batch=2, heads=4, head_dim=16, d_state=16].""" - return torch.randn(2, 4, 16, 16) - - -# ============================================================================= -# SECTION 1: CACHE CREATION & PROPERTIES -# ============================================================================= - - -class TestCacheCreation: - """Test cache initialization from config.""" - - def test_attention_cache_creation(self, tiny_attention_config): - """Create cache for pure attention config.""" - cache = Apriel2Cache(tiny_attention_config) - - assert len(cache) == 2 - assert cache.mixer_types == ["attention", "attention"] - assert all(isinstance(l, _AttentionCache) for l in cache.layers) - - def test_ssm_cache_creation(self, ssm_config): - """Create cache for pure SSM config.""" - cache = Apriel2Cache(ssm_config) - - assert len(cache) == 2 - assert cache.mixer_types == ["mamba", "mamba"] - assert all(isinstance(l, _SSMCache) for l in cache.layers) - - def test_kda_cache_creation(self, kda_config): - """Create cache for pure KDA config.""" - cache = Apriel2Cache(kda_config) - - assert len(cache) == 2 - assert cache.mixer_types == ["kda", "kda"] - assert all(isinstance(l, _SSMCache) for l in cache.layers) - - def test_stochastic_cache_creation(self, stochastic_config): - """Create cache for stochastic mixer config.""" - cache = Apriel2Cache(stochastic_config) - - assert len(cache) == 2 - # Layer 0: pure attention, Layer 1: stochastic (dict) - assert isinstance(cache.layers[0], _AttentionCache) - assert isinstance(cache.layers[1], dict) - assert set(cache.layers[1].keys()) == {"attention", "mamba"} - - def test_swa_window_captured(self, swa_config): - """Verify sliding window size is captured.""" - cache = Apriel2Cache(swa_config) - - assert cache.layers[0].window == 8 - assert cache.is_sliding == [True, True] - - def test_active_mixers_initialized_none(self, stochastic_config): - """Verify active_mixers starts as None for all layers.""" - cache = Apriel2Cache(stochastic_config) - - assert cache.active_mixers == [None, None] - - -class TestCacheProperties: - """Test cache property accessors.""" - - def test_empty_cache_properties(self, tiny_attention_config): - """Test properties of uninitialized cache.""" - cache = Apriel2Cache(tiny_attention_config) - - assert cache.is_initialized == False - assert cache.has_previous_state == False - assert cache.max_batch_size is None - assert cache.max_cache_len is None - assert cache.is_compileable == False - - def test_is_initialized_attention(self, tiny_attention_config, sample_kv): - """is_initialized detects attention cache.""" - cache = Apriel2Cache(tiny_attention_config) - cache.update(*sample_kv, layer_idx=0) - - assert cache.is_initialized == True - - def test_is_initialized_ssm(self, ssm_config, sample_conv_single): - """is_initialized detects SSM cache.""" - cache = Apriel2Cache(ssm_config) - cache.conv_states[0] = sample_conv_single - - assert cache.is_initialized == True - - def test_has_previous_state_ssm_only(self, ssm_config, sample_conv_single): - """has_previous_state only looks at SSM conv states.""" - cache = Apriel2Cache(ssm_config) - - assert cache.has_previous_state == False - cache.conv_states[0] = sample_conv_single - assert cache.has_previous_state == True - - def test_has_previous_state_ignores_attention(self, tiny_attention_config, sample_kv): - """has_previous_state ignores attention cache.""" - cache = Apriel2Cache(tiny_attention_config) - cache.update(*sample_kv, layer_idx=0) - - # Attention cache is set, but has_previous_state only checks SSM - assert cache.has_previous_state == False - - def test_max_batch_size_from_attention(self, tiny_attention_config, sample_kv): - """max_batch_size from attention cache.""" - cache = Apriel2Cache(tiny_attention_config) - cache.update(*sample_kv, layer_idx=0) - - assert cache.max_batch_size == 2 - - def test_max_batch_size_from_ssm(self, ssm_config, sample_conv_single): - """max_batch_size from SSM cache.""" - cache = Apriel2Cache(ssm_config) - cache.conv_states[0] = sample_conv_single - - assert cache.max_batch_size == 2 - - def test_max_batch_size_from_kda_tuple(self, kda_config, sample_conv_tuple): - """max_batch_size from KDA tuple conv state.""" - cache = Apriel2Cache(kda_config) - cache.conv_states[0] = sample_conv_tuple - - assert cache.max_batch_size == 2 - - def test_max_cache_len_single_window(self, swa_config): - """max_cache_len with single window size.""" - cache = Apriel2Cache(swa_config) - assert cache.max_cache_len == 8 - - def test_max_cache_len_multiple_windows(self, multi_window_config): - """max_cache_len returns minimum window.""" - cache = Apriel2Cache(multi_window_config) - assert cache.max_cache_len == 512 # min(512, 2048) - - def test_max_cache_len_no_windows(self, tiny_attention_config): - """max_cache_len is None when no windows.""" - cache = Apriel2Cache(tiny_attention_config) - assert cache.max_cache_len is None - - def test_is_sliding_mixed(self, multi_window_config): - """is_sliding reflects per-layer window presence.""" - cache = Apriel2Cache(multi_window_config) - assert cache.is_sliding == [False, True, True] - - -# ============================================================================= -# SECTION 2: ATTENTION CACHE OPERATIONS -# ============================================================================= - - -class TestAttentionCacheBasics: - """Test basic attention cache operations.""" - - def test_update_stores_kv(self, tiny_attention_config, sample_kv): - """update() stores key/value states.""" - cache = Apriel2Cache(tiny_attention_config) - key, value = sample_kv - - k_out, v_out = cache.update(key, value, layer_idx=0) - - torch.testing.assert_close(k_out, key) - torch.testing.assert_close(v_out, value) - assert cache.get_seq_length(0) == 10 - - def test_update_concatenates(self, tiny_attention_config, sample_kv): - """Subsequent updates concatenate.""" - cache = Apriel2Cache(tiny_attention_config) - key, value = sample_kv - - cache.update(key, value, layer_idx=0) - k_out, v_out = cache.update(key, value, layer_idx=0) - - assert k_out.shape[-2] == 20 - assert cache.get_seq_length(0) == 20 - - def test_key_value_cache_accessors(self, tiny_attention_config, sample_kv): - """Test key_cache and value_cache accessors.""" - cache = Apriel2Cache(tiny_attention_config) - cache.update(*sample_kv, layer_idx=0) - - assert cache.key_cache[0] is not None - assert cache.value_cache[0] is not None - torch.testing.assert_close(cache.key_cache[0], sample_kv[0]) - - -class TestSlidingWindowAttention: - """Test sliding window attention behavior.""" - - def test_initial_within_window(self, swa_config): - """Initial sequence within window is kept.""" - cache = Apriel2Cache(swa_config) - key = torch.randn(2, 4, 5, 16) # seq=5 < window=8 - value = torch.randn(2, 4, 5, 16) - - cache.update(key, value, layer_idx=0) - - assert cache.get_seq_length(0) == 5 - - def test_initial_exceeds_window(self, swa_config): - """Initial sequence > window is truncated to last window tokens.""" - cache = Apriel2Cache(swa_config) - key = torch.arange(12).float().view(1, 1, 12, 1).expand(2, 4, 12, 16) - value = key.clone() - - k_out, v_out = cache.update(key, value, layer_idx=0) - - assert cache.get_seq_length(0) == 8 - # Should keep tokens 4-11 (last 8) - assert k_out[0, 0, 0, 0].item() == 4.0 - - def test_single_token_roll_path(self, swa_config): - """Single token decode with full window uses efficient roll.""" - cache = Apriel2Cache(swa_config) - - # Fill window exactly - key1 = torch.arange(8).float().view(1, 1, 8, 1).expand(2, 4, 8, 16) - cache.update(key1, key1.clone(), layer_idx=0) - - # Decode single token - key2 = torch.full((2, 4, 1, 16), 8.0) - k_out, _ = cache.update(key2, key2.clone(), layer_idx=0) - - assert cache.get_seq_length(0) == 8 - assert k_out[0, 0, 0, 0].item() == 1.0 # Token 0 rolled out - assert k_out[0, 0, 7, 0].item() == 8.0 # New token at end - - def test_multi_token_cat_slice_path(self, swa_config): - """Multiple tokens use cat+slice path.""" - cache = Apriel2Cache(swa_config) - - # Fill window - key1 = torch.randn(2, 4, 8, 16) - cache.update(key1, key1.clone(), layer_idx=0) - - # Add 3 tokens - key2 = torch.randn(2, 4, 3, 16) - k_out, _ = cache.update(key2, key2.clone(), layer_idx=0) - - assert cache.get_seq_length(0) == 8 - torch.testing.assert_close(k_out[..., -3:, :], key2) - - def test_partial_then_fill_then_overflow(self, swa_config): - """Progressive filling: partial β†’ full β†’ overflow.""" - cache = Apriel2Cache(swa_config) - - cache.update(torch.randn(2, 4, 5, 16), torch.randn(2, 4, 5, 16), layer_idx=0) - assert cache.get_seq_length(0) == 5 - - cache.update(torch.randn(2, 4, 3, 16), torch.randn(2, 4, 3, 16), layer_idx=0) - assert cache.get_seq_length(0) == 8 - - cache.update(torch.randn(2, 4, 2, 16), torch.randn(2, 4, 2, 16), layer_idx=0) - assert cache.get_seq_length(0) == 8 - - def test_contiguous_output(self, swa_config): - """Outputs are contiguous after windowing.""" - cache = Apriel2Cache(swa_config) - - cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0) - cache.update(torch.randn(2, 4, 5, 16), torch.randn(2, 4, 5, 16), layer_idx=0) - - assert cache.layers[0].key.is_contiguous() - assert cache.layers[0].value.is_contiguous() - - -# ============================================================================= -# SECTION 3: SSM CACHE OPERATIONS -# ============================================================================= - - -class TestSSMCacheBasics: - """Test basic SSM cache operations.""" - - def test_conv_states_accessor(self, ssm_config, sample_conv_single): - """Test conv_states accessor.""" - cache = Apriel2Cache(ssm_config) - - cache.conv_states[0] = sample_conv_single - torch.testing.assert_close(cache.conv_states[0], sample_conv_single) - - def test_recurrent_states_accessor(self, ssm_config, sample_recurrent): - """Test recurrent_states accessor.""" - cache = Apriel2Cache(ssm_config) - - cache.recurrent_states[0] = sample_recurrent - torch.testing.assert_close(cache.recurrent_states[0], sample_recurrent) - - def test_ssm_seq_length_always_zero(self, ssm_config, sample_conv_single): - """get_seq_length returns 0 for SSM (no KV cache).""" - cache = Apriel2Cache(ssm_config) - cache.conv_states[0] = sample_conv_single - - assert cache.get_seq_length(0) == 0 - - -class TestKDACache: - """Test KDA-specific cache operations with tuple conv states.""" - - def test_tuple_conv_storage(self, kda_config, sample_conv_tuple): - """KDA stores tuple conv states.""" - cache = Apriel2Cache(kda_config) - - cache.conv_states[0] = sample_conv_tuple - - assert isinstance(cache.conv_states[0], tuple) - assert len(cache.conv_states[0]) == 3 - for i in range(3): - torch.testing.assert_close(cache.conv_states[0][i], sample_conv_tuple[i]) - - def test_tuple_with_recurrent(self, kda_config, sample_conv_tuple, sample_recurrent): - """KDA can have both tuple conv and recurrent states.""" - cache = Apriel2Cache(kda_config) - - cache.conv_states[0] = sample_conv_tuple - cache.recurrent_states[0] = sample_recurrent - - assert isinstance(cache.conv_states[0], tuple) - assert cache.recurrent_states[0] is not None - - def test_has_previous_state_detects_tuple(self, kda_config, sample_conv_tuple): - """has_previous_state works with tuple conv states.""" - cache = Apriel2Cache(kda_config) - - assert cache.has_previous_state == False - cache.conv_states[0] = sample_conv_tuple - assert cache.has_previous_state == True - - -# ============================================================================= -# SECTION 4: STOCHASTIC ROUTING -# ============================================================================= - - -class TestStochasticRouting: - """Test stochastic mixer cache routing.""" - - def test_set_active_mixer(self, stochastic_config): - """set_active_mixer sets the pointer.""" - cache = Apriel2Cache(stochastic_config) - - cache.set_active_mixer(1, "attention") - assert cache.active_mixers[1] == "attention" - - cache.set_active_mixer(1, "mamba") - assert cache.active_mixers[1] == "mamba" - - def test_operations_route_to_active(self, stochastic_config, sample_kv): - """Operations route to currently active mixer.""" - cache = Apriel2Cache(stochastic_config) - - cache.set_active_mixer(1, "attention") - cache.update(*sample_kv, layer_idx=1) - attn_len = cache.get_seq_length(1) - - cache.set_active_mixer(1, "mamba") - mamba_len = cache.get_seq_length(1) - - assert attn_len == 10 - assert mamba_len == 0 # Mamba cache is separate and empty - - def test_each_mixer_independent_cache(self, stochastic_config, sample_kv, sample_conv_single): - """Each mixer maintains independent cache.""" - cache = Apriel2Cache(stochastic_config) - - # Fill attention cache - cache.set_active_mixer(1, "attention") - cache.update(*sample_kv, layer_idx=1) - - # Fill mamba cache - cache.set_active_mixer(1, "mamba") - cache.conv_states[1] = sample_conv_single - - # Both preserved - cache.set_active_mixer(1, "attention") - assert cache.get_seq_length(1) == 10 - - cache.set_active_mixer(1, "mamba") - torch.testing.assert_close(cache.conv_states[1], sample_conv_single) - - -class TestMixerSwitching: - """Test behavior when switching between mixers mid-generation.""" - - def test_switch_preserves_previous_state(self, stochastic_config, sample_kv): - """Switching mixers preserves previous mixer's state.""" - cache = Apriel2Cache(stochastic_config) - - cache.set_active_mixer(1, "attention") - cache.update(*sample_kv, layer_idx=1) - original_key = cache.layers[1]["attention"].key.clone() - - # Switch to mamba, do something - cache.set_active_mixer(1, "mamba") - cache.conv_states[1] = torch.randn(2, 128, 4) - - # Switch back - attention unchanged - cache.set_active_mixer(1, "attention") - torch.testing.assert_close(cache.layers[1]["attention"].key, original_key) - - def test_switch_does_not_copy_state(self, stochastic_config, sample_kv): - """Switching does NOT copy state between mixers.""" - cache = Apriel2Cache(stochastic_config) - - # Fill attention with 10 tokens - cache.set_active_mixer(1, "attention") - cache.update(*sample_kv, layer_idx=1) - - # Switch to mamba - it has NO history from attention - cache.set_active_mixer(1, "mamba") - assert cache.conv_states[1] is None - assert cache.recurrent_states[1] is None - - def test_has_previous_state_checks_all_sub_caches(self, stochastic_config): - """has_previous_state checks ALL sub-caches, not just active.""" - cache = Apriel2Cache(stochastic_config) - - cache.set_active_mixer(1, "mamba") - cache.conv_states[1] = torch.randn(2, 128, 4) - - # Even if we switch away, has_previous_state still detects it - cache.set_active_mixer(1, "attention") - assert cache.has_previous_state == True - - -class TestAllMixerTypes: - """Test cache isolation across all 5 mixer types.""" - - def test_all_five_mixer_types_isolated(self, all_mixers_config): - """All 5 mixer types maintain isolated caches.""" - cache = Apriel2Cache(all_mixers_config) - layer_idx = 1 # Stochastic layer - - # Fill each mixer's cache - cache.set_active_mixer(layer_idx, "attention") - attn_kv = (torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16)) - cache.update(*attn_kv, layer_idx=layer_idx) - - cache.set_active_mixer(layer_idx, "swa") - swa_kv = (torch.randn(2, 4, 5, 16), torch.randn(2, 4, 5, 16)) - cache.update(*swa_kv, layer_idx=layer_idx) - - cache.set_active_mixer(layer_idx, "mamba") - mamba_conv = torch.randn(2, 128, 4) - cache.conv_states[layer_idx] = mamba_conv - - cache.set_active_mixer(layer_idx, "gdn") - gdn_conv = torch.randn(2, 64, 3) - cache.conv_states[layer_idx] = gdn_conv - - cache.set_active_mixer(layer_idx, "kda") - kda_conv = (torch.randn(2, 64, 3), torch.randn(2, 64, 3), torch.randn(2, 64, 3)) - cache.conv_states[layer_idx] = kda_conv - - # Verify all preserved - cache.set_active_mixer(layer_idx, "attention") - assert cache.get_seq_length(layer_idx) == 10 - - cache.set_active_mixer(layer_idx, "swa") - assert cache.get_seq_length(layer_idx) == 5 - - cache.set_active_mixer(layer_idx, "mamba") - torch.testing.assert_close(cache.conv_states[layer_idx], mamba_conv) - - cache.set_active_mixer(layer_idx, "gdn") - torch.testing.assert_close(cache.conv_states[layer_idx], gdn_conv) - - cache.set_active_mixer(layer_idx, "kda") - assert isinstance(cache.conv_states[layer_idx], tuple) - - -# ============================================================================= -# SECTION 5: CACHE INVALIDATION -# ============================================================================= - - -class TestCacheInvalidation: - """Test cache invalidation and reset semantics. - - Key principle: Each mixer maintains independent state. To invalidate: - - reset() clears ALL caches across ALL layers and mixers - - There is no per-mixer reset (by design - each mixer is independent) - """ - - def test_reset_clears_attention(self, tiny_attention_config, sample_kv): - """reset() clears attention cache.""" - cache = Apriel2Cache(tiny_attention_config) - cache.update(*sample_kv, layer_idx=0) - - cache.reset() - - assert cache.is_initialized == False - assert cache.get_seq_length(0) == 0 - - def test_reset_clears_ssm(self, ssm_config, sample_conv_single, sample_recurrent): - """reset() clears SSM cache.""" - cache = Apriel2Cache(ssm_config) - cache.conv_states[0] = sample_conv_single - cache.recurrent_states[0] = sample_recurrent - - cache.reset() - - assert cache.has_previous_state == False - assert cache.conv_states[0] is None - assert cache.recurrent_states[0] is None - - def test_reset_clears_kda_tuple(self, kda_config, sample_conv_tuple): - """reset() clears KDA tuple conv states.""" - cache = Apriel2Cache(kda_config) - cache.conv_states[0] = sample_conv_tuple - - cache.reset() - - assert cache.conv_states[0] is None - - def test_reset_clears_all_stochastic_mixers(self, all_mixers_config): - """reset() clears ALL mixer caches in stochastic layer.""" - cache = Apriel2Cache(all_mixers_config) - layer_idx = 1 - - # Fill all mixers - cache.set_active_mixer(layer_idx, "attention") - cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=layer_idx) - - cache.set_active_mixer(layer_idx, "mamba") - cache.conv_states[layer_idx] = torch.randn(2, 128, 4) - - cache.set_active_mixer(layer_idx, "kda") - cache.conv_states[layer_idx] = (torch.randn(2, 64, 3),) * 3 - - cache.reset() - - # All cleared - assert cache.layers[layer_idx]["attention"].key is None - assert cache.layers[layer_idx]["mamba"].conv is None - assert cache.layers[layer_idx]["kda"].conv is None - - def test_crop_truncates_attention(self, tiny_attention_config, sample_kv): - """crop() truncates attention cache to max_length.""" - cache = Apriel2Cache(tiny_attention_config) - cache.update(*sample_kv, layer_idx=0) - - cache.crop(5) - - assert cache.get_seq_length(0) == 5 - - def test_crop_affects_all_layers(self, tiny_attention_config, sample_kv): - """crop() affects all layers.""" - cache = Apriel2Cache(tiny_attention_config) - cache.update(*sample_kv, layer_idx=0) - cache.update(*sample_kv, layer_idx=1) - - cache.crop(3) - - assert cache.get_seq_length(0) == 3 - assert cache.get_seq_length(1) == 3 - - def test_crop_ignores_ssm(self, ssm_config, sample_conv_single): - """crop() only affects attention, not SSM.""" - cache = Apriel2Cache(ssm_config) - cache.conv_states[0] = sample_conv_single - - cache.crop(5) # Should not crash - - # Conv state unchanged - torch.testing.assert_close(cache.conv_states[0], sample_conv_single) - - -# ============================================================================= -# SECTION 6: BEAM SEARCH OPERATIONS -# ============================================================================= - - -class TestBatchRepeatInterleave: - """Test batch_repeat_interleave for beam search expansion.""" - - def test_repeat_attention(self, tiny_attention_config, sample_kv): - """Repeat attention cache for beam search.""" - cache = Apriel2Cache(tiny_attention_config) - cache.update(*sample_kv, layer_idx=0) - - cache.batch_repeat_interleave(3) - - assert cache.max_batch_size == 6 # 2 * 3 - - def test_repeat_ssm(self, ssm_config, sample_conv_single, sample_recurrent): - """Repeat SSM cache for beam search.""" - cache = Apriel2Cache(ssm_config) - cache.conv_states[0] = sample_conv_single - cache.recurrent_states[0] = sample_recurrent - - cache.batch_repeat_interleave(4) - - assert cache.conv_states[0].shape[0] == 8 # 2 * 4 - assert cache.recurrent_states[0].shape[0] == 8 - - def test_repeat_kda_tuple(self, kda_config, sample_conv_tuple): - """Repeat KDA tuple conv states.""" - cache = Apriel2Cache(kda_config) - cache.conv_states[0] = sample_conv_tuple - - cache.batch_repeat_interleave(3) - - for c in cache.conv_states[0]: - assert c.shape[0] == 6 - - def test_repeat_stochastic_all_mixers(self, all_mixers_config): - """Repeat all mixer caches in stochastic layer.""" - cache = Apriel2Cache(all_mixers_config) - layer_idx = 1 - - cache.set_active_mixer(layer_idx, "attention") - cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=layer_idx) - - cache.set_active_mixer(layer_idx, "mamba") - cache.conv_states[layer_idx] = torch.randn(2, 128, 4) - - cache.batch_repeat_interleave(2) - - cache.set_active_mixer(layer_idx, "attention") - assert cache.layers[layer_idx]["attention"].key.shape[0] == 4 - - cache.set_active_mixer(layer_idx, "mamba") - assert cache.conv_states[layer_idx].shape[0] == 4 - - def test_repeat_skips_none(self, tiny_attention_config): - """Repeat gracefully skips None caches.""" - cache = Apriel2Cache(tiny_attention_config) - # Don't fill anything - - cache.batch_repeat_interleave(3) # Should not crash - - assert cache.max_batch_size is None - - -class TestReorderCache: - """Test reorder_cache for beam search hypothesis selection.""" - - def test_reorder_attention(self, tiny_attention_config, sample_kv): - """Reorder attention cache.""" - cache = Apriel2Cache(tiny_attention_config) - key, value = sample_kv - # Make batches distinguishable - key = torch.arange(2).float().view(2, 1, 1, 1).expand(2, 4, 10, 16) - cache.update(key, key.clone(), layer_idx=0) - - beam_idx = torch.tensor([1, 0]) - cache.reorder_cache(beam_idx) - - assert cache.layers[0].key[0, 0, 0, 0].item() == 1.0 - assert cache.layers[0].key[1, 0, 0, 0].item() == 0.0 - - def test_reorder_ssm(self, ssm_config): - """Reorder SSM cache.""" - cache = Apriel2Cache(ssm_config) - conv = torch.arange(2).float().view(2, 1, 1).expand(2, 128, 4) - cache.conv_states[0] = conv.clone() - - beam_idx = torch.tensor([1, 0]) - cache.reorder_cache(beam_idx) - - assert cache.conv_states[0][0, 0, 0].item() == 1.0 - - def test_reorder_kda_tuple(self, kda_config): - """Reorder KDA tuple conv states.""" - cache = Apriel2Cache(kda_config) - conv_q = torch.arange(2).float().view(2, 1, 1).expand(2, 64, 3) - cache.conv_states[0] = (conv_q.clone(), conv_q.clone(), conv_q.clone()) - - beam_idx = torch.tensor([1, 0]) - cache.reorder_cache(beam_idx) - - for c in cache.conv_states[0]: - assert c[0, 0, 0].item() == 1.0 - - -class TestBatchSelectIndices: - """Test batch_select_indices for beam selection.""" - - def test_select_attention(self, tiny_attention_config, sample_kv): - """Select subset of attention cache.""" - cache = Apriel2Cache(tiny_attention_config) - key = torch.arange(4).float().view(4, 1, 1, 1).expand(4, 4, 10, 16) - cache.update(key, key.clone(), layer_idx=0) - - indices = torch.tensor([0, 3]) - cache.batch_select_indices(indices) - - assert cache.max_batch_size == 2 - assert cache.layers[0].key[0, 0, 0, 0].item() == 0.0 - assert cache.layers[0].key[1, 0, 0, 0].item() == 3.0 - - def test_select_kda_tuple(self, kda_config): - """Select subset of KDA tuple conv states.""" - cache = Apriel2Cache(kda_config) - conv = tuple(torch.arange(4).float().view(4, 1, 1).expand(4, 64, 3).clone() for _ in range(3)) - cache.conv_states[0] = conv - - indices = torch.tensor([1, 2]) - cache.batch_select_indices(indices) - - for c in cache.conv_states[0]: - assert c.shape[0] == 2 - assert c[0, 0, 0].item() == 1.0 - - -# ============================================================================= -# SECTION 7: HUGGINGFACE INTEGRATION -# ============================================================================= - - -class TestGetMaskSizes: - """Test get_mask_sizes() for attention mask computation.""" - - def test_empty_cache(self, tiny_attention_config): - """Mask sizes with empty cache.""" - cache = Apriel2Cache(tiny_attention_config) - cache_position = torch.arange(10) - - kv_length, kv_offset = cache.get_mask_sizes(cache_position, layer_idx=0) - - assert kv_length == 10 - assert kv_offset == 0 - - def test_with_cached_tokens(self, tiny_attention_config, sample_kv): - """Mask sizes with cached tokens.""" - cache = Apriel2Cache(tiny_attention_config) - cache.update(*sample_kv, layer_idx=0) # 10 tokens - - cache_position = torch.arange(5) - kv_length, kv_offset = cache.get_mask_sizes(cache_position, layer_idx=0) - - assert kv_length == 15 # 10 + 5 - assert kv_offset == 10 - - def test_single_token_decode(self, tiny_attention_config, sample_kv): - """Mask sizes for single token decode.""" - cache = Apriel2Cache(tiny_attention_config) - cache.update(*sample_kv, layer_idx=0) - - cache_position = torch.arange(1) - kv_length, kv_offset = cache.get_mask_sizes(cache_position, layer_idx=0) - - assert kv_length == 11 - assert kv_offset == 10 - - def test_ssm_returns_query_only(self, ssm_config, sample_conv_single): - """SSM layers return query_length (no KV cache).""" - cache = Apriel2Cache(ssm_config) - cache.conv_states[0] = sample_conv_single - - cache_position = torch.arange(5) - kv_length, kv_offset = cache.get_mask_sizes(cache_position, layer_idx=0) - - assert kv_length == 5 - assert kv_offset == 0 - - -class TestCacheIndexing: - """Test cache[idx] indexing.""" - - def test_attention_returns_kv(self, tiny_attention_config, sample_kv): - """Indexing attention layer returns (key, value).""" - cache = Apriel2Cache(tiny_attention_config) - cache.update(*sample_kv, layer_idx=0) - - result = cache[0] - - assert isinstance(result, tuple) - torch.testing.assert_close(result[0], sample_kv[0]) - - def test_empty_returns_empty_tensors(self, tiny_attention_config): - """Indexing empty layer returns empty tensors.""" - cache = Apriel2Cache(tiny_attention_config) - - result = cache[0] - - assert result[0].numel() == 0 - assert result[1].numel() == 0 - - def test_ssm_returns_empty(self, ssm_config, sample_conv_single): - """Indexing SSM layer returns empty (no KV).""" - cache = Apriel2Cache(ssm_config) - cache.conv_states[0] = sample_conv_single - - result = cache[0] - - assert result[0].numel() == 0 - - def test_stochastic_attention_returns_kv(self, stochastic_config, sample_kv): - """Indexing stochastic with attention active returns KV.""" - cache = Apriel2Cache(stochastic_config) - cache.set_active_mixer(1, "attention") - cache.update(*sample_kv, layer_idx=1) - - result = cache[1] - - torch.testing.assert_close(result[0], sample_kv[0]) - - -# ============================================================================= -# SECTION 8: GENERATION PATTERNS -# ============================================================================= - - -class TestGenerationPatterns: - """Test real-world generation patterns.""" - - def test_prefill_then_decode(self, tiny_attention_config, sample_kv): - """Prefill with long prompt, then decode token-by-token.""" - cache = Apriel2Cache(tiny_attention_config) - cache.update(*sample_kv, layer_idx=0) # Prefill 10 tokens - - for _ in range(5): - new_kv = (torch.randn(2, 4, 1, 16), torch.randn(2, 4, 1, 16)) - cache.update(*new_kv, layer_idx=0) - - assert cache.get_seq_length(0) == 15 - - def test_crop_then_continue(self, tiny_attention_config, sample_kv): - """Crop old context, continue generation.""" - cache = Apriel2Cache(tiny_attention_config) - cache.update(*sample_kv, layer_idx=0) - cache.update(*sample_kv, layer_idx=0) # 20 tokens - - cache.crop(5) # Keep last 5 - cache.update(torch.randn(2, 4, 3, 16), torch.randn(2, 4, 3, 16), layer_idx=0) - - assert cache.get_seq_length(0) == 8 - - def test_reset_between_generations(self, tiny_attention_config, sample_kv): - """Reset between independent generations.""" - cache = Apriel2Cache(tiny_attention_config) - - # First generation - cache.update(*sample_kv, layer_idx=0) - assert cache.is_initialized == True - - # Reset - cache.reset() - assert cache.is_initialized == False - - # Second generation - cache.update(*sample_kv, layer_idx=0) - assert cache.get_seq_length(0) == 10 - - def test_multi_layer_consistency(self, tiny_attention_config, sample_kv): - """All layers updated consistently.""" - cache = Apriel2Cache(tiny_attention_config) - - for layer_idx in range(2): - cache.update(*sample_kv, layer_idx=layer_idx) - cache.update(torch.randn(2, 4, 1, 16), torch.randn(2, 4, 1, 16), layer_idx=layer_idx) - - for layer_idx in range(2): - assert cache.get_seq_length(layer_idx) == 11 - - -# ============================================================================= -# SECTION 9: ERROR HANDLING -# ============================================================================= - - -class TestErrorHandling: - """Test error conditions and guards.""" - - def test_stochastic_update_without_active_mixer(self, stochastic_config): - """update() on stochastic without active_mixer raises.""" - cache = Apriel2Cache(stochastic_config) - - with pytest.raises(RuntimeError, match="needs active_mixer set"): - cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=1) - - def test_stochastic_accessor_without_active_mixer(self, stochastic_config): - """Accessing stochastic cache without active_mixer raises.""" - cache = Apriel2Cache(stochastic_config) - - with pytest.raises(RuntimeError, match="requires set_active_mixer"): - _ = cache.conv_states[1] - - def test_accessor_error_lists_available_mixers(self, stochastic_config): - """Error message lists available mixers.""" - cache = Apriel2Cache(stochastic_config) - - with pytest.raises(RuntimeError, match="Available mixers:"): - _ = cache.key_cache[1] - - def test_invalid_mixer_name(self, stochastic_config): - """Invalid mixer name raises KeyError on access.""" - cache = Apriel2Cache(stochastic_config) - cache.set_active_mixer(1, "nonexistent") - - with pytest.raises(KeyError): - cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=1) - - def test_layer_idx_out_of_bounds(self, tiny_attention_config): - """Out-of-bounds layer_idx raises IndexError.""" - cache = Apriel2Cache(tiny_attention_config) - - with pytest.raises(IndexError): - cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=999) - - -# ============================================================================= -# SECTION 10: INTERNAL CLASSES -# ============================================================================= - - -class TestAttentionCacheInternal: - """Test internal _AttentionCache class directly.""" - - def test_unbounded_growth(self): - """No window allows unbounded growth.""" - cache = _AttentionCache(window=None) - - for _ in range(10): - cache.update(torch.randn(2, 4, 100, 16), torch.randn(2, 4, 100, 16)) - - assert cache.key.shape[-2] == 1000 - - def test_window_enforced(self): - """Window caps cache size.""" - cache = _AttentionCache(window=50) - - for _ in range(10): - cache.update(torch.randn(2, 4, 100, 16), torch.randn(2, 4, 100, 16)) - - assert cache.key.shape[-2] == 50 - - -class TestSSMCacheInternal: - """Test internal _SSMCache class directly.""" - - def test_initial_none(self): - """Initial states are None.""" - cache = _SSMCache() - - assert cache.conv is None - assert cache.recurrent is None - - def test_stores_tuple(self): - """Can store tuple (for KDA).""" - cache = _SSMCache() - cache.conv = (torch.randn(2, 64, 3),) * 3 - - assert isinstance(cache.conv, tuple) diff --git a/fast_llm_external_models/tests/test_apriel2/test_cache_apriel2_specific.py b/fast_llm_external_models/tests/test_apriel2/test_cache_apriel2_specific.py new file mode 100644 index 000000000..e0e4db2d3 --- /dev/null +++ b/fast_llm_external_models/tests/test_apriel2/test_cache_apriel2_specific.py @@ -0,0 +1,342 @@ +"""Tests for Apriel2-specific cache behaviors with no HuggingFace equivalent. + +This module tests features unique to Apriel2Cache that cannot be validated +against upstream HF implementations: + +1. Stochastic mixer routing (switching between attention/SSM per layer) +2. Multi-mixer layer support +3. Error handling and guard rails +4. Beam search operations (batch_repeat, reorder, select) +5. Crop operation + +Fixtures used from conftest.py: + - stochastic_config: Stochastic mixer config with attention and mamba + - attention_config: Pure attention config + - ssm_config: Pure SSM config +""" + +import pytest +import torch + +from fast_llm_external_models.apriel2.cache import Apriel2Cache, _AttentionCache, _SSMCache + + +# ============================================================================= +# STOCHASTIC MIXER ROUTING +# ============================================================================= + + +class TestStochasticMixerRouting: + """Test routing operations to correct sub-cache in stochastic layers.""" + + def test_set_active_mixer(self, stochastic_config): + """set_active_mixer updates routing for layer.""" + cache = Apriel2Cache(stochastic_config) + + cache.set_active_mixer(0, "attention") + assert cache.active_mixers[0] == "attention" + + cache.set_active_mixer(0, "mamba") + assert cache.active_mixers[0] == "mamba" + + def test_update_routes_to_active_mixer(self, stochastic_config): + """update() stores in correct sub-cache based on active_mixer.""" + cache = Apriel2Cache(stochastic_config) + + # Route to attention + cache.set_active_mixer(0, "attention") + cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0) + + # Attention sub-cache should have data + assert cache.layers[0]["attention"].key is not None + # Mamba sub-cache should be empty + assert cache.layers[0]["mamba"].conv is None + + def test_each_mixer_has_independent_cache(self, stochastic_config): + """Each mixer in a stochastic layer has its own independent state.""" + cache = Apriel2Cache(stochastic_config) + + # Store in attention + cache.set_active_mixer(0, "attention") + cache.update(torch.randn(2, 4, 5, 16), torch.randn(2, 4, 5, 16), layer_idx=0) + + # Switch to mamba and store + cache.set_active_mixer(0, "mamba") + cache.layers[0]["mamba"].conv = torch.randn(2, 64, 4) + + # Attention data should be unchanged + assert cache.layers[0]["attention"].cumulative_length == 5 + + def test_switching_preserves_all_states(self, stochastic_config): + """Switching active_mixer doesn't clear other mixer's state.""" + cache = Apriel2Cache(stochastic_config) + + # Build up attention state + cache.set_active_mixer(0, "attention") + cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0) + attn_key = cache.layers[0]["attention"].key.clone() + + # Switch to mamba + cache.set_active_mixer(0, "mamba") + + # Attention state preserved + torch.testing.assert_close(cache.layers[0]["attention"].key, attn_key) + + +# ============================================================================= +# ERROR HANDLING +# ============================================================================= + + +class TestErrorHandling: + """Test guard rails and error messages.""" + + def test_update_without_active_mixer_raises(self, stochastic_config): + """update() on stochastic layer without active_mixer raises RuntimeError.""" + cache = Apriel2Cache(stochastic_config) + + with pytest.raises(RuntimeError, match="needs active_mixer set"): + cache.update(torch.randn(2, 4, 5, 16), torch.randn(2, 4, 5, 16), layer_idx=0) + + def test_accessor_without_active_mixer_raises(self, stochastic_config): + """Accessing key_cache/value_cache without active_mixer raises RuntimeError.""" + cache = Apriel2Cache(stochastic_config) + + with pytest.raises(RuntimeError, match="requires set_active_mixer"): + _ = cache.key_cache[0] + + def test_error_message_lists_available_mixers(self, stochastic_config): + """Error message includes list of available mixers.""" + cache = Apriel2Cache(stochastic_config) + + with pytest.raises(RuntimeError, match="attention.*mamba|mamba.*attention"): + _ = cache.key_cache[0] + + +# ============================================================================= +# BEAM SEARCH OPERATIONS +# ============================================================================= + + +class TestBeamSearchOperations: + """Test batch manipulation for beam search.""" + + def test_batch_repeat_interleave_attention(self, attention_config): + """batch_repeat_interleave expands batch dimension.""" + cache = Apriel2Cache(attention_config) + cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0) + + cache.batch_repeat_interleave(3) + + assert cache.layers[0].key.shape[0] == 6 # 2 * 3 + + def test_batch_repeat_interleave_ssm(self, ssm_config): + """batch_repeat_interleave works for SSM caches.""" + cache = Apriel2Cache(ssm_config) + cache.layers[0].conv = torch.randn(2, 64, 4) + + cache.batch_repeat_interleave(3) + + assert cache.layers[0].conv.shape[0] == 6 + + def test_batch_repeat_interleave_kda_tuple(self, ssm_config): + """batch_repeat_interleave handles KDA tuple conv states.""" + cache = Apriel2Cache(ssm_config) + cache.layers[0].conv = (torch.randn(2, 64, 4),) * 3 + + cache.batch_repeat_interleave(3) + + assert cache.layers[0].conv[0].shape[0] == 6 + + def test_reorder_cache_attention(self, attention_config): + """reorder_cache reorders batch dimension.""" + cache = Apriel2Cache(attention_config) + k = torch.arange(4).float().view(4, 1, 1, 1).expand(4, 4, 10, 16) + cache.update(k, k.clone(), layer_idx=0) + + beam_idx = torch.tensor([3, 2, 1, 0]) + cache.reorder_cache(beam_idx) + + # Check reordering + assert cache.layers[0].key[0, 0, 0, 0].item() == 3.0 + assert cache.layers[0].key[3, 0, 0, 0].item() == 0.0 + + def test_batch_select_indices(self, attention_config): + """batch_select_indices selects subset of batch.""" + cache = Apriel2Cache(attention_config) + cache.update(torch.randn(4, 4, 10, 16), torch.randn(4, 4, 10, 16), layer_idx=0) + + indices = torch.tensor([0, 2]) + cache.batch_select_indices(indices) + + assert cache.layers[0].key.shape[0] == 2 + + def test_reorder_cache_ssm_tuple(self, ssm_config): + """reorder_cache handles KDA tuple conv states.""" + cache = Apriel2Cache(ssm_config) + # Create distinguishable tensors for each batch position + conv0 = torch.full((1, 64, 4), 0.0) + conv1 = torch.full((1, 64, 4), 1.0) + conv2 = torch.full((1, 64, 4), 2.0) + cache.layers[0].conv = ( + torch.cat([conv0, conv1, conv2], dim=0), + torch.cat([conv0, conv1, conv2], dim=0), + torch.cat([conv0, conv1, conv2], dim=0), + ) + + beam_idx = torch.tensor([2, 1, 0]) + cache.reorder_cache(beam_idx) + + # Check reordering: batch[0] should now have value 2.0 + assert cache.layers[0].conv[0][0, 0, 0].item() == 2.0 + assert cache.layers[0].conv[0][2, 0, 0].item() == 0.0 + + def test_batch_select_indices_ssm_tuple(self, ssm_config): + """batch_select_indices handles KDA tuple conv states.""" + cache = Apriel2Cache(ssm_config) + cache.layers[0].conv = (torch.randn(4, 64, 4),) * 3 + + indices = torch.tensor([0, 2]) + cache.batch_select_indices(indices) + + assert cache.layers[0].conv[0].shape[0] == 2 + assert cache.layers[0].conv[1].shape[0] == 2 + assert cache.layers[0].conv[2].shape[0] == 2 + + +# ============================================================================= +# CROP OPERATION +# ============================================================================= + + +class TestCropOperation: + """Test cache truncation.""" + + def test_crop_truncates_attention(self, attention_config): + """crop() truncates attention cache.""" + cache = Apriel2Cache(attention_config) + cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0) + + cache.crop(5) + + assert cache.layers[0].key.shape[-2] == 5 + assert cache.get_seq_length(0) == 5 + + def test_crop_affects_all_layers(self, attention_config): + """crop() affects all layers.""" + cache = Apriel2Cache(attention_config) + cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0) + cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=1) + + cache.crop(3) + + assert cache.layers[0].key.shape[-2] == 3 + assert cache.layers[1].key.shape[-2] == 3 + + def test_crop_ignores_ssm(self, ssm_config): + """crop() doesn't affect SSM caches (they don't have seq dimension).""" + cache = Apriel2Cache(ssm_config) + cache.layers[0].conv = torch.randn(2, 64, 4) + + # Should not raise + cache.crop(5) + + # SSM state unchanged + assert cache.layers[0].conv.shape == (2, 64, 4) + + +# ============================================================================= +# CACHE PROPERTIES +# ============================================================================= + + +class TestCacheProperties: + """Test cache property methods.""" + + def test_is_initialized_attention(self, attention_config): + """is_initialized True after update.""" + cache = Apriel2Cache(attention_config) + assert not cache.is_initialized + + cache.update(torch.randn(2, 4, 5, 16), torch.randn(2, 4, 5, 16), layer_idx=0) + assert cache.is_initialized + + def test_is_initialized_ssm(self, ssm_config): + """is_initialized True after setting conv state.""" + cache = Apriel2Cache(ssm_config) + assert not cache.is_initialized + + cache.layers[0].conv = torch.randn(2, 64, 4) + assert cache.is_initialized + + def test_has_previous_state_ssm_only(self, ssm_config): + """has_previous_state checks SSM conv states.""" + cache = Apriel2Cache(ssm_config) + assert not cache.has_previous_state + + cache.layers[0].conv = torch.randn(2, 64, 4) + assert cache.has_previous_state + + def test_has_previous_state_ignores_attention(self, attention_config): + """has_previous_state ignores attention caches.""" + cache = Apriel2Cache(attention_config) + cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0) + + # Attention-only cache returns False for has_previous_state + assert not cache.has_previous_state + + def test_reset_clears_ssm_states(self, ssm_config): + """reset() clears SSM conv and recurrent states.""" + cache = Apriel2Cache(ssm_config) + cache.layers[0].conv = torch.randn(2, 64, 4) + cache.layers[0].recurrent = torch.randn(2, 64, 16) + + cache.reset() + + assert cache.layers[0].conv is None + assert cache.layers[0].recurrent is None + + def test_max_batch_size_from_ssm_tuple(self, ssm_config): + """max_batch_size works with KDA tuple conv states.""" + cache = Apriel2Cache(ssm_config) + cache.layers[0].conv = (torch.randn(3, 64, 4),) * 3 + + assert cache.max_batch_size == 3 + + def test_max_batch_size(self, attention_config): + """max_batch_size returns batch dimension.""" + cache = Apriel2Cache(attention_config) + cache.update(torch.randn(3, 4, 10, 16), torch.randn(3, 4, 10, 16), layer_idx=0) + + assert cache.max_batch_size == 3 + + def test_len_returns_num_layers(self, attention_config): + """__len__ returns number of layers.""" + cache = Apriel2Cache(attention_config) + assert len(cache) == 2 + + +# ============================================================================= +# INDEXING +# ============================================================================= + + +class TestCacheIndexing: + """Test __getitem__ for HF compatibility.""" + + def test_getitem_returns_kv_tuple(self, attention_config): + """cache[idx] returns (key, value) tuple.""" + cache = Apriel2Cache(attention_config) + cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0) + + k, v = cache[0] + assert k.shape == (2, 4, 10, 16) + assert v.shape == (2, 4, 10, 16) + + def test_getitem_empty_returns_empty_tensors(self, attention_config): + """cache[idx] on empty cache returns empty tensors.""" + cache = Apriel2Cache(attention_config) + + k, v = cache[0] + assert k.numel() == 0 + assert v.numel() == 0 diff --git a/fast_llm_external_models/tests/test_apriel2/test_cache_contracts.py b/fast_llm_external_models/tests/test_apriel2/test_cache_contracts.py new file mode 100644 index 000000000..7c38f75b7 --- /dev/null +++ b/fast_llm_external_models/tests/test_apriel2/test_cache_contracts.py @@ -0,0 +1,592 @@ +"""Contract tests for Apriel2Cache against HuggingFace cache implementations. + +This module tests that Apriel2Cache components behave equivalently to their +HuggingFace counterparts. This ensures compatibility with HF's generation +infrastructure (mask creation, beam search, etc.). + +Mapping: + Apriel2 Component HuggingFace Equivalent + ----------------- ---------------------- + _AttentionCache (no window) -> DynamicLayer + _AttentionCache (window) -> DynamicSlidingWindowLayer + _SSMCache -> MambaCache (different interface, same concept) + +Apriel2-specific features (stochastic routing, multi-mixer layers) are tested +separately in test_cache_apriel2_specific.py since they have no HF equivalent. + +Fixtures used from conftest.py: + - batch_size, num_heads, head_dim: Tensor dimensions + - hf_dynamic_layer: HuggingFace DynamicLayer + - hf_sliding_layer: HuggingFace DynamicSlidingWindowLayer (parameterized by window_size) + - apriel_attention_cache: Apriel2 _AttentionCache (no window) + - apriel_sliding_cache: Apriel2 _AttentionCache (with window, parameterized) + - window_size: Parameterized window sizes [4, 8, 16, 32] + - attention_config, swa_config: Apriel2 configs +""" + +import pytest +import torch + +from fast_llm_external_models.apriel2.cache import _AttentionCache, _SSMCache, Apriel2Cache + + +# ============================================================================= +# SECTION 1: FULL ATTENTION - _AttentionCache vs DynamicLayer +# ============================================================================= + + +class TestFullAttentionContract: + """Test _AttentionCache (no window) matches HuggingFace DynamicLayer. + + DynamicLayer is the standard cache for full causal attention. + We test that our cache produces identical mask parameters. + """ + + # ------------------------------------------------------------------------- + # get_seq_length: Must match exactly for generation to work + # ------------------------------------------------------------------------- + + @pytest.mark.parametrize("seq_len", [1, 5, 10, 50, 100]) + def test_get_seq_length_after_prefill( + self, hf_dynamic_layer, apriel_attention_cache, batch_size, num_heads, head_dim, seq_len + ): + """After prefill, cumulative_length matches HF get_seq_length.""" + key = torch.randn(batch_size, num_heads, seq_len, head_dim) + value = torch.randn(batch_size, num_heads, seq_len, head_dim) + + hf_dynamic_layer.update(key.clone(), value.clone()) + apriel_attention_cache.update(key.clone(), value.clone()) + + assert apriel_attention_cache.cumulative_length == hf_dynamic_layer.get_seq_length() + + @pytest.mark.parametrize("prefill_len", [1, 5, 10]) + @pytest.mark.parametrize("decode_steps", [1, 5, 10, 20]) + def test_get_seq_length_during_decode( + self, hf_dynamic_layer, apriel_attention_cache, batch_size, num_heads, head_dim, prefill_len, decode_steps + ): + """During decode, cumulative_length tracks total tokens seen.""" + # Prefill + key = torch.randn(batch_size, num_heads, prefill_len, head_dim) + value = torch.randn(batch_size, num_heads, prefill_len, head_dim) + hf_dynamic_layer.update(key.clone(), value.clone()) + apriel_attention_cache.update(key.clone(), value.clone()) + + # Decode + for step in range(decode_steps): + key = torch.randn(batch_size, num_heads, 1, head_dim) + value = torch.randn(batch_size, num_heads, 1, head_dim) + hf_dynamic_layer.update(key.clone(), value.clone()) + apriel_attention_cache.update(key.clone(), value.clone()) + + assert apriel_attention_cache.cumulative_length == hf_dynamic_layer.get_seq_length(), ( + f"Mismatch at decode step {step}" + ) + + # ------------------------------------------------------------------------- + # get_mask_sizes: Verify HF behavior for documentation + # ------------------------------------------------------------------------- + + @pytest.mark.parametrize("prefill_len", [1, 5, 10]) + @pytest.mark.parametrize("decode_steps", [0, 1, 5, 10]) + def test_hf_mask_sizes_kv_length( + self, hf_dynamic_layer, apriel_attention_cache, batch_size, num_heads, head_dim, prefill_len, decode_steps + ): + """Document HF's kv_length behavior and verify cumulative_length tracks correctly. + + For full attention, kv_length = cumulative_length + query_length. + This test verifies our cache tracks tokens identically to HF. + """ + # Prefill + key = torch.randn(batch_size, num_heads, prefill_len, head_dim) + value = torch.randn(batch_size, num_heads, prefill_len, head_dim) + hf_dynamic_layer.update(key.clone(), value.clone()) + apriel_attention_cache.update(key.clone(), value.clone()) + + # Decode + for _ in range(decode_steps): + key = torch.randn(batch_size, num_heads, 1, head_dim) + value = torch.randn(batch_size, num_heads, 1, head_dim) + hf_dynamic_layer.update(key.clone(), value.clone()) + apriel_attention_cache.update(key.clone(), value.clone()) + + # Verify cumulative_length matches HF + assert apriel_attention_cache.cumulative_length == hf_dynamic_layer.get_seq_length() + + # Verify HF's kv_length follows the expected formula + cache_position = torch.arange(1) # Single token decode + hf_kv_len, hf_kv_offset = hf_dynamic_layer.get_mask_sizes(cache_position) + expected_kv_len = hf_dynamic_layer.get_seq_length() + cache_position.shape[0] + assert hf_kv_len == expected_kv_len + + def test_hf_kv_offset_always_zero(self, hf_dynamic_layer, batch_size, num_heads, head_dim): + """Document that HF DynamicLayer always returns kv_offset=0. + + For full attention, all cached KV pairs map to absolute positions + starting from 0, so kv_offset is always 0. + """ + # Add many tokens + for _ in range(20): + key = torch.randn(batch_size, num_heads, 5, head_dim) + value = torch.randn(batch_size, num_heads, 5, head_dim) + hf_dynamic_layer.update(key.clone(), value.clone()) + + cache_position = torch.arange(1) + _, hf_kv_offset = hf_dynamic_layer.get_mask_sizes(cache_position) + + assert hf_kv_offset == 0, "DynamicLayer always returns kv_offset=0" + + # ------------------------------------------------------------------------- + # update: Output shape and values must match + # ------------------------------------------------------------------------- + + @pytest.mark.parametrize("seq_len", [1, 5, 10]) + def test_update_returns_same_shape( + self, hf_dynamic_layer, apriel_attention_cache, batch_size, num_heads, head_dim, seq_len + ): + """update() returns tensors with matching shapes.""" + key = torch.randn(batch_size, num_heads, seq_len, head_dim) + value = torch.randn(batch_size, num_heads, seq_len, head_dim) + + hf_k, hf_v = hf_dynamic_layer.update(key.clone(), value.clone()) + apr_k, apr_v = apriel_attention_cache.update(key.clone(), value.clone()) + + assert hf_k.shape == apr_k.shape + assert hf_v.shape == apr_v.shape + + def test_update_concatenates_identically( + self, hf_dynamic_layer, apriel_attention_cache, batch_size, num_heads, head_dim + ): + """Multiple updates produce identical concatenated states.""" + # Use deterministic values for comparison + k1 = torch.arange(10).float().view(1, 1, 10, 1).expand(batch_size, num_heads, 10, head_dim) + v1 = k1.clone() + + hf_dynamic_layer.update(k1.clone(), v1.clone()) + apriel_attention_cache.update(k1.clone(), v1.clone()) + + k2 = torch.arange(10, 15).float().view(1, 1, 5, 1).expand(batch_size, num_heads, 5, head_dim) + v2 = k2.clone() + + hf_k, hf_v = hf_dynamic_layer.update(k2.clone(), v2.clone()) + apr_k, apr_v = apriel_attention_cache.update(k2.clone(), v2.clone()) + + torch.testing.assert_close(hf_k, apr_k) + torch.testing.assert_close(hf_v, apr_v) + + +# ============================================================================= +# SECTION 2: SLIDING WINDOW - _AttentionCache vs DynamicSlidingWindowLayer +# ============================================================================= + + +class TestSlidingWindowContract: + """Test _AttentionCache (with window) matches HuggingFace DynamicSlidingWindowLayer. + + DynamicSlidingWindowLayer is used for sliding window attention (e.g., Mistral). + Critical behaviors: + - cumulative_length tracks ALL tokens seen (not just cached) + - kv_offset increases once window is exceeded + - kv_length is capped at window size + + Uses fixtures from conftest.py: + - window_size: parameterized [4, 8, 16, 32] + - hf_sliding_layer: DynamicSlidingWindowLayer + - apriel_sliding_cache: _AttentionCache with window + """ + + # ------------------------------------------------------------------------- + # cumulative_length: Must track total tokens, not cached tokens + # ------------------------------------------------------------------------- + + @pytest.mark.parametrize("prefill_len", [1, 3, 5, 10, 20]) + def test_cumulative_length_matches_after_prefill( + self, hf_sliding_layer, apriel_sliding_cache, batch_size, num_heads, head_dim, prefill_len + ): + """cumulative_length matches HF get_seq_length after prefill.""" + key = torch.randn(batch_size, num_heads, prefill_len, head_dim) + value = torch.randn(batch_size, num_heads, prefill_len, head_dim) + + hf_sliding_layer.update(key.clone(), value.clone()) + apriel_sliding_cache.update(key.clone(), value.clone()) + + assert apriel_sliding_cache.cumulative_length == hf_sliding_layer.get_seq_length() + + def test_cumulative_length_continues_past_window( + self, hf_sliding_layer, apriel_sliding_cache, window_size, batch_size, num_heads, head_dim + ): + """cumulative_length keeps growing even after window is full.""" + total_tokens = window_size * 3 # Way past window + + for i in range(total_tokens): + key = torch.randn(batch_size, num_heads, 1, head_dim) + value = torch.randn(batch_size, num_heads, 1, head_dim) + + hf_sliding_layer.update(key.clone(), value.clone()) + apriel_sliding_cache.update(key.clone(), value.clone()) + + expected = i + 1 + assert apriel_sliding_cache.cumulative_length == expected + assert hf_sliding_layer.get_seq_length() == expected + + # ------------------------------------------------------------------------- + # get_mask_sizes: kv_offset must increase once window is exceeded + # ------------------------------------------------------------------------- + + def test_kv_offset_zero_before_window_full( + self, hf_sliding_layer, apriel_sliding_cache, window_size, batch_size, num_heads, head_dim + ): + """kv_offset is 0 while cumulative < window. + + Before the window is full, kv_offset should be 0 because all cached tokens + correspond to absolute positions starting from 0. + """ + # Add tokens up to window-1 + for i in range(window_size - 1): + key = torch.randn(batch_size, num_heads, 1, head_dim) + value = torch.randn(batch_size, num_heads, 1, head_dim) + + hf_sliding_layer.update(key.clone(), value.clone()) + apriel_sliding_cache.update(key.clone(), value.clone()) + + cache_position = torch.arange(1) + hf_kv_len, hf_kv_offset = hf_sliding_layer.get_mask_sizes(cache_position) + + # Verify HF returns 0 offset before window full + assert hf_kv_offset == 0, f"HF offset should be 0 at step {i}" + # Verify Apriel cache tracks cumulative correctly + assert apriel_sliding_cache.cumulative_length == i + 1 + + def test_kv_offset_increases_after_window_full( + self, hf_sliding_layer, apriel_sliding_cache, window_size, batch_size, num_heads, head_dim + ): + """kv_offset increases once cumulative >= window. + + Once the window is full, the cache discards oldest tokens. kv_offset tracks + which absolute position KV[0] corresponds to. + """ + # Fill to exactly window + for _ in range(window_size): + key = torch.randn(batch_size, num_heads, 1, head_dim) + value = torch.randn(batch_size, num_heads, 1, head_dim) + hf_sliding_layer.update(key.clone(), value.clone()) + apriel_sliding_cache.update(key.clone(), value.clone()) + + cache_position = torch.arange(1) + hf_kv_len, hf_kv_offset = hf_sliding_layer.get_mask_sizes(cache_position) + + # At window boundary, offset should be 1 + assert hf_kv_offset == 1, "HF offset should be 1 at window boundary" + assert apriel_sliding_cache.cumulative_length == window_size + + # Add more tokens and verify offset keeps increasing with HF + for i in range(5): + key = torch.randn(batch_size, num_heads, 1, head_dim) + value = torch.randn(batch_size, num_heads, 1, head_dim) + hf_sliding_layer.update(key.clone(), value.clone()) + apriel_sliding_cache.update(key.clone(), value.clone()) + + hf_kv_len, hf_kv_offset = hf_sliding_layer.get_mask_sizes(cache_position) + + expected_offset = i + 2 + assert hf_kv_offset == expected_offset + assert apriel_sliding_cache.cumulative_length == window_size + i + 1 + + def test_kv_length_capped_at_window( + self, hf_sliding_layer, apriel_sliding_cache, window_size, batch_size, num_heads, head_dim + ): + """kv_length is capped at window size once exceeded. + + For a query of length 1 after the window is full, kv_length = window + (window-1 cached tokens + 1 query token). + """ + # Way past window + for _ in range(window_size * 2): + key = torch.randn(batch_size, num_heads, 1, head_dim) + value = torch.randn(batch_size, num_heads, 1, head_dim) + hf_sliding_layer.update(key.clone(), value.clone()) + apriel_sliding_cache.update(key.clone(), value.clone()) + + cache_position = torch.arange(1) + hf_kv_len, _ = hf_sliding_layer.get_mask_sizes(cache_position) + + # HF returns window (window-1 cached + 1 query) + assert hf_kv_len == window_size + # Verify our cache tracked cumulative correctly + assert apriel_sliding_cache.cumulative_length == window_size * 2 + + # ------------------------------------------------------------------------- + # Full sequence length tracking through generation + # ------------------------------------------------------------------------- + + @pytest.mark.parametrize("prefill_len", [1, 3, 5, 10, 20]) + def test_cumulative_length_tracks_all_tokens( + self, hf_sliding_layer, apriel_sliding_cache, window_size, batch_size, num_heads, head_dim, prefill_len + ): + """cumulative_length tracks total tokens seen through prefill + decode. + + This is the foundation for correct mask size computation. We verify that + our _AttentionCache tracks tokens identically to HuggingFace's DynamicSlidingWindowLayer. + The actual get_mask_sizes computation is tested in TestApriel2CacheIntegration. + """ + # Prefill + key = torch.randn(batch_size, num_heads, prefill_len, head_dim) + value = torch.randn(batch_size, num_heads, prefill_len, head_dim) + hf_sliding_layer.update(key.clone(), value.clone()) + apriel_sliding_cache.update(key.clone(), value.clone()) + + assert apriel_sliding_cache.cumulative_length == hf_sliding_layer.get_seq_length() + + # Decode past window + for i in range(window_size + 10): + key = torch.randn(batch_size, num_heads, 1, head_dim) + value = torch.randn(batch_size, num_heads, 1, head_dim) + hf_sliding_layer.update(key.clone(), value.clone()) + apriel_sliding_cache.update(key.clone(), value.clone()) + + assert apriel_sliding_cache.cumulative_length == hf_sliding_layer.get_seq_length(), ( + f"cumulative_length mismatch at step {i}" + ) + + +# ============================================================================= +# SECTION 3: SSM CACHE - _SSMCache vs MambaCache concept +# ============================================================================= + + +class TestSSMCacheContract: + """Document _SSMCache interface and verify basic contract. + + Unlike attention caches which have HF equivalents (DynamicLayer, DynamicSlidingWindowLayer), + SSM caches have no direct HF counterpart with matching interface. HF's MambaCache uses + different methods (update_conv_state, update_ssm_state), so we can't do direct comparison. + + These tests document the interface contract: + 1. `conv` and `recurrent` attributes for storing states + 2. Both support None (lazy initialization) + 3. `conv` can be tuple (for KDA which has separate q/k/v conv states) + + Higher-level operations (reorder, batch_repeat, reset) are tested in + TestBeamSearchOperations in test_cache_apriel2_specific.py. + """ + + def test_conv_state_storage(self, ssm_cache): + """conv attribute stores conv states (batch, intermediate, kernel_size).""" + conv = torch.randn(2, 64, 4) + ssm_cache.conv = conv + torch.testing.assert_close(ssm_cache.conv, conv) + + def test_recurrent_state_storage(self, ssm_cache): + """recurrent attribute stores SSM states (batch, intermediate, state_size).""" + recurrent = torch.randn(2, 64, 16) + ssm_cache.recurrent = recurrent + torch.testing.assert_close(ssm_cache.recurrent, recurrent) + + def test_conv_state_tuple_for_kda(self, ssm_cache): + """conv can be tuple for KDA's separate q/k/v convolutions.""" + conv_tuple = (torch.randn(2, 64, 4), torch.randn(2, 64, 4), torch.randn(2, 64, 4)) + ssm_cache.conv = conv_tuple + assert isinstance(ssm_cache.conv, tuple) + assert len(ssm_cache.conv) == 3 + + def test_initial_states_none(self, ssm_cache): + """States are None initially (lazy initialization pattern).""" + assert ssm_cache.conv is None + assert ssm_cache.recurrent is None + + def test_states_independent(self, ssm_cache): + """conv and recurrent states are independent.""" + ssm_cache.conv = torch.randn(2, 64, 4) + assert ssm_cache.recurrent is None # recurrent unchanged + + ssm_cache.recurrent = torch.randn(2, 64, 16) + assert ssm_cache.conv is not None # conv unchanged + + +# ============================================================================= +# SECTION 4: APRIEL2CACHE INTEGRATION +# ============================================================================= + + +class TestApriel2CacheIntegration: + """Test Apriel2Cache correctly delegates to underlying caches. + + Uses fixtures from conftest.py: + - attention_config: Pure attention config + - swa_config: Sliding window attention config (window=8) + """ + + def test_get_seq_length_matches_dynamic_layer(self, attention_config): + """Apriel2Cache.get_seq_length matches DynamicLayer for full attention.""" + from transformers.cache_utils import DynamicLayer + + cache = Apriel2Cache(attention_config) + hf_layer = DynamicLayer() + + key = torch.randn(2, 4, 10, 16) + value = torch.randn(2, 4, 10, 16) + + cache.update(key.clone(), value.clone(), layer_idx=0) + hf_layer.update(key.clone(), value.clone()) + + assert cache.get_seq_length(0) == hf_layer.get_seq_length() + + def test_get_mask_sizes_matches_dynamic_layer(self, attention_config): + """Apriel2Cache.get_mask_sizes matches DynamicLayer.""" + from transformers.cache_utils import DynamicLayer + + cache = Apriel2Cache(attention_config) + hf_layer = DynamicLayer() + + key = torch.randn(2, 4, 10, 16) + value = torch.randn(2, 4, 10, 16) + + cache.update(key.clone(), value.clone(), layer_idx=0) + hf_layer.update(key.clone(), value.clone()) + + cache_position = torch.arange(1) + hf_kv_len, hf_kv_offset = hf_layer.get_mask_sizes(cache_position) + apr_kv_len, apr_kv_offset = cache.get_mask_sizes(cache_position, layer_idx=0) + + assert apr_kv_len == hf_kv_len + assert apr_kv_offset == hf_kv_offset + + def test_get_mask_sizes_matches_sliding_layer(self, swa_config): + """Apriel2Cache.get_mask_sizes matches DynamicSlidingWindowLayer.""" + from transformers.cache_utils import DynamicSlidingWindowLayer + + cache = Apriel2Cache(swa_config) + hf_layer = DynamicSlidingWindowLayer(sliding_window=8) + + # Fill past window + for _ in range(15): + key = torch.randn(2, 4, 1, 16) + value = torch.randn(2, 4, 1, 16) + cache.update(key.clone(), value.clone(), layer_idx=0) + hf_layer.update(key.clone(), value.clone()) + + cache_position = torch.arange(1) + hf_kv_len, hf_kv_offset = hf_layer.get_mask_sizes(cache_position) + apr_kv_len, apr_kv_offset = cache.get_mask_sizes(cache_position, layer_idx=0) + + assert apr_kv_len == hf_kv_len + assert apr_kv_offset == hf_kv_offset + + def test_reset_clears_cumulative_length(self, attention_config): + """reset() clears cumulative_length (matches DynamicLayer.reset).""" + cache = Apriel2Cache(attention_config) + + cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0) + assert cache.get_seq_length(0) == 10 + + cache.reset() + assert cache.get_seq_length(0) == 0 + + +# ============================================================================= +# SECTION 5: MASK CORRECTNESS (SEMANTIC TESTS) +# ============================================================================= + + +class TestMaskCorrectness: + """Test that mask parameters produce semantically correct masks. + + These tests verify the END RESULT: masks created with our parameters + allow the correct attention patterns. + """ + + def test_full_attention_decode_can_attend_to_all(self): + """During decode, query can attend to all cached positions.""" + from transformers.masking_utils import sdpa_mask, causal_mask_function + + cache = _AttentionCache(window=None) + + # Prefill + decode + for _ in range(10): + cache.update(torch.randn(1, 1, 1, 16), torch.randn(1, 1, 1, 16)) + + # Mask for decode step + cache_position = torch.tensor([10]) # Position of new token + kv_length = cache.cumulative_length + 1 + kv_offset = 0 + + mask = sdpa_mask( + batch_size=1, + cache_position=cache_position, + kv_length=kv_length, + kv_offset=kv_offset, + mask_function=causal_mask_function, + ) + + if mask is not None: + # Query at position 10 should attend to positions 0-10 + query_mask = mask[0, 0, 0, :] + for kv_idx in range(kv_length): + assert query_mask[kv_idx].item() == True, f"Should attend to position {kv_idx}" + + @pytest.mark.parametrize("window_size", [4, 8, 16]) + def test_sliding_window_decode_respects_window(self, window_size): + """During decode, query only attends within sliding window.""" + from transformers.masking_utils import sdpa_mask, sliding_window_causal_mask_function + + cache = _AttentionCache(window=window_size) + + # Fill way past window + total_tokens = window_size * 2 + for _ in range(total_tokens): + cache.update(torch.randn(1, 1, 1, 16), torch.randn(1, 1, 1, 16)) + + # Mask for decode step + cache_position = torch.tensor([total_tokens]) + cumulative = cache.cumulative_length + kv_offset = max(cumulative - window_size + 1, 0) + kv_length = window_size - 1 + 1 # cached + query + + mask = sdpa_mask( + batch_size=1, + cache_position=cache_position, + kv_length=kv_length, + kv_offset=kv_offset, + mask_function=sliding_window_causal_mask_function(window_size), + ) + + if mask is not None: + query_mask = mask[0, 0, 0, :] + query_pos = cache_position[0].item() + + for kv_idx in range(kv_length): + abs_pos = kv_offset + kv_idx + in_window = abs_pos > query_pos - window_size + causal = abs_pos <= query_pos + expected = in_window and causal + + assert query_mask[kv_idx].item() == expected, ( + f"Position {abs_pos}: expected {expected}, got {query_mask[kv_idx].item()}" + ) + + def test_prefill_has_causal_pattern(self): + """During prefill, mask has proper causal (lower triangular) pattern.""" + from transformers.masking_utils import sdpa_mask, causal_mask_function + + cache = _AttentionCache(window=None) + cache.update(torch.randn(1, 1, 5, 16), torch.randn(1, 1, 5, 16)) + + cache_position = torch.arange(5) + kv_length = cache.cumulative_length + kv_offset = 0 + + mask = sdpa_mask( + batch_size=1, + cache_position=cache_position, + kv_length=kv_length, + kv_offset=kv_offset, + mask_function=causal_mask_function, + allow_is_causal_skip=False, # Force mask creation + ) + + if mask is not None: + # Check causal pattern + for q_idx in range(5): + for kv_idx in range(5): + expected = kv_idx <= q_idx + actual = mask[0, 0, q_idx, kv_idx].item() + assert actual == expected, f"q={q_idx}, kv={kv_idx}: expected {expected}" From 843a355a6c37ea8f74783fef987b658b8549af51 Mon Sep 17 00:00:00 2001 From: bigximik Date: Fri, 28 Nov 2025 14:09:39 +0000 Subject: [PATCH 02/12] fix qwen converted to correctly load qkv biases --- fast_llm/models/gpt/conversion/qwen2.py | 32 +++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/fast_llm/models/gpt/conversion/qwen2.py b/fast_llm/models/gpt/conversion/qwen2.py index a8bc33454..57c9614bd 100644 --- a/fast_llm/models/gpt/conversion/qwen2.py +++ b/fast_llm/models/gpt/conversion/qwen2.py @@ -2,6 +2,7 @@ from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.layers.attention.config import AttentionConfig +from fast_llm.layers.decoder.mlp.config import MLPConfig from fast_llm.models.gpt.conversion.config import Qwen2CheckpointFormat from fast_llm.models.gpt.conversion.llama import ( LlamaAttentionConverter, @@ -10,6 +11,7 @@ LlamaDecoderConverter, LlamaHeadConverter, LlamaHuggingfaceCheckpointHandler, + LlamaMLPConverter, ) from fast_llm.utils import Assert @@ -17,6 +19,22 @@ class Qwen2AttentionConverter(LlamaAttentionConverter): # TODO: Support sliding window with max_window_layers (need 2 kinds of block?) + @classmethod + def import_config(cls, config: dict) -> dict: + config["attention_bias"] = True + out = super().import_config(config) + out["query_layer"] = {"bias": {"enabled": True}} + out["key_layer"] = {"bias": {"enabled": True}} + out["value_layer"] = {"bias": {"enabled": True}} + out["dense_layer"] = {"bias": {"enabled": False}} + return out + + @classmethod + def export_config(cls, config: AttentionConfig) -> dict: + out = super().export_config(config) + del out["attention_bias"] + return out + @classmethod def _check_config(cls, config: AttentionConfig) -> None: Assert.is_(type(config), AttentionConfig) @@ -33,8 +51,22 @@ def _check_config(cls, config: AttentionConfig) -> None: Assert.incl(config.dense_layer.bias.enabled, (None, False)) +class Qwen2MLPConverter(LlamaMLPConverter): + @classmethod + def import_config(cls, config: dict) -> dict: + config["mlp_bias"] = False + return super().import_config(config) + + @classmethod + def export_config(cls, config: MLPConfig) -> dict: + out = super().export_config(config) + del out["mlp_bias"] + return out + + class Qwen2BlockConverter(LlamaBlockConverter): mixer_converter_class: typing.ClassVar[type[Qwen2AttentionConverter]] = Qwen2AttentionConverter + mlp_converter_class: typing.ClassVar[type[Qwen2MLPConverter]] = Qwen2MLPConverter class Qwen2DecoderConverter(LlamaDecoderConverter): From 33b6d31dd842022812814655ba3ef2a6558ad010 Mon Sep 17 00:00:00 2001 From: bigximik Date: Tue, 2 Dec 2025 12:09:03 +0000 Subject: [PATCH 03/12] fix converters --- fast_llm/models/gpt/conversion/qwen2.py | 37 +++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/fast_llm/models/gpt/conversion/qwen2.py b/fast_llm/models/gpt/conversion/qwen2.py index 57c9614bd..4ebf18c3a 100644 --- a/fast_llm/models/gpt/conversion/qwen2.py +++ b/fast_llm/models/gpt/conversion/qwen2.py @@ -1,10 +1,12 @@ import typing from fast_llm.engine.checkpoint.config import CheckpointFormat +from fast_llm.engine.checkpoint.external import WeightConverter from fast_llm.layers.attention.config import AttentionConfig from fast_llm.layers.decoder.mlp.config import MLPConfig from fast_llm.models.gpt.conversion.config import Qwen2CheckpointFormat from fast_llm.models.gpt.conversion.llama import ( + KeyValueWeightConverter, LlamaAttentionConverter, LlamaBaseModelConverter, LlamaBlockConverter, @@ -12,6 +14,8 @@ LlamaHeadConverter, LlamaHuggingfaceCheckpointHandler, LlamaMLPConverter, + QueryWeightConverter, + get_weight_and_bias_converters, ) from fast_llm.utils import Assert @@ -50,6 +54,39 @@ def _check_config(cls, config: AttentionConfig) -> None: Assert.is_(config.value_layer.bias.enabled, True) Assert.incl(config.dense_layer.bias.enabled, (None, False)) + @classmethod + def get_converters( + cls, + config: AttentionConfig, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + ) -> list[WeightConverter]: + return [ + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.query", + f"{hf_prefix}.q_proj", + True, + QueryWeightConverter, + config, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.key_value", + (f"{hf_prefix}.k_proj", f"{hf_prefix}.v_proj"), + True, + KeyValueWeightConverter, + config, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.dense", + f"{hf_prefix}.o_proj", + False, + drop_on_export=drop_on_export, + ), + ] + class Qwen2MLPConverter(LlamaMLPConverter): @classmethod From 78229757422ccd58256cdd5afe2067884d3a80f5 Mon Sep 17 00:00:00 2001 From: Torsten Scholak Date: Sun, 14 Dec 2025 20:51:29 +0000 Subject: [PATCH 04/12] Add per-layer bias support, surgery improvements, and integration tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit adds comprehensive support for per-layer bias configurations in Apriel2 conversions and improves the surgery/conversion infrastructure. Key changes: **Per-layer bias configuration:** - Support weight-specific bias settings (query_layer.bias.enabled, etc.) - Bias inheritance for stochastic mixer submixers - Proper handling of Qwen-style bias pattern (QKV bias, no O bias) **Surgery and conversion improvements:** - Document monoidal structure in compose_configs and plan_surgery - Fix non-gated MLP handling (gate_proj only when gated=True) - Fix vision_encoder=None handling in converters - Change to relative imports in apriel2 modules for portability **Test infrastructure:** - Add requires_fastllm decorator for Fast-LLM dependent tests - Fix autouse fixture scoping (module-scoped for proper ordering) - Add comprehensive integration tests with parameterized inputs - Test all conversion stages: Qwen2 -> Apriel2 -> Supernet -> Roundtrip - Parameterized test inputs for batch size, padding, and generation length **Integration test structure:** - TestConfigPreservation: Verify config correctness at each stage - TestNumericalEquivalence: Verify logits and generation match - 24 tests covering 3 stages Γ— 3 input variations Γ— 2 checks πŸ€– Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- fast_llm/models/gpt/conversion/apriel2.py | 125 +++++-- .../apriel2/conversion/__init__.py | 2 +- .../apriel2/conversion/config.py | 142 +++++--- .../apriel2/conversion/converters.py | 195 ++++++++-- .../apriel2/conversion/qwen2/config.py | 14 +- .../apriel2/conversion/qwen2/plan.py | 20 +- .../apriel2/modeling_apriel2.py | 93 ++++- .../tests/test_apriel2/conftest.py | 152 +++++++- .../test_apriel2/test_compose_configs.py | 157 ++++++++ .../tests/test_apriel2/test_expr_plan.py | 202 +++++++++++ .../tests/test_apriel2/test_integration.py | 335 ++++++++++++++++++ .../tests/test_apriel2/test_modeling.py | 3 +- .../test_plan_composition_torture.py | 148 ++++++++ 13 files changed, 1452 insertions(+), 136 deletions(-) create mode 100644 fast_llm_external_models/tests/test_apriel2/test_integration.py diff --git a/fast_llm/models/gpt/conversion/apriel2.py b/fast_llm/models/gpt/conversion/apriel2.py index 7682196c8..eb5641aea 100644 --- a/fast_llm/models/gpt/conversion/apriel2.py +++ b/fast_llm/models/gpt/conversion/apriel2.py @@ -39,8 +39,20 @@ def import_config(cls, config: dict) -> dict: "head_groups": config["head_groups"], "head_size": config["head_size"], "rotary": rotary, - "add_linear_biases": config["add_linear_biases"], } + # Per-layer bias configuration mirroring Fast-LLM structure + # If per-layer configs exist, use them; otherwise fall back to add_linear_biases + if "query_layer" in config: + result["query_layer"] = config["query_layer"] + if "key_layer" in config: + result["key_layer"] = config["key_layer"] + if "value_layer" in config: + result["value_layer"] = config["value_layer"] + if "dense_layer" in config: + result["dense_layer"] = config["dense_layer"] + # add_linear_biases serves as default for layers without explicit config + if "add_linear_biases" in config: + result["add_linear_biases"] = config["add_linear_biases"] if "window_size" in config: result["window_size"] = config["window_size"] return result @@ -58,18 +70,37 @@ def export_config(cls, config: AttentionConfig) -> dict: else: raise NotImplementedError(f"Unsupported rotary type: {type(config.rotary).__name__}") - return { + result = { "type": "attention", "heads": config.heads, "head_groups": config.head_groups, "head_size": config.head_size, - "add_linear_biases": config.add_linear_biases, "rotary": { "type": rotary_type, "theta": config.rotary.theta, }, "window_size": config.window_size, } + # Export per-layer bias configuration + # Only include if explicitly set (not None) + if config.query_layer.bias.enabled is not None: + result["query_layer"] = {"bias": {"enabled": config.query_layer.bias.enabled}} + if config.key_layer.bias.enabled is not None: + result["key_layer"] = {"bias": {"enabled": config.key_layer.bias.enabled}} + if config.value_layer.bias.enabled is not None: + result["value_layer"] = {"bias": {"enabled": config.value_layer.bias.enabled}} + if config.dense_layer.bias.enabled is not None: + result["dense_layer"] = {"bias": {"enabled": config.dense_layer.bias.enabled}} + # add_linear_biases as fallback default + result["add_linear_biases"] = config.add_linear_biases + return result + + @classmethod + def _get_effective_bias(cls, layer_config, default: bool) -> bool: + """Get effective bias setting: use layer-specific if set, else default.""" + if layer_config.bias.enabled is not None: + return layer_config.bias.enabled + return default @classmethod def get_converters( @@ -79,11 +110,20 @@ def get_converters( hf_prefix: str, drop_on_export: bool = False, ) -> list[WeightConverter]: + # Determine effective bias for each projection + q_bias = cls._get_effective_bias(config.query_layer, config.add_linear_biases) + k_bias = cls._get_effective_bias(config.key_layer, config.add_linear_biases) + v_bias = cls._get_effective_bias(config.value_layer, config.add_linear_biases) + o_bias = cls._get_effective_bias(config.dense_layer, config.add_linear_biases) + # For key_value, both k and v must have same bias setting + # (they're combined in Fast-LLM's key_value layer) + kv_bias = k_bias and v_bias + return [ *get_weight_and_bias_converters( f"{fast_llm_prefix}.query", f"{hf_prefix}.q_proj", - config.add_linear_biases, + q_bias, QueryWeightConverter, config, drop_on_export=drop_on_export, @@ -91,7 +131,7 @@ def get_converters( *get_weight_and_bias_converters( f"{fast_llm_prefix}.key_value", (f"{hf_prefix}.k_proj", f"{hf_prefix}.v_proj"), - config.add_linear_biases, + kv_bias, KeyValueWeightConverter, config, drop_on_export=drop_on_export, @@ -99,7 +139,7 @@ def get_converters( *get_weight_and_bias_converters( f"{fast_llm_prefix}.dense", f"{hf_prefix}.o_proj", - config.add_linear_biases, + o_bias, drop_on_export=drop_on_export, ), ] @@ -524,6 +564,12 @@ def import_config(cls, config: dict, block_config: dict) -> dict: "gated": mlp_config["gated"], "add_linear_biases": mlp_config["add_linear_biases"], } + # Import per-layer MLP bias settings (layer_1, layer_2) + for layer_name in ("layer_1", "layer_2"): + if layer_name in mlp_config: + layer_cfg = mlp_config[layer_name] + if "bias" in layer_cfg: + mlp[layer_name] = {"bias": layer_cfg["bias"]} normalization = block_config["normalization"] @@ -578,6 +624,11 @@ def export_config(cls, config: DecoderBlockConfig) -> dict: "gated": config.mlp.gated, "add_linear_biases": config.mlp.add_linear_biases, } + # Export per-layer MLP bias settings (layer_1, layer_2) + if config.mlp.layer_1.bias.enabled is not None: + mlp["layer_1"] = {"bias": {"enabled": config.mlp.layer_1.bias.enabled}} + if config.mlp.layer_2.bias.enabled is not None: + mlp["layer_2"] = {"bias": {"enabled": config.mlp.layer_2.bias.enabled}} normalization = {"type": norm_type_str, "epsilon": config.normalization.epsilon} @@ -624,22 +675,52 @@ def get_converters( ) ) - converters.extend([ - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_1", - (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), - config.mlp.add_linear_biases, - SplitWeightConverter, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_2", - f"{hf_prefix}.mlp.down_proj", - config.mlp.add_linear_biases, - MLPLayer2Converter, - drop_on_export=drop_on_export, - ), - ]) + # Per-layer MLP bias: use layer-specific setting if set, else default + def get_mlp_layer_bias(layer_config, default: bool) -> bool: + if layer_config.bias.enabled is not None: + return layer_config.bias.enabled + return default + + layer_1_bias = get_mlp_layer_bias(config.mlp.layer_1, config.mlp.add_linear_biases) + layer_2_bias = get_mlp_layer_bias(config.mlp.layer_2, config.mlp.add_linear_biases) + + if config.mlp.gated: + # Gated MLP: gate_proj + up_proj -> layer_1 (split), down_proj -> layer_2 + converters.extend([ + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_1", + (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), + layer_1_bias, + SplitWeightConverter, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_2", + f"{hf_prefix}.mlp.down_proj", + layer_2_bias, + MLPLayer2Converter, + drop_on_export=drop_on_export, + ), + ]) + else: + # Non-gated MLP: up_proj -> layer_1, down_proj -> layer_2 + # Note: layer_2 still needs MLPLayer2Converter for the transpose + converters.extend([ + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_1", + f"{hf_prefix}.mlp.up_proj", + layer_1_bias, + WeightConverter, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_2", + f"{hf_prefix}.mlp.down_proj", + layer_2_bias, + MLPLayer2Converter, + drop_on_export=drop_on_export, + ), + ]) converters.extend([ *LlamaNormalizationConverter.get_converters( diff --git a/fast_llm_external_models/apriel2/conversion/__init__.py b/fast_llm_external_models/apriel2/conversion/__init__.py index 983a632e0..60fc0ef0a 100644 --- a/fast_llm_external_models/apriel2/conversion/__init__.py +++ b/fast_llm_external_models/apriel2/conversion/__init__.py @@ -63,7 +63,7 @@ target_config = compose_configs(source_config, surgery_spec) # 2. Build plan for weight transformation - plan = plan_surgery(source_config, surgery_spec) + plan = plan_surgery(source_config, target_config) # 3. Execute plan to transform weights target_weights = execute(plan, source_weights, seed=42) diff --git a/fast_llm_external_models/apriel2/conversion/config.py b/fast_llm_external_models/apriel2/conversion/config.py index 48f8ff44b..f5b19e208 100644 --- a/fast_llm_external_models/apriel2/conversion/config.py +++ b/fast_llm_external_models/apriel2/conversion/config.py @@ -1,56 +1,59 @@ """Config composition for Apriel2 architecture transformations. This module handles STRUCTURAL composition of configs, independent of weight handling. -The `init` field in surgery specs is preserved as metadata for the plan builder but -does not affect how configs are composed. +The `init` field in surgery specs is metadata for plan_surgery(), not for composition. -Composition Cases -================= +Algebraic Structure +=================== + +The system has a precise algebraic structure with two interacting components: -compose_configs(base, overlay) handles four cases based on completeness: +**Surgery Specs (Monoid)** + Partial config dicts form a monoid under deep merge: + - Identity: {} (empty dict) + - Operation: compose_configs(partial1, partial2) = deep_merge(partial1, partial2) + - Associativity: (a ∘ b) ∘ c = a ∘ (b ∘ c) -1. **Complete + Partial** β†’ Apply surgery semantics (inheritance, cross-type derivation) -2. **Partial + Partial** β†’ Deep merge (monoid operation on surgery specs) -3. **Partial + Complete** β†’ Overlay wins (complete config replaces partial) -4. **Complete + Complete** β†’ Deep merge, then strip `init` fields +**Complete Configs (Monoid Action)** + Surgery specs ACT on complete configs: + - Action: compose_configs(complete, partial) β†’ complete + - For additive surgeries: (s Β· t₁) Β· tβ‚‚ = s Β· (t₁ ∘ tβ‚‚) + - For replacement surgeries: action law intentionally fails (last-write-wins) -A config is "complete" if it has `hidden_size` and `decoder` (i.e., it's a full model -config, not a surgery spec). +This separation is fundamental: surgery specs compose declaratively (what fields to +merge), while the action on configs interprets those fields with inheritance semantics. -Surgery Semantics +Composition Cases ================= -When applying a surgery spec to a complete config: +compose_configs(base, overlay) dispatches based on completeness: + +1. **Complete + Partial** β†’ Monoid action (inheritance, cross-type derivation) +2. **Partial + Partial** β†’ Monoid operation (deep merge) +3. **Partial + Complete** β†’ Overlay wins (complete replaces partial) +4. **Complete + Complete** β†’ Deep merge, strip `init` fields -**Inheritance** - Unspecified parameters inherit from the source config. New blocks inherit - from the "default" block (first block in pattern, or the single fixed block). +A config is "complete" if it has `hidden_size` and `decoder`. -**Cross-Type Derivation** - When changing mixer types, geometric parameters are derived where possible: - - attention β†’ sliding_window: preserve heads, head_groups, head_size - - attention β†’ gdn: heads β†’ value_heads, head_groups β†’ key_heads - - attention β†’ mamba: derive d_inner, d_xb, dt_rank from hidden_size - - attention β†’ kda: preserve heads, head_size β†’ head_dim +Inheritance Semantics +===================== -**Stochastic Mixer Composition** - Two semantics based on whether surgery declares `type: stochastic`: - - Replacement: surgery declares type β†’ only surgery's sub-mixers included - - Additive: surgery omits type β†’ source sub-mixers preserved, surgery adds/modifies +When the monoid action applies a surgery to a complete config: - This distinction means the monoid action law holds for additive surgeries but - intentionally fails for replacement surgeries (they have "last-write-wins" semantics). +- Unspecified fields inherit from source +- New blocks inherit from the "default" block +- Cross-type derivation maps geometry (attention.heads β†’ gdn.value_heads, etc.) +- Stochastic mixers: additive (no type decl) preserves source, replacement replaces The `init` Field ================ -The `init` field is metadata for the plan builder, NOT for config composition: -- `init: transfer` β†’ plan builder creates weight transfer mappings -- `init: random` β†’ plan builder creates random initialization +The `init` field is metadata for plan_surgery(), NOT for config composition: +- `init: transfer` β†’ plan uses weight transfer/conversion +- `init: random` β†’ plan uses random initialization -After surgery is applied to produce a complete config, ALL `init` fields are stripped. -This ensures configs are purely structural and plan creation is Markovian (depends only -on current config + surgery, not on history). +After composition produces a complete config, ALL `init` fields are stripped. +This ensures configs are purely structural and plan creation is Markovian. """ from __future__ import annotations @@ -65,14 +68,49 @@ def is_complete(config: dict) -> bool: def compose_configs(base: dict, overlay: dict | None) -> dict: - """Compose two configs. + """Compose two configs using monoid or monoid action semantics. + + This function implements two algebraic operations depending on argument types: + + 1. **Monoid Action** (complete + partial): Apply surgery to a complete config. + Unspecified fields inherit from base; `init` fields are stripped from result. + + 2. **Monoid Operation** (partial + partial): Merge two surgery specs. + Deep merge with overlay winning on conflicts; `init` fields preserved. Args: - base: Base config (complete or partial surgery spec). - overlay: Overlay config (complete or partial surgery spec). + base: Base config (complete) or surgery spec (partial). + overlay: Surgery spec to apply (partial) or config to merge. Returns: - Composed config. + - If base is complete: Complete config with surgery applied, `init` stripped. + - If both partial: Merged surgery spec with `init` preserved. + + Algebraic Properties: + Surgery specs form a monoid: (a ∘ b) ∘ c = a ∘ (b ∘ c), identity = {} + + For additive surgeries, the action law holds: + compose(compose(s, t1), t2) == compose(s, compose(t1, t2)) + + For replacement surgeries (declaring type:), action law intentionally fails. + + Example: + # Apply surgery to complete config (monoid action) + source = {"hidden_size": 256, "decoder": {...}} # complete + surgery = {"decoder": {"block": {"mixer": {"type": "mamba"}}}} # partial + + target = compose_configs(source, surgery) + # target is complete with inherited fields, init stripped + + # Merge two surgery specs (monoid operation) + s1 = {"decoder": {"block": {"mixer": {"mixers": {"a": {...}}}}}} + s2 = {"decoder": {"block": {"mixer": {"mixers": {"b": {...}}}}}} + + merged = compose_configs(s1, s2) + # merged has both mixers a and b, init preserved + + # Use composed config with plan_surgery + plan = plan_surgery(source, target) """ if not overlay: return copy.deepcopy(base) @@ -134,20 +172,24 @@ def _strip_keys(config: Any, keys_to_strip: set[str]) -> None: def apply_surgery(source_config: dict, surgery_config: dict | None) -> dict: - """Apply surgery specification to a complete source config. + """Apply surgery spec to complete config (the monoid action). + + This is the internal implementation of the monoid action: surgery specs + acting on complete configs. Called by compose_configs when base is complete + and overlay is partial. - This handles: - - Top-level scalar overrides - - Decoder composition (fixed vs pattern) - - Stochastic mixer sub-mixer inheritance - - Cross-type derivation (attention β†’ gdn, attention β†’ mamba) + Implements inheritance semantics: + - Unspecified fields inherit from source + - Cross-type derivation maps geometry (attention β†’ gdn, etc.) + - Stochastic sub-mixers inherit from source's main mixer + - All `init` fields stripped from result Args: - source_config: Complete Apriel2 config. - surgery_config: Partial surgery specification. + source_config: Complete Apriel2 config (the state being acted on). + surgery_config: Partial surgery spec (the monoid element acting). Returns: - Complete Apriel2 config with surgery applied. + Complete config with surgery applied, `init` fields stripped. """ if not surgery_config: return copy.deepcopy(source_config) @@ -392,6 +434,12 @@ def _compose_single_mixer(source: dict, surgery: dict, hidden_size: int) -> dict result[key] = surgery[key] elif key in source: result[key] = source[key] + # Copy per-layer bias settings (query_layer, key_layer, value_layer, dense_layer) + for key in ["query_layer", "key_layer", "value_layer", "dense_layer", "add_linear_biases"]: + if key in surgery: + result[key] = surgery[key] + elif key in source: + result[key] = copy.deepcopy(source[key]) # Preserve init if "init" in surgery: result["init"] = surgery["init"] diff --git a/fast_llm_external_models/apriel2/conversion/converters.py b/fast_llm_external_models/apriel2/conversion/converters.py index 6d1350c54..b54bb5a87 100644 --- a/fast_llm_external_models/apriel2/conversion/converters.py +++ b/fast_llm_external_models/apriel2/conversion/converters.py @@ -79,6 +79,21 @@ # This is the single source of truth for each mixer's weight schema. +def _get_attention_bias_enabled(config: dict, layer_name: str) -> bool: + """Get whether bias is enabled for an attention layer. + + Checks per-layer bias config (e.g., query_layer.bias.enabled). + Falls back to add_linear_biases if not set. + """ + layer_cfg = config.get(layer_name, {}) + bias_cfg = layer_cfg.get("bias", {}) + enabled = bias_cfg.get("enabled") + if enabled is not None: + return enabled + # Fall back to add_linear_biases + return config.get("add_linear_biases", False) + + def _plan_attention_mixer( *, prefix: W, @@ -90,9 +105,13 @@ def _plan_attention_mixer( Weight schema: - q_proj.weight: (q_size, hidden_size) + - q_proj.bias: (q_size,) [if query_layer.bias.enabled] - k_proj.weight: (kv_size, hidden_size) + - k_proj.bias: (kv_size,) [if key_layer.bias.enabled] - v_proj.weight: (kv_size, hidden_size) + - v_proj.bias: (kv_size,) [if value_layer.bias.enabled] - o_proj.weight: (hidden_size, q_size) + - o_proj.bias: (hidden_size,) [if dense_layer.bias.enabled] Args: prefix: Target weight path prefix. @@ -100,12 +119,28 @@ def _plan_attention_mixer( hidden_size: Model hidden size. source_prefix: If provided, passthrough from source. If None, random init. """ + # Check per-layer bias configuration + q_bias = _get_attention_bias_enabled(config, "query_layer") + k_bias = _get_attention_bias_enabled(config, "key_layer") + v_bias = _get_attention_bias_enabled(config, "value_layer") + o_bias = _get_attention_bias_enabled(config, "dense_layer") + if source_prefix is not None: - # Passthrough - return ExprPlan(mappings={ + # Passthrough weights + mappings: dict[W, Expr] = { prefix / proj / "weight": Ref(key=source_prefix / proj / "weight") for proj in ["q_proj", "k_proj", "v_proj", "o_proj"] - }) + } + # Passthrough biases if enabled + if q_bias: + mappings[prefix / "q_proj" / "bias"] = Ref(key=source_prefix / "q_proj" / "bias") + if k_bias: + mappings[prefix / "k_proj" / "bias"] = Ref(key=source_prefix / "k_proj" / "bias") + if v_bias: + mappings[prefix / "v_proj" / "bias"] = Ref(key=source_prefix / "v_proj" / "bias") + if o_bias: + mappings[prefix / "o_proj" / "bias"] = Ref(key=source_prefix / "o_proj" / "bias") + return ExprPlan(mappings=mappings) # Random init heads = config["heads"] @@ -114,12 +149,22 @@ def _plan_attention_mixer( q_size = heads * head_size kv_size = head_groups * head_size - return ExprPlan(mappings={ + mappings = { prefix / "q_proj" / "weight": Init(shape=(q_size, hidden_size), init_type="kaiming"), prefix / "k_proj" / "weight": Init(shape=(kv_size, hidden_size), init_type="kaiming"), prefix / "v_proj" / "weight": Init(shape=(kv_size, hidden_size), init_type="kaiming"), prefix / "o_proj" / "weight": Init(shape=(hidden_size, q_size), init_type="kaiming"), - }) + } + # Random init biases if enabled + if q_bias: + mappings[prefix / "q_proj" / "bias"] = Init(shape=(q_size,), init_type="zeros") + if k_bias: + mappings[prefix / "k_proj" / "bias"] = Init(shape=(kv_size,), init_type="zeros") + if v_bias: + mappings[prefix / "v_proj" / "bias"] = Init(shape=(kv_size,), init_type="zeros") + if o_bias: + mappings[prefix / "o_proj" / "bias"] = Init(shape=(hidden_size,), init_type="zeros") + return ExprPlan(mappings=mappings) def _plan_mamba_mixer( @@ -786,7 +831,45 @@ def plan_surgery( source_config: dict, target_config: dict, ) -> ExprPlan: - """Build plan for Apriel2β†’Apriel2 surgery (MIL, DIL, KIL, stochastic mixers, etc.).""" + """Build a weight conversion plan between two Apriel2 configurations. + + This function creates an ExprPlan that maps source weight keys to expressions + defining how to compute target weights. The plan handles same-type passthrough, + cross-type conversions (MIL, DIL, KIL), and stochastic mixer routing. + + Args: + source_config: Complete Apriel2 config dict describing the source architecture. + Must have all structural fields (hidden_size, decoder, etc.) fully specified. + target_config: Complete Apriel2 config dict describing the target architecture. + Must be fully specified with all inherited fields resolved. Use + compose_configs(source_config, surgery_spec) to produce this from a + partial surgery specification. + + Returns: + ExprPlan mapping target weight keys to expressions over source weights. + + Example: + # Apply a surgery that wraps attention in a stochastic mixer + surgery_spec = { + "decoder": {"block": {"mixer": { + "type": "stochastic", + "mixers": {"attention": {"type": "attention", "init": "transfer"}} + }}} + } + + # First compose to get complete target config with inherited fields + target_config = compose_configs(source_config, surgery_spec) + + # Then build the plan from two complete configs + plan = plan_surgery(source_config, target_config) + new_weights = execute(plan, source_weights, seed=0) + + Note: + Both arguments must be complete configs. The target_config determines the + full target architecture including all inherited fields (bias settings, + rotary config, etc.). Passing a partial surgery spec directly will result + in missing inherited fields and incorrect plans. + """ hidden_size = target_config.get("hidden_size", source_config.get("hidden_size")) assert hidden_size is not None, "hidden_size must be specified in source or target config" @@ -845,8 +928,8 @@ def _plan_non_decoder_weights(config: dict) -> ExprPlan: norm = W("model", "norm", "weight") mappings[norm] = Ref(key=norm) - if "vision_encoder" in config: - vision_config = config["vision_encoder"] + vision_config = config.get("vision_encoder") + if vision_config: vision = W("model", "vision_encoder") patch_emb = vision / "embeddings" / "patch_embeddings" / "weight" @@ -986,6 +1069,24 @@ def _plan_mixer( ) +def _get_mlp_bias_enabled(config: dict, layer_name: str) -> bool: + """Get whether bias is enabled for an MLP layer. + + Checks per-layer bias config (e.g., layer_1.bias.enabled, layer_2.bias.enabled). + Falls back to add_linear_biases if not set. + + Note: layer_1 corresponds to gate_proj and up_proj (gated MLP) or just up_proj (non-gated) + layer_2 corresponds to down_proj + """ + layer_cfg = config.get(layer_name, {}) + bias_cfg = layer_cfg.get("bias", {}) + enabled = bias_cfg.get("enabled") + if enabled is not None: + return enabled + # Fall back to add_linear_biases + return config.get("add_linear_biases", False) + + def _plan_mlp( target_layer_idx: int, source_layer_idx: int, @@ -1006,7 +1107,7 @@ def _plan_mlp_transfer( target_mlp: dict, hidden_size: int, ) -> ExprPlan: - """Passthrough for MLP weights.""" + """Passthrough for MLP weights and biases.""" source_mlp_path = W("model", "decoder", "blocks", source_layer_idx, "mlp") target_mlp_path = W("model", "decoder", "blocks", target_layer_idx, "mlp") @@ -1019,10 +1120,37 @@ def _plan_mlp_transfer( f"Use 'init: random' to initialize randomly." ) - return ExprPlan(mappings={ + # Check per-layer bias configuration + layer_1_bias = _get_mlp_bias_enabled(target_mlp, "layer_1") + layer_2_bias = _get_mlp_bias_enabled(target_mlp, "layer_2") + + # Check if gated MLP (has gate_proj) or non-gated (only up_proj) + gated = target_mlp.get("gated", True) # Default to gated for backwards compatibility + + # Passthrough weights + # layer_1 = gate_proj + up_proj (gated) or just up_proj (non-gated) + # layer_2 = down_proj + if gated: + weight_projs = ["gate_proj", "up_proj", "down_proj"] + else: + weight_projs = ["up_proj", "down_proj"] + + mappings: dict[W, Expr] = { target_mlp_path / proj / "weight": Ref(key=source_mlp_path / proj / "weight") - for proj in ["gate_proj", "up_proj", "down_proj"] - }) + for proj in weight_projs + } + + # Passthrough biases if enabled + if layer_1_bias: + if gated: + mappings[target_mlp_path / "gate_proj" / "bias"] = Ref(key=source_mlp_path / "gate_proj" / "bias") + mappings[target_mlp_path / "up_proj" / "bias"] = Ref(key=source_mlp_path / "up_proj" / "bias") + + # layer_2 = down_proj + if layer_2_bias: + mappings[target_mlp_path / "down_proj" / "bias"] = Ref(key=source_mlp_path / "down_proj" / "bias") + + return ExprPlan(mappings=mappings) def _plan_random_mlp( @@ -1030,20 +1158,41 @@ def _plan_random_mlp( target_mlp: dict, hidden_size: int, ) -> ExprPlan: - """Random initialization for MLP.""" + """Random initialization for MLP weights and biases.""" target_mlp_path = W("model", "decoder", "blocks", target_layer_idx, "mlp") intermediate_size = target_mlp["intermediate_size"] - return ExprPlan(mappings={ - target_mlp_path / "gate_proj" / "weight": Init( - shape=(intermediate_size, hidden_size), init_type="kaiming" - ), - target_mlp_path / "up_proj" / "weight": Init( + + # Check per-layer bias configuration + layer_1_bias = _get_mlp_bias_enabled(target_mlp, "layer_1") + layer_2_bias = _get_mlp_bias_enabled(target_mlp, "layer_2") + + # Check if gated MLP (has gate_proj) or non-gated (only up_proj) + gated = target_mlp.get("gated", True) # Default to gated for backwards compatibility + + # Random init weights + mappings: dict[W, Expr] = {} + if gated: + mappings[target_mlp_path / "gate_proj" / "weight"] = Init( shape=(intermediate_size, hidden_size), init_type="kaiming" - ), - target_mlp_path / "down_proj" / "weight": Init( - shape=(hidden_size, intermediate_size), init_type="kaiming" - ), - }) + ) + mappings[target_mlp_path / "up_proj" / "weight"] = Init( + shape=(intermediate_size, hidden_size), init_type="kaiming" + ) + mappings[target_mlp_path / "down_proj" / "weight"] = Init( + shape=(hidden_size, intermediate_size), init_type="kaiming" + ) + + # Random init biases if enabled + if layer_1_bias: + if gated: + mappings[target_mlp_path / "gate_proj" / "bias"] = Init(shape=(intermediate_size,), init_type="zeros") + mappings[target_mlp_path / "up_proj" / "bias"] = Init(shape=(intermediate_size,), init_type="zeros") + + # layer_2 = down_proj + if layer_2_bias: + mappings[target_mlp_path / "down_proj" / "bias"] = Init(shape=(hidden_size,), init_type="zeros") + + return ExprPlan(mappings=mappings) def _plan_norms( diff --git a/fast_llm_external_models/apriel2/conversion/qwen2/config.py b/fast_llm_external_models/apriel2/conversion/qwen2/config.py index 36df744c0..70629fe0e 100644 --- a/fast_llm_external_models/apriel2/conversion/qwen2/config.py +++ b/fast_llm_external_models/apriel2/conversion/qwen2/config.py @@ -23,11 +23,7 @@ def convert_config(qwen2_config: dict) -> dict: num_key_value_heads = qwen2_config.get("num_key_value_heads", num_attention_heads) head_dim = hidden_size // num_attention_heads - # Qwen2 uses QKV bias but not O bias - # The add_linear_biases in Apriel2 attention config controls all biases uniformly, - # but we can set it to True and the o_proj bias will just be missing from weights - # (handled by strict=False loading or explicit handling in the plan) - + # Qwen2 uses QKV bias but not O bias - mirror Fast-LLM's per-layer config return { "model_type": "apriel2_text", "architectures": ["Apriel2ForCausalLM"], @@ -48,9 +44,11 @@ def convert_config(qwen2_config: dict) -> dict: "heads": num_attention_heads, "head_groups": num_key_value_heads, "head_size": head_dim, - # Qwen2 has QKV bias but not O bias - # We set True and handle O bias separately - "add_linear_biases": True, + # Per-layer bias config matching Fast-LLM structure + "query_layer": {"bias": {"enabled": True}}, + "key_layer": {"bias": {"enabled": True}}, + "value_layer": {"bias": {"enabled": True}}, + "dense_layer": {"bias": {"enabled": False}}, "rotary": { "type": "mistral_1d", "theta": qwen2_config.get("rope_theta", 1000000.0), diff --git a/fast_llm_external_models/apriel2/conversion/qwen2/plan.py b/fast_llm_external_models/apriel2/conversion/qwen2/plan.py index e5ae3e9d8..7752d37c9 100644 --- a/fast_llm_external_models/apriel2/conversion/qwen2/plan.py +++ b/fast_llm_external_models/apriel2/conversion/qwen2/plan.py @@ -3,7 +3,6 @@ from fast_llm_external_models.apriel2.conversion.expr import ( Expr, ExprPlan, - Init, Ref, W, ) @@ -23,15 +22,19 @@ def plan_qwen2_to_apriel2(qwen2_config: dict) -> ExprPlan: model.layers.{i}.input_layernorm.weight -> model.decoder.blocks.{i}.input_layernorm.weight model.layers.{i}.post_attention_layernorm.weight -> model.decoder.blocks.{i}.post_attention_layernorm.weight model.layers.{i}.self_attn.q_proj.weight -> model.decoder.blocks.{i}.mixer.q_proj.weight + model.layers.{i}.self_attn.q_proj.bias -> model.decoder.blocks.{i}.mixer.q_proj.bias model.layers.{i}.self_attn.k_proj.weight -> model.decoder.blocks.{i}.mixer.k_proj.weight + model.layers.{i}.self_attn.k_proj.bias -> model.decoder.blocks.{i}.mixer.k_proj.bias model.layers.{i}.self_attn.v_proj.weight -> model.decoder.blocks.{i}.mixer.v_proj.weight + model.layers.{i}.self_attn.v_proj.bias -> model.decoder.blocks.{i}.mixer.v_proj.bias model.layers.{i}.self_attn.o_proj.weight -> model.decoder.blocks.{i}.mixer.o_proj.weight model.layers.{i}.mlp.gate_proj.weight -> model.decoder.blocks.{i}.mlp.gate_proj.weight model.layers.{i}.mlp.up_proj.weight -> model.decoder.blocks.{i}.mlp.up_proj.weight model.layers.{i}.mlp.down_proj.weight -> model.decoder.blocks.{i}.mlp.down_proj.weight - Note: Qwen2 has QKV biases but no O bias. We skip the biases in the conversion - since Apriel2 is configured with add_linear_biases=False for uniform handling. + Note: Qwen2 has QKV biases but no O bias. The Apriel2 config uses per-layer + bias settings (query_layer.bias.enabled=True, dense_layer.bias.enabled=False) + to match this exactly - no workarounds needed. Args: qwen2_config: HuggingFace Qwen2Config as dict @@ -42,7 +45,6 @@ def plan_qwen2_to_apriel2(qwen2_config: dict) -> ExprPlan: mappings: dict[str, Expr] = {} num_layers = qwen2_config["num_hidden_layers"] - hidden_size = qwen2_config["hidden_size"] # Static mappings (embeddings and final norm) # Note: Qwen2 safetensor keys have "model." prefix @@ -66,8 +68,7 @@ def plan_qwen2_to_apriel2(qwen2_config: dict) -> ExprPlan: qwen_layer = W("model", "layers", layer) apriel_layer = W("model", "decoder", "blocks", layer) - # Attention projections (weights and biases) - # Qwen2 has QKV bias but no O bias + # Attention projection weights for proj in ["q_proj", "k_proj", "v_proj", "o_proj"]: src = qwen_layer / "self_attn" / proj / "weight" tgt = apriel_layer / "mixer" / proj / "weight" @@ -79,12 +80,7 @@ def plan_qwen2_to_apriel2(qwen2_config: dict) -> ExprPlan: tgt = apriel_layer / "mixer" / proj / "bias" mappings[tgt] = Ref(key=src) - # O bias - Qwen2 doesn't have this, so initialize to zeros - # Shape is hidden_size (d_model) - mappings[apriel_layer / "mixer" / "o_proj" / "bias"] = Init( - shape=(hidden_size,), - init_type="zeros", - ) + # Note: o_proj has no bias in Qwen2, and Apriel2 config has dense_layer.bias.enabled=False # MLP projections for proj in ["gate_proj", "up_proj", "down_proj"]: diff --git a/fast_llm_external_models/apriel2/modeling_apriel2.py b/fast_llm_external_models/apriel2/modeling_apriel2.py index 4c263b4e2..878677653 100644 --- a/fast_llm_external_models/apriel2/modeling_apriel2.py +++ b/fast_llm_external_models/apriel2/modeling_apriel2.py @@ -24,8 +24,8 @@ is_torch_flex_attn_available, ) -from fast_llm_external_models.apriel2.cache import Apriel2Cache -from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config, Apriel2TextConfig +from .cache import Apriel2Cache +from .configuration_apriel2 import Apriel2Config, Apriel2TextConfig # GDN implementation - matches Fast-LLM's gdn.py exactly try: @@ -395,14 +395,30 @@ def __init__(self, d_model: int, mixer_config: dict, layer_idx: int, config): # cross_document_attention: if False, use cu_seqlens to isolate sequences (e.g., images) self.cross_document_attention = mixer_config.get("cross_document_attention", True) - # Whether to add biases to linear projections - add_bias = mixer_config.get("add_linear_biases", False) - - # Projections (Fast-LLM weight names: q_proj, k_proj, v_proj, o_proj) - self.q_proj = nn.Linear(d_model, self.num_heads * self.head_dim, bias=add_bias) - self.k_proj = nn.Linear(d_model, self.num_key_value_heads * self.head_dim, bias=add_bias) - self.v_proj = nn.Linear(d_model, self.num_key_value_heads * self.head_dim, bias=add_bias) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, d_model, bias=add_bias) + # Bias configuration mirroring Fast-LLM's structure: + # - add_linear_biases: bool (default for all projections) + # - query_layer: {"bias": {"enabled": bool}} (per-layer override) + # - key_layer: {"bias": {"enabled": bool}} + # - value_layer: {"bias": {"enabled": bool}} + # - dense_layer: {"bias": {"enabled": bool}} + default_bias = mixer_config.get("add_linear_biases", False) + + def get_layer_bias(layer_name: str) -> bool: + layer_cfg = mixer_config.get(layer_name, {}) + bias_cfg = layer_cfg.get("bias", {}) + enabled = bias_cfg.get("enabled") + return default_bias if enabled is None else enabled + + q_bias = get_layer_bias("query_layer") + k_bias = get_layer_bias("key_layer") + v_bias = get_layer_bias("value_layer") + o_bias = get_layer_bias("dense_layer") + + # Projections + self.q_proj = nn.Linear(d_model, self.num_heads * self.head_dim, bias=q_bias) + self.k_proj = nn.Linear(d_model, self.num_key_value_heads * self.head_dim, bias=k_bias) + self.v_proj = nn.Linear(d_model, self.num_key_value_heads * self.head_dim, bias=v_bias) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, d_model, bias=o_bias) @classmethod def setup( @@ -1828,16 +1844,36 @@ def __init__( self.post_attention_layernorm = self._create_norm(norm_config, hidden_size, rms_norm_eps) def _create_mlp(self, mlp_config: dict, hidden_size: int): - """Create MLP based on config.""" + """Create MLP based on config. + + Supports per-layer bias configuration mirroring Fast-LLM: + - add_linear_biases: default bias setting for all layers + - layer_1.bias.enabled: override for up_proj/gate_proj + - layer_2.bias.enabled: override for down_proj + """ mlp_type = mlp_config.get("type", "mlp") if mlp_type == "mlp": intermediate_size = mlp_config["intermediate_size"] activation = mlp_config.get("activation", "silu") - gated = mlp_config["gated"] - bias = mlp_config.get("add_linear_biases", False) + gated = mlp_config.get("gated", False) + + # Per-layer bias configuration (mirrors Fast-LLM structure) + default_bias = mlp_config.get("add_linear_biases", False) + + def get_layer_bias(layer_name: str) -> bool: + layer_cfg = mlp_config.get(layer_name, {}) + bias_cfg = layer_cfg.get("bias", {}) + enabled = bias_cfg.get("enabled") + return default_bias if enabled is None else enabled + + layer_1_bias = get_layer_bias("layer_1") + layer_2_bias = get_layer_bias("layer_2") if gated: + # MistralMLP uses gate_proj, up_proj, down_proj (all bias controlled together) + # For now, we use the default bias setting for gated MLPs + # TODO: Add per-layer bias support to gated MLP mlp_cfg = SimpleNamespace( hidden_size=hidden_size, intermediate_size=intermediate_size, @@ -1845,7 +1881,13 @@ def _create_mlp(self, mlp_config: dict, hidden_size: int): ) return MistralMLP(mlp_cfg) else: - return SimpleMLP(hidden_size, intermediate_size, activation, bias) + return SimpleMLP( + hidden_size, + intermediate_size, + activation, + layer_1_bias=layer_1_bias, + layer_2_bias=layer_2_bias, + ) else: raise ValueError(f"Unknown MLP type: {mlp_type}") @@ -2179,6 +2221,8 @@ def forward( class Apriel2ForCausalLM(Apriel2PreTrainedModel, GenerationMixin): """Apriel2 model with a language modeling head (text-only).""" + _tied_weights_keys = ["lm_head.weight"] + def __init__(self, config: Apriel2TextConfig): super().__init__(config) self.model = Apriel2TextModel(config) @@ -2186,6 +2230,7 @@ def __init__(self, config: Apriel2TextConfig): self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing + # post_init() calls init_weights() which calls tie_weights() if config.tie_word_embeddings self.post_init() def get_input_embeddings(self): @@ -2583,14 +2628,26 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: class SimpleMLP(nn.Module): - """Non-gated MLP: up_proj -> activation -> down_proj.""" + """Non-gated MLP: up_proj -> activation -> down_proj. - def __init__(self, hidden_size: int, intermediate_size: int, activation: str = "silu", bias: bool = False): + Supports per-layer bias configuration mirroring Fast-LLM: + - layer_1_bias: bias for up_proj (layer_1 in Fast-LLM naming) + - layer_2_bias: bias for down_proj (layer_2 in Fast-LLM naming) + """ + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + activation: str = "silu", + layer_1_bias: bool = False, + layer_2_bias: bool = False, + ): super().__init__() from transformers.activations import ACT2FN - self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=bias) - self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=bias) + self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=layer_1_bias) + self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=layer_2_bias) self.act_fn = ACT2FN[activation] def forward(self, x): diff --git a/fast_llm_external_models/tests/test_apriel2/conftest.py b/fast_llm_external_models/tests/test_apriel2/conftest.py index 5c127d97e..cf190b50a 100644 --- a/fast_llm_external_models/tests/test_apriel2/conftest.py +++ b/fast_llm_external_models/tests/test_apriel2/conftest.py @@ -10,16 +10,40 @@ from fast_llm_external_models.apriel2.cache import _AttentionCache, _SSMCache +# Register custom marks +def pytest_configure(config): + config.addinivalue_line("markers", "slow: marks tests as slow (deselect with '-m \"not slow\"')") + + +def _can_import_fast_llm(): + """Check if Fast-LLM is available.""" + try: + from fast_llm.engine.checkpoint.convert import ConvertConfig + return True + except ImportError: + return False + + # Skip marker for tests that require CUDA for Mamba forward pass requires_cuda = pytest.mark.skipif( not torch.cuda.is_available(), reason="SSM mixers (Mamba) require CUDA for forward pass" ) +# Skip marker for tests that require Fast-LLM +requires_fastllm = pytest.mark.skipif( + not _can_import_fast_llm(), + reason="Fast-LLM not available" +) + -@pytest.fixture(autouse=True) +@pytest.fixture(scope="module", autouse=True) def set_default_device(): - """Set default device to CUDA for all tests (Mamba requires CUDA).""" + """Set default device to CUDA for all tests (Mamba requires CUDA). + + Module-scoped to ensure it runs before any module-scoped fixtures + that load models (e.g., qwen2_model_and_tokenizer). + """ if torch.cuda.is_available(): old_device = torch.get_default_device() torch.set_default_device("cuda") @@ -29,9 +53,12 @@ def set_default_device(): yield -@pytest.fixture(autouse=True) +@pytest.fixture(scope="module", autouse=True) def set_default_dtype(): - """Set default dtype to float32 for numerical comparison tests.""" + """Set default dtype to float32 for numerical comparison tests. + + Module-scoped to ensure it runs before any module-scoped fixtures. + """ old_dtype = torch.get_default_dtype() torch.set_default_dtype(torch.float32) yield @@ -763,6 +790,52 @@ def apriel2_config_comprehensive(): ) +@pytest.fixture +def apriel2_config_with_bias(): + """Apriel2 config with Qwen-style per-layer bias and non-gated MLP. + + This config exercises: + - Per-layer attention bias (QKV bias enabled, O bias disabled) + - Non-gated MLP with per-layer bias (layer_1 enabled, layer_2 disabled) + - Config structure parity with Fast-LLM's AffineLinearConfig + + Critical for testing bias inheritance through surgery operations. + """ + from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config + + return Apriel2Config( + vocab_size=100, + hidden_size=64, + decoder={ + "type": "fixed", + "num_blocks": 2, + "block": { + "mixer": { + "type": "attention", + "heads": 4, + "head_groups": 2, + "head_size": 16, + "rotary": {"type": "mistral_1d", "theta": 10000.0}, + # Qwen-style: QKV bias enabled, O bias disabled + "query_layer": {"bias": {"enabled": True}}, + "key_layer": {"bias": {"enabled": True}}, + "value_layer": {"bias": {"enabled": True}}, + "dense_layer": {"bias": {"enabled": False}}, + }, + "mlp": { + "type": "mlp", + "intermediate_size": 256, + "gated": False, # Non-gated MLP (SimpleMLP) + # Per-layer MLP bias + "layer_1": {"bias": {"enabled": True}}, + "layer_2": {"bias": {"enabled": False}}, + }, + "normalization": {"type": "rms_norm", "epsilon": 1e-5}, + }, + }, + ) + + @pytest.fixture def apriel2_cache(apriel2_config_tiny): """Create empty Apriel2Cache from tiny config.""" @@ -865,6 +938,77 @@ def additive_surgery_chain(): ] +@pytest.fixture +def bias_surgery_chain(): + """Surgery chain that exercises bias inheritance through surgery operations. + + Designed to be used with apriel2_config_with_bias as the source config. + Tests that per-layer bias settings (Qwen-style QKV bias, non-gated MLP bias) + are correctly inherited through: + - Stochastic wrapper creation + - Adding new sub-mixers that inherit from source + - Cross-type derivation (attention β†’ sliding_window) + + Source config has: + - Attention: query/key/value bias enabled, dense bias disabled + - MLP: layer_1 bias enabled, layer_2 bias disabled (non-gated) + """ + return [ + # S1: Wrap in stochastic - bias should transfer to attention sub-mixer + { + "decoder": { + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"init": "transfer"}, + }, + }, + "mlp": {"init": "transfer"}, + "normalization": {"init": "transfer"}, + }, + }, + }, + # S2: Add sliding_window - should inherit bias from source attention + { + "decoder": { + "block": { + "mixer": { + "mixers": { + "sliding_window": { + "type": "attention", + "init": "transfer", + "window_size": 512, + }, + }, + }, + }, + }, + }, + # S3: Add new attention with DIFFERENT bias config (random init) + { + "decoder": { + "block": { + "mixer": { + "mixers": { + "full_bias_attn": { + "type": "attention", + "init": "random", + "heads": 4, + "head_groups": 2, + "head_size": 16, + "rotary": {"type": "mistral_1d", "theta": 10000.0}, + "add_linear_biases": True, # All biases enabled + }, + }, + }, + }, + }, + }, + ] + + @pytest.fixture def comprehensive_torture_chain(): """Comprehensive torture chain exercising ALL conversion paths. diff --git a/fast_llm_external_models/tests/test_apriel2/test_compose_configs.py b/fast_llm_external_models/tests/test_apriel2/test_compose_configs.py index 0bd6ac88d..4380b1fbd 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_compose_configs.py +++ b/fast_llm_external_models/tests/test_apriel2/test_compose_configs.py @@ -288,6 +288,163 @@ def test_init_random_still_inherits_config(self, source_config): assert mixer["window_size"] == 512 +class TestBiasConfigInheritance: + """Test per-layer bias inheritance through surgery composition. + + These tests verify that the per-layer bias configuration (mirroring Fast-LLM's + AffineLinearConfig) is correctly inherited through surgery operations: + - query_layer.bias.enabled, key_layer.bias.enabled, etc. for attention + - layer_1.bias.enabled, layer_2.bias.enabled for MLP + """ + + @pytest.fixture + def source_config_with_bias(self): + """Source config with Qwen-style bias (QKV enabled, O disabled).""" + return { + "model_type": "apriel2", + "architectures": ["Apriel2ForCausalLM"], + "hidden_size": 256, + "vocab_size": 1000, + "decoder": { + "type": "fixed", + "num_blocks": 4, + "block": { + "mixer": { + "type": "attention", + "heads": 8, + "head_groups": 4, + "head_size": 32, + "rotary": {"type": "mistral_1d", "theta": 10000.0}, + # Qwen-style per-layer bias + "query_layer": {"bias": {"enabled": True}}, + "key_layer": {"bias": {"enabled": True}}, + "value_layer": {"bias": {"enabled": True}}, + "dense_layer": {"bias": {"enabled": False}}, + }, + "mlp": { + "type": "mlp", + "intermediate_size": 512, + "gated": False, + # Per-layer MLP bias + "layer_1": {"bias": {"enabled": True}}, + "layer_2": {"bias": {"enabled": False}}, + }, + "normalization": {"type": "rms_norm", "epsilon": 1e-5}, + }, + }, + } + + def test_same_type_inherits_attention_bias(self, source_config_with_bias): + """Same-type surgery inherits per-layer attention bias settings.""" + surgery = { + "decoder": { + "block": { + "mixer": { + "window_size": 512, # Add sliding window behavior + }, + }, + }, + } + result = compose_configs(source_config_with_bias, surgery) + + mixer = result["decoder"]["block"]["mixer"] + assert mixer["query_layer"]["bias"]["enabled"] is True + assert mixer["key_layer"]["bias"]["enabled"] is True + assert mixer["value_layer"]["bias"]["enabled"] is True + assert mixer["dense_layer"]["bias"]["enabled"] is False + + def test_same_type_inherits_mlp_bias(self, source_config_with_bias): + """Same-type surgery inherits per-layer MLP bias settings.""" + surgery = { + "decoder": { + "block": { + "mlp": { + "intermediate_size": 1024, # Change size + }, + }, + }, + } + result = compose_configs(source_config_with_bias, surgery) + + mlp = result["decoder"]["block"]["mlp"] + assert mlp["layer_1"]["bias"]["enabled"] is True + assert mlp["layer_2"]["bias"]["enabled"] is False + assert mlp["intermediate_size"] == 1024 + + def test_cross_type_attention_to_sliding_window_preserves_bias(self, source_config_with_bias): + """attentionβ†’sliding_window cross-type preserves per-layer bias.""" + surgery = { + "decoder": { + "block": { + "mixer": { + "type": "sliding_window", # Cross-type derivation + "window_size": 512, + }, + }, + }, + } + result = compose_configs(source_config_with_bias, surgery) + + mixer = result["decoder"]["block"]["mixer"] + assert mixer["type"] == "sliding_window" + # Bias settings preserved through cross-type + assert mixer["query_layer"]["bias"]["enabled"] is True + assert mixer["key_layer"]["bias"]["enabled"] is True + assert mixer["value_layer"]["bias"]["enabled"] is True + assert mixer["dense_layer"]["bias"]["enabled"] is False + + def test_stochastic_wrapper_inherits_bias(self, source_config_with_bias): + """Wrapping in stochastic inherits bias settings to all sub-mixers.""" + surgery = { + "decoder": { + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"init": "transfer"}, + "sliding_window": { + "type": "sliding_window", + "window_size": 512, + "init": "transfer", + }, + }, + }, + }, + }, + } + result = compose_configs(source_config_with_bias, surgery) + + mixers = result["decoder"]["block"]["mixer"]["mixers"] + + # Attention sub-mixer inherits bias + assert mixers["attention"]["query_layer"]["bias"]["enabled"] is True + assert mixers["attention"]["dense_layer"]["bias"]["enabled"] is False + + # Sliding window sub-mixer also inherits bias + assert mixers["sliding_window"]["query_layer"]["bias"]["enabled"] is True + assert mixers["sliding_window"]["dense_layer"]["bias"]["enabled"] is False + + def test_surgery_can_override_bias(self, source_config_with_bias): + """Surgery can explicitly override inherited bias settings.""" + surgery = { + "decoder": { + "block": { + "mixer": { + "dense_layer": {"bias": {"enabled": True}}, # Override O bias + }, + }, + }, + } + result = compose_configs(source_config_with_bias, surgery) + + mixer = result["decoder"]["block"]["mixer"] + # Q/K/V unchanged + assert mixer["query_layer"]["bias"]["enabled"] is True + # O bias overridden + assert mixer["dense_layer"]["bias"]["enabled"] is True + + class TestComposeConfigsRealYAML: """Test compose_configs with real YAML surgery files.""" diff --git a/fast_llm_external_models/tests/test_apriel2/test_expr_plan.py b/fast_llm_external_models/tests/test_apriel2/test_expr_plan.py index c487ab3a3..569ed88fd 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_expr_plan.py +++ b/fast_llm_external_models/tests/test_apriel2/test_expr_plan.py @@ -1711,3 +1711,205 @@ def test_conversion_plan_targets_match_model_state_dict(self, llava_pixtral_conf assert not missing_from_plan, f"Plan missing keys that model expects: {sorted(missing_from_plan)}" assert not extra_in_plan, f"Plan has extra keys model doesn't expect: {sorted(extra_in_plan)}" + + +class TestBiasPlanGeneration: + """Test that surgery plans correctly handle per-layer bias configurations. + + These tests verify that plan_surgery correctly includes/excludes bias + weight mappings based on the per-layer bias settings: + - query_layer.bias.enabled, key_layer.bias.enabled, etc. for attention + - layer_1.bias.enabled, layer_2.bias.enabled for MLP + """ + + @pytest.fixture + def source_config_with_bias(self): + """Source config with Qwen-style bias (QKV enabled, O disabled).""" + return { + "model_type": "apriel2", + "hidden_size": 256, + "vocab_size": 1000, + "decoder": { + "type": "fixed", + "num_blocks": 2, + "block": { + "mixer": { + "type": "attention", + "heads": 8, + "head_groups": 4, + "head_size": 32, + "rotary": {"type": "mistral_1d", "theta": 10000.0}, + # Qwen-style: QKV bias enabled, O bias disabled + "query_layer": {"bias": {"enabled": True}}, + "key_layer": {"bias": {"enabled": True}}, + "value_layer": {"bias": {"enabled": True}}, + "dense_layer": {"bias": {"enabled": False}}, + }, + "mlp": { + "type": "mlp", + "intermediate_size": 512, + "gated": False, + # Per-layer MLP bias: layer_1 enabled, layer_2 disabled + "layer_1": {"bias": {"enabled": True}}, + "layer_2": {"bias": {"enabled": False}}, + }, + "normalization": {"type": "rms_norm", "epsilon": 1e-5}, + }, + }, + } + + def test_plan_includes_enabled_attention_biases(self, source_config_with_bias): + """Surgery plan includes bias mappings for enabled attention biases.""" + from fast_llm_external_models.apriel2.conversion.config import compose_configs + from fast_llm_external_models.apriel2.conversion.converters import plan_surgery + + target_config = compose_configs(source_config_with_bias, { + "decoder": { + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"init": "transfer"}, + }, + }, + "mlp": {"init": "transfer"}, + }, + }, + }) + + plan = plan_surgery(source_config_with_bias, target_config) + mapping_strs = [str(k) for k in plan.mappings.keys()] + + # Should have q_proj.bias, k_proj.bias, v_proj.bias mappings + q_bias = [m for m in mapping_strs if "q_proj.bias" in m] + k_bias = [m for m in mapping_strs if "k_proj.bias" in m] + v_bias = [m for m in mapping_strs if "v_proj.bias" in m] + + assert len(q_bias) > 0, "Should have q_proj.bias mappings" + assert len(k_bias) > 0, "Should have k_proj.bias mappings" + assert len(v_bias) > 0, "Should have v_proj.bias mappings" + + def test_plan_excludes_disabled_attention_biases(self, source_config_with_bias): + """Surgery plan excludes bias mappings for disabled attention biases.""" + from fast_llm_external_models.apriel2.conversion.config import compose_configs + from fast_llm_external_models.apriel2.conversion.converters import plan_surgery + + target_config = compose_configs(source_config_with_bias, { + "decoder": { + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"init": "transfer"}, + }, + }, + "mlp": {"init": "transfer"}, + }, + }, + }) + + plan = plan_surgery(source_config_with_bias, target_config) + mapping_strs = [str(k) for k in plan.mappings.keys()] + + # Should NOT have o_proj.bias mappings (disabled) + o_bias = [m for m in mapping_strs if "o_proj.bias" in m] + assert len(o_bias) == 0, f"Should not have o_proj.bias mappings, found: {o_bias}" + + def test_plan_includes_enabled_mlp_biases(self, source_config_with_bias): + """Surgery plan includes bias mappings for enabled MLP biases.""" + from fast_llm_external_models.apriel2.conversion.config import compose_configs + from fast_llm_external_models.apriel2.conversion.converters import plan_surgery + + target_config = compose_configs(source_config_with_bias, { + "decoder": { + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"init": "transfer"}, + }, + }, + "mlp": {"init": "transfer"}, + }, + }, + }) + + plan = plan_surgery(source_config_with_bias, target_config) + mapping_strs = [str(k) for k in plan.mappings.keys()] + + # Should have up_proj.bias (layer_1) mappings + up_bias = [m for m in mapping_strs if "up_proj.bias" in m] + assert len(up_bias) > 0, "Should have up_proj.bias mappings" + + def test_plan_excludes_disabled_mlp_biases(self, source_config_with_bias): + """Surgery plan excludes bias mappings for disabled MLP biases.""" + from fast_llm_external_models.apriel2.conversion.config import compose_configs + from fast_llm_external_models.apriel2.conversion.converters import plan_surgery + + target_config = compose_configs(source_config_with_bias, { + "decoder": { + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"init": "transfer"}, + }, + }, + "mlp": {"init": "transfer"}, + }, + }, + }) + + plan = plan_surgery(source_config_with_bias, target_config) + mapping_strs = [str(k) for k in plan.mappings.keys()] + + # Should NOT have down_proj.bias (layer_2) mappings + down_bias = [m for m in mapping_strs if "down_proj.bias" in m] + assert len(down_bias) == 0, f"Should not have down_proj.bias mappings, found: {down_bias}" + + def test_plan_random_init_creates_init_expressions_for_bias(self, source_config_with_bias): + """Random init creates Init expressions for bias weights.""" + from fast_llm_external_models.apriel2.conversion.converters import plan_surgery + + # Surgery spec - pass directly to plan_surgery (NOT composed, to preserve init) + surgery = { + "decoder": { + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"init": "transfer"}, + "new_attention": { + "type": "attention", + "init": "random", # This triggers random init + "heads": 8, + "head_groups": 4, + "head_size": 32, + "rotary": {"type": "mistral_1d", "theta": 10000.0}, + "add_linear_biases": True, # All biases enabled + }, + }, + }, + }, + }, + } + + # Pass surgery spec directly - init fields are preserved + plan = plan_surgery(source_config_with_bias, surgery) + + # Check that new_attention biases use Init expressions + new_mixer_bias_keys = [ + k for k in plan.mappings.keys() + if "new_attention" in str(k) and "bias" in str(k) + ] + + assert len(new_mixer_bias_keys) > 0, "Should have bias mappings for new_attention" + + for key in new_mixer_bias_keys: + expr = plan.mappings[key] + assert isinstance(expr, Init), f"{key} should be Init, got {type(expr)}" diff --git a/fast_llm_external_models/tests/test_apriel2/test_integration.py b/fast_llm_external_models/tests/test_apriel2/test_integration.py new file mode 100644 index 000000000..c11302d22 --- /dev/null +++ b/fast_llm_external_models/tests/test_apriel2/test_integration.py @@ -0,0 +1,335 @@ +"""Integration tests for Qwen2 -> Apriel2 -> Fast-LLM conversion pipeline. + +Tests verify the full conversion chain: +1. Qwen2 -> Apriel2 (external module conversion) +2. Apriel2 + Surgery -> Supernet (stochastic mixer creation) +3. Supernet -> Fast-LLM -> Supernet (roundtrip through training format) + +Test Strategy: +- Use real HuggingFace model (Qwen2.5-0.5B) for meaningful validation +- Separate config preservation tests from numerical equivalence tests +- Parameterize both conversion stages AND input variations +- Single test implementation applied across all stages +""" + +import json +import tempfile +from pathlib import Path + +import pytest +import torch + +from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config +from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForCausalLM +from fast_llm_external_models.apriel2.conversion import ( + compose, + compose_configs, + execute, + plan_surgery, +) +from fast_llm_external_models.apriel2.conversion.expr import W +from fast_llm_external_models.apriel2.conversion.qwen2.config import convert_config as convert_qwen2_config +from fast_llm_external_models.apriel2.conversion.qwen2.plan import plan_qwen2_to_apriel2 + +from .conftest import requires_fastllm + + +# ============================================================================= +# Test Input Variations +# ============================================================================= + +TEST_INPUTS = pytest.mark.parametrize( + "prompts,max_new_tokens", + [ + pytest.param(["Hello world"], 10, id="single_short"), + pytest.param(["Hi", "The quick brown fox jumps over the lazy dog"], 20, id="batch_varied"), + pytest.param(["Once upon a time"], 50, id="long_generation"), + ], +) + + +# ============================================================================= +# Conversion Fixtures +# ============================================================================= + + +@pytest.fixture(scope="module") +def qwen2_source(): + """Load Qwen2.5-0.5B as the source/reference model.""" + from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig + + model_name = "Qwen/Qwen2.5-0.5B" + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + model = AutoModelForCausalLM.from_pretrained( + model_name, torch_dtype=torch.float32, trust_remote_code=True + ) + config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + model.eval() + + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "left" + + return { + "model": model, + "tokenizer": tokenizer, + "config_dict": config.to_dict(), + "state_dict": model.state_dict(), + } + + +@pytest.fixture(scope="module") +def apriel2_converted(qwen2_source): + """Stage 1: Qwen2 -> Apriel2.""" + config_dict = convert_qwen2_config(qwen2_source["config_dict"]) + plan = plan_qwen2_to_apriel2(qwen2_source["config_dict"]) + weights = execute(plan, {W(k): v for k, v in qwen2_source["state_dict"].items()}, seed=42) + + config = Apriel2Config(**config_dict) + model = Apriel2ForCausalLM(config) + model.load_state_dict({str(k): v for k, v in weights.items()}, strict=False) + model.eval() + + return {"model": model, "config_dict": config_dict, "plan": plan, "name": "Apriel2"} + + +@pytest.fixture(scope="module") +def supernet_converted(qwen2_source, apriel2_converted): + """Stage 2: Apriel2 + Surgery -> Supernet.""" + surgery_spec = { + "decoder": { + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"type": "attention", "init": "transfer"}, + "sliding_window": { + "type": "attention", + "init": "transfer", + "window_size": 4096, + }, + }, + }, + }, + }, + } + + apriel_config = apriel2_converted["config_dict"] + supernet_config = compose_configs(apriel_config, surgery_spec) + + full_plan = compose( + apriel2_converted["plan"], + plan_surgery(apriel_config, supernet_config), + ) + + weights = execute(full_plan, {W(k): v for k, v in qwen2_source["state_dict"].items()}, seed=42) + + config = Apriel2Config(**supernet_config) + model = Apriel2ForCausalLM(config) + model.load_state_dict({str(k): v for k, v in weights.items()}, strict=False) + model.eval() + + return {"model": model, "config_dict": supernet_config, "name": "Supernet"} + + +@pytest.fixture(scope="module") +def roundtrip_converted(supernet_converted, qwen2_source): + """Stage 3: Supernet -> Fast-LLM -> Supernet.""" + from fast_llm.engine.checkpoint.config import ( + CheckpointLoadConfig, + CheckpointSaveConfig, + FastLLMCheckpointFormat, + ) + from fast_llm.engine.checkpoint.convert import ConvertConfig + from fast_llm.models.gpt.config import GPTModelConfig + from fast_llm.models.gpt.conversion.config import Apriel2TextCheckpointFormat + + with tempfile.TemporaryDirectory() as tmpdir: + tmpdir = Path(tmpdir) + supernet_path = tmpdir / "supernet" + fastllm_path = tmpdir / "fastllm" + roundtrip_path = tmpdir / "roundtrip" + + supernet_converted["model"].save_pretrained(supernet_path) + qwen2_source["tokenizer"].save_pretrained(supernet_path) + + ConvertConfig( + model=GPTModelConfig, + input=CheckpointLoadConfig(path=supernet_path, format=Apriel2TextCheckpointFormat), + output=CheckpointSaveConfig(path=fastllm_path, format=FastLLMCheckpointFormat), + ).run() + + ConvertConfig( + model=GPTModelConfig, + input=CheckpointLoadConfig(path=fastllm_path, format=FastLLMCheckpointFormat), + output=CheckpointSaveConfig(path=roundtrip_path, format=Apriel2TextCheckpointFormat), + ).run() + + model = Apriel2ForCausalLM.from_pretrained(roundtrip_path) + model.eval() + + with open(roundtrip_path / "config.json") as f: + config_dict = json.load(f) + + yield {"model": model, "config_dict": config_dict, "name": "Roundtrip"} + + +# ============================================================================= +# Parameterized Fixture: All Conversion Stages +# ============================================================================= + + +@pytest.fixture(params=["apriel2", "supernet", "roundtrip"]) +def converted_model(request, apriel2_converted, supernet_converted, roundtrip_converted): + """Parameterized fixture providing each conversion stage for testing. + + This allows a single test to run against all stages automatically. + """ + if request.param == "roundtrip": + pytest.importorskip("fast_llm") + + return { + "apriel2": apriel2_converted, + "supernet": supernet_converted, + "roundtrip": roundtrip_converted, + }[request.param] + + +# ============================================================================= +# Config Preservation Tests +# ============================================================================= + + +@pytest.mark.slow +class TestConfigPreservation: + """Verify configs are correctly preserved through the conversion chain.""" + + def test_apriel2_structure(self, qwen2_source, apriel2_converted): + """Qwen2 -> Apriel2 preserves model dimensions.""" + qwen = qwen2_source["config_dict"] + apriel = apriel2_converted["config_dict"] + + assert apriel["hidden_size"] == qwen["hidden_size"] + assert apriel["vocab_size"] == qwen["vocab_size"] + assert apriel["decoder"]["num_blocks"] == qwen["num_hidden_layers"] + + def test_apriel2_bias_pattern(self, apriel2_converted): + """Qwen2 -> Apriel2 preserves Qwen-style bias (QKV yes, O no).""" + mixer = apriel2_converted["config_dict"]["decoder"]["block"]["mixer"] + + assert mixer["query_layer"]["bias"]["enabled"] is True + assert mixer["key_layer"]["bias"]["enabled"] is True + assert mixer["value_layer"]["bias"]["enabled"] is True + assert mixer["dense_layer"]["bias"]["enabled"] is False + + def test_supernet_structure(self, supernet_converted): + """Surgery creates correct stochastic mixer structure.""" + mixer = supernet_converted["config_dict"]["decoder"]["block"]["mixer"] + + assert mixer["type"] == "stochastic" + assert mixer["main_mixer_name"] == "attention" + assert set(mixer["mixers"].keys()) == {"attention", "sliding_window"} + + def test_supernet_bias_inheritance(self, supernet_converted): + """Submixers inherit bias settings from source.""" + mixer = supernet_converted["config_dict"]["decoder"]["block"]["mixer"] + + for name in ["attention", "sliding_window"]: + assert mixer["mixers"][name]["query_layer"]["bias"]["enabled"] is True + assert mixer["mixers"][name]["dense_layer"]["bias"]["enabled"] is False + + @requires_fastllm + def test_roundtrip_structure(self, roundtrip_converted): + """Fast-LLM roundtrip preserves stochastic mixer structure.""" + mixer = roundtrip_converted["config_dict"]["decoder"]["block"]["mixer"] + + assert mixer["type"] == "stochastic" + assert mixer["main_mixer_name"] == "attention" + assert set(mixer["mixers"].keys()) == {"attention", "sliding_window"} + + @requires_fastllm + def test_roundtrip_bias_preservation(self, roundtrip_converted): + """Fast-LLM roundtrip preserves per-layer bias settings.""" + mixer = roundtrip_converted["config_dict"]["decoder"]["block"]["mixer"] + + for name in ["attention", "sliding_window"]: + assert mixer["mixers"][name]["query_layer"]["bias"]["enabled"] is True + assert mixer["mixers"][name]["dense_layer"]["bias"]["enabled"] is False + + +# ============================================================================= +# Numerical Equivalence Tests +# ============================================================================= + + +@pytest.mark.slow +class TestNumericalEquivalence: + """Verify all conversion stages produce numerically identical outputs. + + Uses parameterized fixtures to test all stages with all input variations, + giving us 3 stages Γ— 3 inputs = 9 test cases from a single test function. + """ + + @TEST_INPUTS + def test_logits_match(self, qwen2_source, converted_model, prompts, max_new_tokens): + """Converted model produces identical logits to source.""" + tokenizer = qwen2_source["tokenizer"] + ref_model = qwen2_source["model"] + test_model = converted_model["model"] + stage = converted_model["name"] + + inputs = tokenizer(prompts, return_tensors="pt", padding=True) + ref_device = next(ref_model.parameters()).device + test_device = next(test_model.parameters()).device + + with torch.no_grad(): + ref_logits = ref_model( + input_ids=inputs.input_ids.to(ref_device), + attention_mask=inputs.attention_mask.to(ref_device), + ).logits.cpu() + + test_logits = test_model( + input_ids=inputs.input_ids.to(test_device), + attention_mask=inputs.attention_mask.to(test_device), + ).logits.cpu() + + max_diff = (ref_logits - test_logits).abs().max().item() + assert torch.allclose(ref_logits, test_logits, rtol=1e-4, atol=1e-4), ( + f"{stage} logits mismatch: max diff = {max_diff:.6f}" + ) + + @TEST_INPUTS + def test_generation_match(self, qwen2_source, converted_model, prompts, max_new_tokens): + """Converted model produces identical generation to source.""" + tokenizer = qwen2_source["tokenizer"] + ref_model = qwen2_source["model"] + test_model = converted_model["model"] + stage = converted_model["name"] + + inputs = tokenizer(prompts, return_tensors="pt", padding=True) + ref_device = next(ref_model.parameters()).device + test_device = next(test_model.parameters()).device + + with torch.no_grad(): + ref_gen = ref_model.generate( + input_ids=inputs.input_ids.to(ref_device), + attention_mask=inputs.attention_mask.to(ref_device), + max_new_tokens=max_new_tokens, + do_sample=False, + pad_token_id=tokenizer.pad_token_id, + ).cpu() + + test_gen = test_model.generate( + input_ids=inputs.input_ids.to(test_device), + attention_mask=inputs.attention_mask.to(test_device), + max_new_tokens=max_new_tokens, + do_sample=False, + pad_token_id=tokenizer.pad_token_id, + ).cpu() + + assert torch.equal(ref_gen, test_gen), ( + f"{stage} generation mismatch:\n" + f" Reference: {tokenizer.batch_decode(ref_gen, skip_special_tokens=True)}\n" + f" Test: {tokenizer.batch_decode(test_gen, skip_special_tokens=True)}" + ) diff --git a/fast_llm_external_models/tests/test_apriel2/test_modeling.py b/fast_llm_external_models/tests/test_apriel2/test_modeling.py index 5dbd36159..47c877d09 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_modeling.py +++ b/fast_llm_external_models/tests/test_apriel2/test_modeling.py @@ -12,7 +12,8 @@ class TestApriel2Modeling: "apriel2_config_tiny", "apriel2_config_stochastic", "apriel2_config_multi_mixer", - "apriel2_config_all_mixers" # Tests all 4 mixer types + "apriel2_config_all_mixers", # Tests all 4 mixer types + "apriel2_config_with_bias", # Tests per-layer bias and non-gated MLP ]) def test_model_end_to_end(self, config_name, request): """Test instantiation, forward pass, cache correctness, and generation. diff --git a/fast_llm_external_models/tests/test_apriel2/test_plan_composition_torture.py b/fast_llm_external_models/tests/test_apriel2/test_plan_composition_torture.py index 3b4adc7f5..76a77ccb6 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_plan_composition_torture.py +++ b/fast_llm_external_models/tests/test_apriel2/test_plan_composition_torture.py @@ -1980,3 +1980,151 @@ def test_expand_surgery_chain_preserves_invariant(self): # After cycling and restore, we should be back to the same state assert current_config == config_after_original + + +class TestBiasSurgeryChain: + """Torture tests for per-layer bias inheritance through surgery operations. + + Uses apriel2_config_with_bias + bias_surgery_chain to test that: + - Qwen-style per-layer attention bias (QKV enabled, O disabled) survives surgery + - Non-gated MLP per-layer bias (layer_1 enabled, layer_2 disabled) survives surgery + - Bias settings are correctly inherited by new sub-mixers + - Bias is correctly tracked in surgery plans + """ + + @pytest.fixture + def bias_source_config(self, apriel2_config_with_bias): + """Convert Apriel2Config to dict for surgery operations.""" + return apriel2_config_with_bias.to_dict() + + def test_bias_survives_stochastic_wrapper(self, bias_source_config, bias_surgery_chain): + """Test that bias settings survive wrapping in stochastic mixer.""" + # Apply first surgery (wrap in stochastic) + result = compose_configs(bias_source_config, bias_surgery_chain[0]) + + # Check attention sub-mixer inherited bias settings + mixer = result["decoder"]["block"]["mixer"] + assert mixer["type"] == "stochastic" + + attn_mixer = mixer["mixers"]["attention"] + assert attn_mixer["query_layer"]["bias"]["enabled"] is True + assert attn_mixer["key_layer"]["bias"]["enabled"] is True + assert attn_mixer["value_layer"]["bias"]["enabled"] is True + assert attn_mixer["dense_layer"]["bias"]["enabled"] is False + + # Check MLP bias survived + mlp = result["decoder"]["block"]["mlp"] + assert mlp["layer_1"]["bias"]["enabled"] is True + assert mlp["layer_2"]["bias"]["enabled"] is False + + def test_new_submixer_inherits_bias(self, bias_source_config, bias_surgery_chain): + """Test that new sub-mixers inherit bias from source attention.""" + # Apply S1 + S2 (wrap in stochastic, add sliding_window) + config = bias_source_config + for surgery in bias_surgery_chain[:2]: + config = compose_configs(config, surgery) + + # sliding_window should inherit bias from source attention + mixer = config["decoder"]["block"]["mixer"] + sw_mixer = mixer["mixers"]["sliding_window"] + + assert sw_mixer["query_layer"]["bias"]["enabled"] is True + assert sw_mixer["key_layer"]["bias"]["enabled"] is True + assert sw_mixer["value_layer"]["bias"]["enabled"] is True + assert sw_mixer["dense_layer"]["bias"]["enabled"] is False + + def test_full_bias_chain_produces_valid_config(self, bias_source_config, bias_surgery_chain): + """Test that full bias surgery chain produces valid config.""" + config = bias_source_config + for surgery in bias_surgery_chain: + config = compose_configs(config, surgery) + + # Verify final config structure + mixer = config["decoder"]["block"]["mixer"] + assert mixer["type"] == "stochastic" + assert "attention" in mixer["mixers"] + assert "sliding_window" in mixer["mixers"] + assert "full_bias_attn" in mixer["mixers"] + + # attention and sliding_window inherit Qwen-style bias + for name in ["attention", "sliding_window"]: + sub = mixer["mixers"][name] + assert sub["query_layer"]["bias"]["enabled"] is True + assert sub["dense_layer"]["bias"]["enabled"] is False + + # full_bias_attn has add_linear_biases=True but per-layer settings inherited from + # source take precedence, so O bias is still disabled + full_bias = mixer["mixers"]["full_bias_attn"] + assert full_bias.get("add_linear_biases") is True + # Per-layer dense_layer.bias.enabled=False inherited from source takes precedence + assert full_bias["dense_layer"]["bias"]["enabled"] is False + + def test_bias_plan_has_correct_mappings(self, bias_source_config, bias_surgery_chain): + """Test that surgery plan correctly includes/excludes bias weight mappings.""" + # Compose config first to get full target config with inherited bias settings + target_config = compose_configs(bias_source_config, bias_surgery_chain[0]) + plan = plan_surgery(bias_source_config, target_config) + mapping_strs = [str(k) for k in plan.mappings.keys()] + + # Should have q_proj.bias (enabled) + q_bias = [m for m in mapping_strs if "q_proj.bias" in m] + assert len(q_bias) > 0, "Should have q_proj.bias mappings" + + # Should NOT have o_proj.bias (disabled) + o_bias = [m for m in mapping_strs if "o_proj.bias" in m] + assert len(o_bias) == 0, "Should not have o_proj.bias mappings" + + # Should have up_proj.bias (layer_1 enabled) + up_bias = [m for m in mapping_strs if "up_proj.bias" in m] + assert len(up_bias) > 0, "Should have up_proj.bias mappings" + + # Should NOT have down_proj.bias (layer_2 disabled) + down_bias = [m for m in mapping_strs if "down_proj.bias" in m] + assert len(down_bias) == 0, "Should not have down_proj.bias mappings" + + def test_bias_chain_produces_working_model(self, bias_source_config, bias_surgery_chain): + """Test that bias surgery chain produces a working model.""" + from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForCausalLM + + # Apply full chain + config = bias_source_config + for surgery in bias_surgery_chain: + config = compose_configs(config, surgery) + + # Create model + apriel_config = Apriel2Config(**config) + model = Apriel2ForCausalLM(apriel_config) + model.eval() + + # Verify model structure has correct biases + block = model.model.decoder.blocks[0] + + # attention sub-mixer should have QKV bias, no O bias + attn = block.mixer.mixers["attention"] + assert attn.q_proj.bias is not None + assert attn.k_proj.bias is not None + assert attn.v_proj.bias is not None + assert attn.o_proj.bias is None + + # sliding_window should also inherit bias settings + sw = block.mixer.mixers["sliding_window"] + assert sw.q_proj.bias is not None + assert sw.o_proj.bias is None + + # full_bias_attn inherits per-layer bias from source (even with add_linear_biases=True, + # per-layer settings take precedence in same-type inheritance) + full_bias = block.mixer.mixers["full_bias_attn"] + assert full_bias.q_proj.bias is not None + # O bias is disabled because inherited per-layer dense_layer.bias.enabled=False + # takes precedence over add_linear_biases=True + assert full_bias.o_proj.bias is None + + # MLP should have layer_1 bias, no layer_2 bias + assert block.mlp.up_proj.bias is not None + assert block.mlp.down_proj.bias is None + + # Forward pass should work + input_ids = torch.randint(0, config["vocab_size"], (1, 10)) + with torch.no_grad(): + outputs = model(input_ids, use_cache=False) + assert outputs.logits.shape == (1, 10, config["vocab_size"]) From 7053d8cdf7ec941d84233fd4aa89be3ebf04b645 Mon Sep 17 00:00:00 2001 From: Torsten Scholak Date: Mon, 15 Dec 2025 19:30:40 +0000 Subject: [PATCH 05/12] Add conversation format support for SFT data preparation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Enable automatic loss masking span computation for chat/conversation datasets using HuggingFace's {% generation %}...{% endgeneration %} markers. This allows preparing SFT data (e.g., Tulu 3) with proper masking of non-assistant content. - Add ConversationSourceConfig with `type: conversation` for chat data - Add validate_chat_template() to verify tokenizer has generation markers - Add apply_chat_template_with_spans() for text + masking span extraction - Tokenizer must have built-in chat template with generation markers πŸ€– Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- fast_llm/data/preparator/gpt_memmap/config.py | 106 ++++++++++++++++- .../data/preparator/gpt_memmap/prepare.py | 20 +++- fast_llm/data/preprocessing/tokenizer.py | 108 ++++++++++++++++++ tests/data/test_preparator.py | 31 ++++- tests/data/test_tokenizer.py | 93 +++++++++++---- 5 files changed, 333 insertions(+), 25 deletions(-) diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 503b400c3..2aa0fbf31 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -15,11 +15,14 @@ from fast_llm.data.preparator.gpt_memmap.prepare import GPTMemmapDatasetPreparator -@config_class() +@config_class(registry=True) class LanguageModelSourceConfig(Config): """ A schema holding the name of each relevant column in the dataset. Setting optional entries will enable the associated feature. + + This is the base class for source schemas. Use `type: text` (default) for + plain text datasets, or `type: conversation` for chat/conversation datasets. """ text: str = Field( @@ -48,6 +51,8 @@ def columns(self) -> list[str]: columns.append(self.loss_masking_spans) if self.has_preference_spans: columns.extend([self.chosen_span, self.rejected_span]) + if self.has_images: + columns.extend([self.images, self.image_positions]) return columns @functools.cached_property @@ -64,12 +69,111 @@ def has_images(self) -> bool: Assert.eq(self.images is None, self.image_positions is None) return self.images is not None + @functools.cached_property + def has_conversation(self) -> bool: + """Whether this is a conversation source schema.""" + return False + def _validate(self): super()._validate() if self.has_preference_spans and self.has_loss_masking_span: raise ValueError(f"Can not enable both loss masking and preference spans.") +@config_class(dynamic_type={LanguageModelSourceConfig: "text"}) +class TextSourceConfig(LanguageModelSourceConfig): + """ + Source schema for plain text datasets (default). + + The dataset should have a text column containing the document text. + Optionally, it can have additional columns for loss masking spans, + preference spans (for DPO), or images. + """ + + pass + + +@config_class(dynamic_type={LanguageModelSourceConfig: "conversation"}) +class ConversationSourceConfig(LanguageModelSourceConfig): + """ + Source schema for chat/conversation datasets (e.g., Tulu 3, ShareGPT, OpenAI format). + + The dataset should have a messages column containing a list of message dicts, + where each message has 'role' and 'content' keys. Common roles include: + - 'system': System prompt + - 'user': User input + - 'assistant': Model response (trained on by default) + - 'tool': Tool/function results + - 'ipython': Code execution results + + The conversation is formatted using the tokenizer's chat template, which must + contain {% generation %}...{% endgeneration %} markers to define which content + to train on. Loss masking spans are automatically computed from these markers. + + Example dataset format: + { + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello!"}, + {"role": "assistant", "content": "Hi there!"}, + ] + } + """ + + # Override text field - not used directly for conversation format + text: None | str = Field( + default=None, + desc="Not used for conversation format. Text is generated from messages.", + hint=FieldHint.optional, + ) + + # Conversation-specific fields + messages: str = Field( + default="messages", + desc="Field containing the conversation messages list. Each message should have 'role' and 'content' keys.", + hint=FieldHint.core, + ) + + add_generation_prompt: bool = Field( + default=False, + desc="Whether to add a generation prompt at the end of the conversation. " + "Typically False for training data.", + hint=FieldHint.optional, + ) + + @functools.cached_property + def columns(self) -> list[str]: + # For conversation format, we read the messages column, not text + columns = [self.messages] + # Images can still be used with conversation format + if self.has_images: + columns.extend([self.images, self.image_positions]) + return columns + + @functools.cached_property + def has_conversation(self) -> bool: + return True + + @functools.cached_property + def has_loss_masking_span(self) -> bool: + # Conversation format always generates loss masking spans + return True + + def _validate(self): + # Skip parent validation that checks text field + Config._validate(self) + if self.has_preference_spans: + raise ValueError("Preference spans are not supported with conversation format.") + if self.has_images: + # Images with conversation format would require computing image positions in the + # chat-template-formatted text, which is complex and format-dependent. + # For VLM training with conversations, preprocess the data to plain text format first. + raise ValueError( + "Images are not yet supported with conversation format. " + "For multimodal conversation data, preprocess to plain text format with image positions." + ) + + @config_class() class GPTHuggingfaceDatasetConfig(Config): path: str | pathlib.Path = Field( diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 2ea81d8a6..f349b1979 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -132,6 +132,10 @@ def run(self) -> None: # Load tokenizer self._tokenizer = self._config.tokenizer.get_tokenizer() + # Validate chat template for conversation format + if self._source_schema.has_conversation: + self._tokenizer.validate_chat_template() + # Decide the datatype based on the tokenizer vocabulary size self._data_type = ( get_unsigned_integer_type(self._tokenizer.vocab_size) @@ -216,9 +220,21 @@ def _preprocessing_config(self) -> LanguageModelPreprocessingConfig: ) def _prepare_sample(self, sample: dict[str, typing.Any]) -> LanguageModelSample: - text = sample[self._source_schema.text] all_spans = [] - if self._source_schema.has_loss_masking_span: + + if self._source_schema.has_conversation: + # Conversation format: apply chat template and compute loss masking spans + messages = sample[self._source_schema.messages] + text, loss_masking_spans = self._tokenizer.apply_chat_template_with_spans( + messages, + add_generation_prompt=self._source_schema.add_generation_prompt, + ) + all_spans.extend([(SpanType.loss_masking, span) for span in loss_masking_spans]) + else: + # Plain text format + text = sample[self._source_schema.text] + + if self._source_schema.has_loss_masking_span and not self._source_schema.has_conversation: # Spans are typically stored in the (begin, last) format. We convert to (begin, end) range format. loss_masking_spans = _sort_spans( (SpanType.loss_masking, (begin, last + 1)) diff --git a/fast_llm/data/preprocessing/tokenizer.py b/fast_llm/data/preprocessing/tokenizer.py index abfb5b3d2..924dc64b2 100644 --- a/fast_llm/data/preprocessing/tokenizer.py +++ b/fast_llm/data/preprocessing/tokenizer.py @@ -213,3 +213,111 @@ def _remove_delimiters( @property def eod(self): return self.eod_id + + @staticmethod + def _has_generation_markers(template: str | None) -> bool: + """Check if a template has generation markers.""" + return template is not None and "{% generation %}" in template + + def validate_chat_template(self) -> None: + """ + Validate the tokenizer's chat template has generation markers. + + Raises: + ValueError: If the tokenizer lacks a chat template or generation markers. + """ + template = self.tokenizer.chat_template + + if template is None: + raise ValueError( + "Tokenizer does not have a chat template. " + "Conversation format requires a tokenizer with a built-in chat template " + "containing {% generation %}...{% endgeneration %} markers." + ) + + if not self._has_generation_markers(template): + raise ValueError( + "Tokenizer's chat template does not contain {% generation %}...{% endgeneration %} markers. " + "These markers are required to determine which tokens to train on. " + "Please use a tokenizer with generation markers in its chat template." + ) + + def apply_chat_template_with_spans( + self, + messages: list[dict[str, str]], + *, + add_generation_prompt: bool = False, + ) -> tuple[str, list[tuple[int, int]]]: + """ + Apply the tokenizer's chat template to messages and compute loss masking spans. + + This method converts a list of messages (OpenAI/Tulu format) into formatted + text and computes character-level spans that should be MASKED (not trained on). + + Note: Call validate_chat_template() once before using this method to ensure + the tokenizer has a valid chat template with generation markers. + + Args: + messages: List of message dicts with 'role' and 'content' keys. + add_generation_prompt: Whether to add a generation prompt at the end. + + Returns: + Tuple of (formatted_text, loss_masking_spans) where loss_masking_spans + is a list of (start, end) character positions to MASK (not train on). + """ + if not messages: + return "", [] + + return self._apply_chat_template(messages, add_generation_prompt) + + def _apply_chat_template( + self, + messages: list[dict[str, str]], + add_generation_prompt: bool, + ) -> tuple[str, list[tuple[int, int]]]: + """Use HF's return_assistant_tokens_mask for precise token-level masking.""" + # Get tokens and assistant mask + result = self.tokenizer.apply_chat_template( + messages, + tokenize=True, + return_assistant_tokens_mask=True, + return_dict=True, + add_generation_prompt=add_generation_prompt, + ) + + tokens = result["input_ids"] + train_mask = result["assistant_masks"] + + # Get text for output + full_text = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=add_generation_prompt, + ) + + # Convert token mask to character spans using detokenization + # We need spans for tokens where train_mask=0 (should be masked/not trained on) + loss_masking_spans = [] + current_span_start = None + + # Track character positions by decoding incrementally + char_positions = [0] + for i in range(len(tokens)): + decoded = self.tokenizer.decode(tokens[: i + 1]) + char_positions.append(len(decoded)) + + for i, is_train in enumerate(train_mask): + if not is_train: # This token should be masked + if current_span_start is None: + current_span_start = char_positions[i] + else: # This token should be trained on + if current_span_start is not None: + loss_masking_spans.append((current_span_start, char_positions[i])) + current_span_start = None + + # Close any open span + if current_span_start is not None: + loss_masking_spans.append((current_span_start, char_positions[-1])) + + return full_text, loss_masking_spans + diff --git a/tests/data/test_preparator.py b/tests/data/test_preparator.py index f4f6fab82..ccef94d03 100644 --- a/tests/data/test_preparator.py +++ b/tests/data/test_preparator.py @@ -6,7 +6,11 @@ from fast_llm.data.dataset.config import BlendedDatasetConfig, MemmapDatasetConfig, SamplingParameters from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig from fast_llm.data.dataset.memmap import MemmapDataset -from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig +from fast_llm.data.preparator.gpt_memmap.config import ( + ConversationSourceConfig, + GPTMemmapDatasetPreparatorConfig, + LanguageModelSourceConfig, +) from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.data.preprocessing.tokenizer import TokenizerConfig from fast_llm.utils import Assert @@ -198,3 +202,28 @@ def test_dataset_preparator_from_hub(): tokenizer.detokenize(dataset.get_document(index).tokens.tokens), f"<|endoftext|>{hf_dataset[index]["answer"]}<|endoftext|>", ) + + +# ============================================================================= +# Conversation Format Tests +# ============================================================================= + + +def test_conversation_source_config(): + """Test conversation source schema configuration.""" + config = LanguageModelSourceConfig.from_dict({"type": "conversation"}) + Assert.custom(isinstance, config, ConversationSourceConfig) + Assert.eq(config.messages, "messages") + Assert.eq(config.has_conversation, True) + Assert.eq(config.has_loss_masking_span, True) + Assert.eq(config.columns, ["messages"]) + + +def test_conversation_config_validation(): + """Test conversation config validation errors.""" + with pytest.raises(ValueError, match="Images are not yet supported"): + LanguageModelSourceConfig.from_dict({ + "type": "conversation", + "images": "images", + "image_positions": "positions", + }) diff --git a/tests/data/test_tokenizer.py b/tests/data/test_tokenizer.py index c7fdef9ca..b7e1d3e9b 100644 --- a/tests/data/test_tokenizer.py +++ b/tests/data/test_tokenizer.py @@ -1,42 +1,93 @@ import pytest -from fast_llm.data.preprocessing.tokenizer import Tokenizer, TokenizerConfig +from fast_llm.data.preprocessing.tokenizer import TokenizerConfig from fast_llm.utils import Assert from tests.utils.dataset import download_santacoder_tokenizer from tests.utils.global_variables import TOKENIZER_PATH +TEXT = "hello world" + -@pytest.fixture(scope="session") -def common_tokenizer() -> Tokenizer: +@pytest.fixture(scope="module") +def tokenizer(): download_santacoder_tokenizer() return TokenizerConfig(path=TOKENIZER_PATH).get_tokenizer() -TEXT = "hello world" - - @pytest.mark.parametrize("extra_tokens", (False, True)) @pytest.mark.parametrize( ("spans", "expected_token_spans", "expected_tokens"), ( - ([], [], [7196, 5297]), # No span - ([(1, 3)], [(1, 2)], [71, 325, 303, 5297]), # Simple span - ([(2, 2)], [(1, 1)], [284, 47443, 5297]), # Empty span - ([(0, 11)], [(0, 2)], [7196, 5297]), # Full span - ([(1, 4), (6, 7)], [(1, 2), (4, 5)], [71, 1498, 78, 207, 86, 2231]), # Two spans - ([(1, 6), (4, 7)], [(1, 4), (2, 5)], [71, 1498, 78, 207, 86, 2231]), # Overlapping spans - ([(1, 7), (4, 6)], [(1, 5), (2, 4)], [71, 1498, 78, 207, 86, 2231]), # Nested spans - ([(1, 5), (5, 7)], [(1, 3), (3, 4)], [71, 325, 303, 365, 2231]), # Consecutive spans - ([(2, 4), (2, 4)], [(1, 2), (1, 2)], [284, 683, 78, 5297]), # Duplicate spans - ([(2, 3), (5, 8), (9, 11)], [(1, 2), (3, 4), (5, 6)], [284, 75, 303, 48485, 81, 1382]), # Three spans + ([], [], [7196, 5297]), + ([(1, 3)], [(1, 2)], [71, 325, 303, 5297]), + ([(2, 2)], [(1, 1)], [284, 47443, 5297]), + ([(0, 11)], [(0, 2)], [7196, 5297]), + ([(1, 4), (6, 7)], [(1, 2), (4, 5)], [71, 1498, 78, 207, 86, 2231]), + ([(1, 6), (4, 7)], [(1, 4), (2, 5)], [71, 1498, 78, 207, 86, 2231]), + ([(1, 7), (4, 6)], [(1, 5), (2, 4)], [71, 1498, 78, 207, 86, 2231]), + ([(1, 5), (5, 7)], [(1, 3), (3, 4)], [71, 325, 303, 365, 2231]), + ([(2, 4), (2, 4)], [(1, 2), (1, 2)], [284, 683, 78, 5297]), + ([(2, 3), (5, 8), (9, 11)], [(1, 2), (3, 4), (5, 6)], [284, 75, 303, 48485, 81, 1382]), ), ) -def test_tokenize_with_spans(common_tokenizer, spans, expected_token_spans, expected_tokens, extra_tokens): - tokens, token_spans = common_tokenizer.tokenize_with_spans( - TEXT, begin=extra_tokens, end=extra_tokens, text_spans=spans - ) +def test_tokenize_with_spans(tokenizer, spans, expected_token_spans, expected_tokens, extra_tokens): + tokens, token_spans = tokenizer.tokenize_with_spans(TEXT, begin=extra_tokens, end=extra_tokens, text_spans=spans) if extra_tokens: - expected_tokens = [common_tokenizer.bod_id, *expected_tokens, common_tokenizer.eod_id] + expected_tokens = [tokenizer.bod_id, *expected_tokens, tokenizer.eod_id] expected_token_spans = [(begin + 1, end + 1) for begin, end in expected_token_spans] Assert.eq(tokens.tolist(), expected_tokens) Assert.eq(token_spans, expected_token_spans) + + +def test_validate_chat_template_no_template(tokenizer): + """Tokenizer without chat template raises.""" + with pytest.raises(ValueError, match="does not have a chat template"): + tokenizer.validate_chat_template() + + +def test_validate_chat_template_no_markers(tokenizer): + """Tokenizer with chat template but no markers raises.""" + tokenizer.tokenizer.chat_template = "{{ messages }}" + with pytest.raises(ValueError, match="does not contain.*generation"): + tokenizer.validate_chat_template() + + +def test_validate_chat_template_with_markers(tokenizer): + """Tokenizer with generation markers validates.""" + tokenizer.tokenizer.chat_template = "{% generation %}{{ m }}{% endgeneration %}" + tokenizer.validate_chat_template() + + +CHAT_TEMPLATE = ( + "{% for message in messages %}" + "{% if message.role == 'assistant' %}" + "{% generation %}{{ message.content }}{% endgeneration %}" + "{% else %}" + "<{{ message.role }}>{{ message.content }}" + "{% endif %}" + "{% endfor %}" +) + + +@pytest.mark.parametrize( + ("messages", "expected_text", "expected_spans"), + ( + ([], "", []), + ( + [{"role": "user", "content": "Hi"}, {"role": "assistant", "content": "Hello"}], + "HiHello", + [(0, 26), (31, 43)], + ), + ( + [{"role": "user", "content": "A"}, {"role": "assistant", "content": "B"}, {"role": "user", "content": "C"}, {"role": "assistant", "content": "D"}], + "ABCD", + [(0, 25), (26, 63), (64, 76)], + ), + ), +) +def test_apply_chat_template_with_spans(tokenizer, messages, expected_text, expected_spans): + """Chat template produces correct text and masking spans.""" + tokenizer.tokenizer.chat_template = CHAT_TEMPLATE + text, spans = tokenizer.apply_chat_template_with_spans(messages) + Assert.eq(text, expected_text) + Assert.eq(spans, expected_spans) From 53d657069e46f6a830aef86ce852b2be79fa203a Mon Sep 17 00:00:00 2001 From: Torsten Scholak Date: Mon, 15 Dec 2025 20:25:48 +0000 Subject: [PATCH 06/12] Cleanup: remove private method indirection, revert test changes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Inline _apply_chat_template into apply_chat_template_with_spans - Revert unnecessary test refactoring in test_tokenizer.py - Remove trivial config tests from test_preparator.py πŸ€– Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- fast_llm/data/preprocessing/tokenizer.py | 9 ---- tests/data/test_preparator.py | 31 +----------- tests/data/test_tokenizer.py | 61 +++++++++++++----------- 3 files changed, 33 insertions(+), 68 deletions(-) diff --git a/fast_llm/data/preprocessing/tokenizer.py b/fast_llm/data/preprocessing/tokenizer.py index 924dc64b2..372d8cd90 100644 --- a/fast_llm/data/preprocessing/tokenizer.py +++ b/fast_llm/data/preprocessing/tokenizer.py @@ -267,15 +267,6 @@ def apply_chat_template_with_spans( """ if not messages: return "", [] - - return self._apply_chat_template(messages, add_generation_prompt) - - def _apply_chat_template( - self, - messages: list[dict[str, str]], - add_generation_prompt: bool, - ) -> tuple[str, list[tuple[int, int]]]: - """Use HF's return_assistant_tokens_mask for precise token-level masking.""" # Get tokens and assistant mask result = self.tokenizer.apply_chat_template( messages, diff --git a/tests/data/test_preparator.py b/tests/data/test_preparator.py index ccef94d03..f4f6fab82 100644 --- a/tests/data/test_preparator.py +++ b/tests/data/test_preparator.py @@ -6,11 +6,7 @@ from fast_llm.data.dataset.config import BlendedDatasetConfig, MemmapDatasetConfig, SamplingParameters from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig from fast_llm.data.dataset.memmap import MemmapDataset -from fast_llm.data.preparator.gpt_memmap.config import ( - ConversationSourceConfig, - GPTMemmapDatasetPreparatorConfig, - LanguageModelSourceConfig, -) +from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.data.preprocessing.tokenizer import TokenizerConfig from fast_llm.utils import Assert @@ -202,28 +198,3 @@ def test_dataset_preparator_from_hub(): tokenizer.detokenize(dataset.get_document(index).tokens.tokens), f"<|endoftext|>{hf_dataset[index]["answer"]}<|endoftext|>", ) - - -# ============================================================================= -# Conversation Format Tests -# ============================================================================= - - -def test_conversation_source_config(): - """Test conversation source schema configuration.""" - config = LanguageModelSourceConfig.from_dict({"type": "conversation"}) - Assert.custom(isinstance, config, ConversationSourceConfig) - Assert.eq(config.messages, "messages") - Assert.eq(config.has_conversation, True) - Assert.eq(config.has_loss_masking_span, True) - Assert.eq(config.columns, ["messages"]) - - -def test_conversation_config_validation(): - """Test conversation config validation errors.""" - with pytest.raises(ValueError, match="Images are not yet supported"): - LanguageModelSourceConfig.from_dict({ - "type": "conversation", - "images": "images", - "image_positions": "positions", - }) diff --git a/tests/data/test_tokenizer.py b/tests/data/test_tokenizer.py index b7e1d3e9b..4b8f45d8d 100644 --- a/tests/data/test_tokenizer.py +++ b/tests/data/test_tokenizer.py @@ -1,61 +1,64 @@ import pytest -from fast_llm.data.preprocessing.tokenizer import TokenizerConfig +from fast_llm.data.preprocessing.tokenizer import Tokenizer, TokenizerConfig from fast_llm.utils import Assert from tests.utils.dataset import download_santacoder_tokenizer from tests.utils.global_variables import TOKENIZER_PATH -TEXT = "hello world" - -@pytest.fixture(scope="module") -def tokenizer(): +@pytest.fixture(scope="session") +def common_tokenizer() -> Tokenizer: download_santacoder_tokenizer() return TokenizerConfig(path=TOKENIZER_PATH).get_tokenizer() +TEXT = "hello world" + + @pytest.mark.parametrize("extra_tokens", (False, True)) @pytest.mark.parametrize( ("spans", "expected_token_spans", "expected_tokens"), ( - ([], [], [7196, 5297]), - ([(1, 3)], [(1, 2)], [71, 325, 303, 5297]), - ([(2, 2)], [(1, 1)], [284, 47443, 5297]), - ([(0, 11)], [(0, 2)], [7196, 5297]), - ([(1, 4), (6, 7)], [(1, 2), (4, 5)], [71, 1498, 78, 207, 86, 2231]), - ([(1, 6), (4, 7)], [(1, 4), (2, 5)], [71, 1498, 78, 207, 86, 2231]), - ([(1, 7), (4, 6)], [(1, 5), (2, 4)], [71, 1498, 78, 207, 86, 2231]), - ([(1, 5), (5, 7)], [(1, 3), (3, 4)], [71, 325, 303, 365, 2231]), - ([(2, 4), (2, 4)], [(1, 2), (1, 2)], [284, 683, 78, 5297]), - ([(2, 3), (5, 8), (9, 11)], [(1, 2), (3, 4), (5, 6)], [284, 75, 303, 48485, 81, 1382]), + ([], [], [7196, 5297]), # No span + ([(1, 3)], [(1, 2)], [71, 325, 303, 5297]), # Simple span + ([(2, 2)], [(1, 1)], [284, 47443, 5297]), # Empty span + ([(0, 11)], [(0, 2)], [7196, 5297]), # Full span + ([(1, 4), (6, 7)], [(1, 2), (4, 5)], [71, 1498, 78, 207, 86, 2231]), # Two spans + ([(1, 6), (4, 7)], [(1, 4), (2, 5)], [71, 1498, 78, 207, 86, 2231]), # Overlapping spans + ([(1, 7), (4, 6)], [(1, 5), (2, 4)], [71, 1498, 78, 207, 86, 2231]), # Nested spans + ([(1, 5), (5, 7)], [(1, 3), (3, 4)], [71, 325, 303, 365, 2231]), # Consecutive spans + ([(2, 4), (2, 4)], [(1, 2), (1, 2)], [284, 683, 78, 5297]), # Duplicate spans + ([(2, 3), (5, 8), (9, 11)], [(1, 2), (3, 4), (5, 6)], [284, 75, 303, 48485, 81, 1382]), # Three spans ), ) -def test_tokenize_with_spans(tokenizer, spans, expected_token_spans, expected_tokens, extra_tokens): - tokens, token_spans = tokenizer.tokenize_with_spans(TEXT, begin=extra_tokens, end=extra_tokens, text_spans=spans) +def test_tokenize_with_spans(common_tokenizer, spans, expected_token_spans, expected_tokens, extra_tokens): + tokens, token_spans = common_tokenizer.tokenize_with_spans( + TEXT, begin=extra_tokens, end=extra_tokens, text_spans=spans + ) if extra_tokens: - expected_tokens = [tokenizer.bod_id, *expected_tokens, tokenizer.eod_id] + expected_tokens = [common_tokenizer.bod_id, *expected_tokens, common_tokenizer.eod_id] expected_token_spans = [(begin + 1, end + 1) for begin, end in expected_token_spans] Assert.eq(tokens.tolist(), expected_tokens) Assert.eq(token_spans, expected_token_spans) -def test_validate_chat_template_no_template(tokenizer): +def test_validate_chat_template_no_template(common_tokenizer): """Tokenizer without chat template raises.""" with pytest.raises(ValueError, match="does not have a chat template"): - tokenizer.validate_chat_template() + common_tokenizer.validate_chat_template() -def test_validate_chat_template_no_markers(tokenizer): +def test_validate_chat_template_no_markers(common_tokenizer): """Tokenizer with chat template but no markers raises.""" - tokenizer.tokenizer.chat_template = "{{ messages }}" + common_tokenizer.tokenizer.chat_template = "{{ messages }}" with pytest.raises(ValueError, match="does not contain.*generation"): - tokenizer.validate_chat_template() + common_tokenizer.validate_chat_template() -def test_validate_chat_template_with_markers(tokenizer): +def test_validate_chat_template_with_markers(common_tokenizer): """Tokenizer with generation markers validates.""" - tokenizer.tokenizer.chat_template = "{% generation %}{{ m }}{% endgeneration %}" - tokenizer.validate_chat_template() + common_tokenizer.tokenizer.chat_template = "{% generation %}{{ m }}{% endgeneration %}" + common_tokenizer.validate_chat_template() CHAT_TEMPLATE = ( @@ -85,9 +88,9 @@ def test_validate_chat_template_with_markers(tokenizer): ), ), ) -def test_apply_chat_template_with_spans(tokenizer, messages, expected_text, expected_spans): +def test_apply_chat_template_with_spans(common_tokenizer, messages, expected_text, expected_spans): """Chat template produces correct text and masking spans.""" - tokenizer.tokenizer.chat_template = CHAT_TEMPLATE - text, spans = tokenizer.apply_chat_template_with_spans(messages) + common_tokenizer.tokenizer.chat_template = CHAT_TEMPLATE + text, spans = common_tokenizer.apply_chat_template_with_spans(messages) Assert.eq(text, expected_text) Assert.eq(spans, expected_spans) From d053d47d41abb23cdb729f01b02faad6fde7433f Mon Sep 17 00:00:00 2001 From: Torsten Scholak Date: Tue, 16 Dec 2025 12:28:39 +0000 Subject: [PATCH 07/12] Refactor test organization: rename modules and remove duplication MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Rename test_plan_composition_torture.py β†’ test_conversion_e2e.py (reflects actual purpose: end-to-end integration tests) - Rename test_algebraic_properties.py β†’ test_plan_execution.py (clearer: tests plan execution and algebraic composition laws) - Remove stale NOTE comments referencing deleted tests - Fix fixture naming collision: attention_config β†’ attention_config_dict in TestMarkovianProperty to avoid shadowing conftest fixtures - Consolidate shared fixtures in conftest.py Test organization now follows clear separation: - test_compose_configs.py: Config dict composition (structure/completeness) - test_plan_execution.py: Plan execution (weight transfer/correctness) - test_conversion_e2e.py: Full pipeline integration tests πŸ€– Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- .../tests/test_apriel2/conftest.py | 142 +++++ .../test_apriel2/test_compose_configs.py | 261 +++----- ...tion_torture.py => test_conversion_e2e.py} | 342 +--------- .../tests/test_apriel2/test_plan_execution.py | 597 ++++++++++++++++++ 4 files changed, 833 insertions(+), 509 deletions(-) rename fast_llm_external_models/tests/test_apriel2/{test_plan_composition_torture.py => test_conversion_e2e.py} (84%) create mode 100644 fast_llm_external_models/tests/test_apriel2/test_plan_execution.py diff --git a/fast_llm_external_models/tests/test_apriel2/conftest.py b/fast_llm_external_models/tests/test_apriel2/conftest.py index cf190b50a..320813747 100644 --- a/fast_llm_external_models/tests/test_apriel2/conftest.py +++ b/fast_llm_external_models/tests/test_apriel2/conftest.py @@ -1680,6 +1680,148 @@ def torture_surgery_chain(): ] +# ============================================================================= +# Shared Config Dict Fixtures (for compose_configs / plan_surgery tests) +# ============================================================================= + + +@pytest.fixture +def base_config_dict(): + """Complete Apriel2 config dict without biases (Llama-style). + + Use this as the base config for testing compose_configs and plan_surgery. + Returns a dict (not Apriel2Config) for direct use with compose_configs. + """ + return { + "model_type": "apriel2", + "hidden_size": 256, + "vocab_size": 1000, + "bos_token_id": 1, + "eos_token_id": 2, + "tie_word_embeddings": False, + "decoder": { + "type": "fixed", + "num_blocks": 2, + "block": { + "mixer": { + "type": "attention", + "heads": 8, + "head_groups": 4, + "head_size": 32, + "rotary": {"type": "mistral_1d", "theta": 10000.0}, + }, + "mlp": {"type": "mlp", "intermediate_size": 512, "gated": True}, + "normalization": {"type": "rms_norm", "epsilon": 1e-5}, + }, + }, + } + + +@pytest.fixture +def base_config_with_bias_dict(): + """Complete Apriel2 config dict with Qwen-style biases. + + - QKV bias enabled, O bias disabled + - Gated MLP (no per-layer bias control in this style) + + Use this for testing bias inheritance through surgery operations. + Returns a dict (not Apriel2Config) for direct use with compose_configs. + """ + return { + "model_type": "apriel2", + "hidden_size": 256, + "vocab_size": 1000, + "bos_token_id": 1, + "eos_token_id": 2, + "tie_word_embeddings": False, + "decoder": { + "type": "fixed", + "num_blocks": 2, + "block": { + "mixer": { + "type": "attention", + "heads": 8, + "head_groups": 4, + "head_size": 32, + "rotary": {"type": "mistral_1d", "theta": 10000.0}, + "query_layer": {"bias": {"enabled": True}}, + "key_layer": {"bias": {"enabled": True}}, + "value_layer": {"bias": {"enabled": True}}, + "dense_layer": {"bias": {"enabled": False}}, + }, + "mlp": {"type": "mlp", "intermediate_size": 512, "gated": True}, + "normalization": {"type": "rms_norm", "epsilon": 1e-5}, + }, + }, + } + + +def make_weights_for_config(config: dict) -> dict: + """Create random weights matching a config's expected schema. + + This is a helper function (not a fixture) for creating test weights. + Use it in tests that need weights for plan execution. + + Args: + config: Complete Apriel2 config dict + + Returns: + Dict mapping weight key strings to torch tensors + """ + from fast_llm_external_models.apriel2.conversion import W + + hidden = config["hidden_size"] + vocab = config["vocab_size"] + decoder = config["decoder"] + num_blocks = decoder["num_blocks"] + block = decoder["block"] + mixer = block["mixer"] + mlp = block["mlp"] + + heads = mixer["heads"] + head_groups = mixer["head_groups"] + head_size = mixer["head_size"] + inter = mlp["intermediate_size"] + + # Check bias settings + has_q_bias = mixer.get("query_layer", {}).get("bias", {}).get("enabled", False) + has_k_bias = mixer.get("key_layer", {}).get("bias", {}).get("enabled", False) + has_v_bias = mixer.get("value_layer", {}).get("bias", {}).get("enabled", False) + + weights = {} + weights["model.embed_tokens.weight"] = torch.randn(vocab, hidden) + + for i in range(num_blocks): + p = f"model.decoder.blocks.{i}" + + # Attention + weights[f"{p}.mixer.q_proj.weight"] = torch.randn(heads * head_size, hidden) + weights[f"{p}.mixer.k_proj.weight"] = torch.randn(head_groups * head_size, hidden) + weights[f"{p}.mixer.v_proj.weight"] = torch.randn(head_groups * head_size, hidden) + weights[f"{p}.mixer.o_proj.weight"] = torch.randn(hidden, heads * head_size) + + if has_q_bias: + weights[f"{p}.mixer.q_proj.bias"] = torch.randn(heads * head_size) + if has_k_bias: + weights[f"{p}.mixer.k_proj.bias"] = torch.randn(head_groups * head_size) + if has_v_bias: + weights[f"{p}.mixer.v_proj.bias"] = torch.randn(head_groups * head_size) + + # MLP + weights[f"{p}.mlp.up_proj.weight"] = torch.randn(inter, hidden) + weights[f"{p}.mlp.gate_proj.weight"] = torch.randn(inter, hidden) + weights[f"{p}.mlp.down_proj.weight"] = torch.randn(hidden, inter) + + # Norms + weights[f"{p}.input_layernorm.weight"] = torch.randn(hidden) + weights[f"{p}.post_attention_layernorm.weight"] = torch.randn(hidden) + + weights["model.norm.weight"] = torch.randn(hidden) + weights["lm_head.weight"] = torch.randn(vocab, hidden) + + return {W(k): v for k, v in weights.items()} + + # ============================================================================= # Cache Test Fixtures - Tensor Dimensions # ============================================================================= diff --git a/fast_llm_external_models/tests/test_apriel2/test_compose_configs.py b/fast_llm_external_models/tests/test_apriel2/test_compose_configs.py index 4380b1fbd..b1ee15d54 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_compose_configs.py +++ b/fast_llm_external_models/tests/test_apriel2/test_compose_configs.py @@ -75,14 +75,10 @@ def source_config(self): }, } - def test_identity_empty_surgery(self, source_config): - """Law 1: compose_configs(config, {}) == config""" - result = compose_configs(source_config, {}) - assert result == source_config - - def test_identity_none_surgery(self, source_config): - """Law 1: compose_configs(config, None) == config""" - result = compose_configs(source_config, None) + @pytest.mark.parametrize("empty_surgery", [{}, None]) + def test_identity(self, source_config, empty_surgery): + """Law 1: compose_configs(config, empty) == config for empty in [{}, None]""" + result = compose_configs(source_config, empty_surgery) assert result == source_config def test_override_explicit_values(self, source_config): @@ -114,7 +110,7 @@ def test_same_type_inheritance(self, source_config): assert mixer["head_size"] == 32 # Inherited assert mixer["rope_theta"] == 10000.0 # Inherited assert mixer["window_size"] == 512 # Added - assert "init" not in mixer # Stripped by apply_surgery + # init is preserved for plan_surgery to see (stripped only at final output) def test_cross_type_attention_to_gdn(self, source_config): """Law 5: attentionβ†’gdn derives GDN dims from attention geometry.""" @@ -239,8 +235,14 @@ def test_null_deletion(self, source_config): assert "vision_encoder" not in result - def test_init_stripped_from_result(self, source_config): - """Verify `init` keys are stripped from final result.""" + def test_init_preserved_for_plan_surgery(self, source_config): + """Verify `init` keys are preserved so plan_surgery can see them. + + The `init` field controls weight initialization (transfer vs random). + It's preserved through composition and only stripped at final output. + """ + from fast_llm_external_models.apriel2.conversion.config import strip_init_fields + surgery = { "decoder": { "block": { @@ -252,20 +254,20 @@ def test_init_stripped_from_result(self, source_config): "gdn": {"type": "gdn", "init": "random", "convolution_layer": {"kernel_size": 4}}, }, }, - "mlp": {"init": "transfer"}, - "normalization": {"init": "transfer"}, }, }, } result = compose_configs(source_config, surgery) - def check_no_init(d, path=""): - assert "init" not in d, f"Found 'init' key at {path}" - for k, v in d.items(): - if isinstance(v, dict): - check_no_init(v, f"{path}.{k}") + # init is preserved in composed config + mixers = result["decoder"]["block"]["mixer"]["mixers"] + assert mixers["attention"].get("init") == "transfer" + assert mixers["gdn"].get("init") == "random" - check_no_init(result) + # strip_init_fields removes them for final output + stripped = strip_init_fields(result) + assert "init" not in stripped["decoder"]["block"]["mixer"]["mixers"]["attention"] + assert "init" not in stripped["decoder"]["block"]["mixer"]["mixers"]["gdn"] def test_init_random_still_inherits_config(self, source_config): """init: random is for weights only - config params still inherited.""" @@ -287,6 +289,49 @@ def test_init_random_still_inherits_config(self, source_config): assert mixer["head_groups"] == 4 assert mixer["window_size"] == 512 + # ========================================================================= + # Monoid Laws: compose_configs forms a monoid action on configs + # ========================================================================= + + def test_surgery_monoid_associativity(self): + """MONOID: merge(merge(A, B), C) == merge(A, merge(B, C)) for partial configs.""" + surgery_a = {"decoder": {"block": {"mixer": {"type": "stochastic", "main_mixer_name": "attention"}}}} + surgery_b = {"decoder": {"block": {"mixer": {"mixers": {"sliding_window": {"window_size": 512}}}}}} + surgery_c = {"decoder": {"block": {"mixer": {"mixers": {"gdn": {"type": "gdn"}}}}}} + + # Left-associated: (A ∘ B) ∘ C + ab_c = compose_configs(compose_configs(surgery_a, surgery_b), surgery_c) + # Right-associated: A ∘ (B ∘ C) + a_bc = compose_configs(surgery_a, compose_configs(surgery_b, surgery_c)) + + assert ab_c == a_bc, "Surgery monoid should be associative" + + @pytest.mark.parametrize("num_surgeries", [2, 3]) + def test_monoid_action_compatibility(self, source_config, num_surgeries): + """MONOID ACTION: apply(apply(c, A), B) == apply(c, merge(A, B)) + + This is the key law: applying surgeries sequentially equals merging first. + Parameterized to test with 2 and 3 surgeries. + """ + surgeries = [ + {"decoder": {"block": {"mixer": {"type": "stochastic", "main_mixer_name": "attention", "mixers": {"attention": {}}}}}}, + {"decoder": {"block": {"mixer": {"mixers": {"sliding_window": {"window_size": 512}}}}}}, + {"decoder": {"block": {"mixer": {"mixers": {"gdn": {"type": "gdn"}}}}}}, + ][:num_surgeries] + + # Sequential: ((c ⊳ A) ⊳ B) ⊳ ... + result_sequential = source_config + for s in surgeries: + result_sequential = compose_configs(result_sequential, s) + + # Merged: c ⊳ (A ∘ B ∘ ...) + merged = surgeries[0] + for s in surgeries[1:]: + merged = compose_configs(merged, s) + result_merged = compose_configs(source_config, merged) + + assert result_sequential == result_merged, f"Monoid action compatibility failed for {num_surgeries} surgeries" + class TestBiasConfigInheritance: """Test per-layer bias inheritance through surgery composition. @@ -555,160 +600,12 @@ def test_build_plan_returns_complete_config(self, llava_pixtral_checkpoint): mixer = config.decoder["block"]["mixer"] assert mixer["type"] == "stochastic" - # Each sub-mixer should have complete config (no init keys) + # Each sub-mixer should have complete config + # (init is preserved for plan_surgery, stripped only at final output) for name, sub_mixer in mixer["mixers"].items(): - assert "init" not in sub_mixer, f"Sub-mixer {name} still has 'init' key" assert "type" in sub_mixer -class TestMonoidLaws: - """Test the algebraic laws of compose_configs. - - Surgery specs form a MONOID under deep-merge: - - Identity: {} - - Operation: deep merge (overlay wins) - - Associativity: merge(merge(A, B), C) == merge(A, merge(B, C)) - - compose_configs is a MONOID ACTION on configs: - - Identity action: apply(config, {}) == config - - Compatibility: apply(apply(c, A), B) == apply(c, merge(A, B)) - """ - - @pytest.fixture - def complete_config(self): - """A complete Apriel2 config.""" - return { - "model_type": "apriel2", - "architectures": ["Apriel2ForConditionalGeneration"], - "hidden_size": 256, - "vocab_size": 1000, - "bos_token_id": 1, - "eos_token_id": 2, - "tie_word_embeddings": False, - "image_token_index": 100, - "decoder": { - "type": "fixed", - "num_blocks": 4, - "block": { - "mixer": { - "type": "attention", - "heads": 8, - "head_groups": 4, - "head_size": 32, - "rope_theta": 10000.0, - }, - "mlp": {"type": "mlp", "intermediate_size": 512}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - }, - } - - @pytest.fixture - def surgery_a(self): - """First surgery: wrap in stochastic with attention.""" - return { - "decoder": { - "block": { - "mixer": { - "type": "stochastic", - "main_mixer_name": "attention", - "mixers": { - "attention": {"init": "transfer"}, - }, - }, - }, - }, - } - - @pytest.fixture - def surgery_b(self): - """Second surgery: add sliding window mixer.""" - return { - "decoder": { - "block": { - "mixer": { - "mixers": { - "sliding_window": {"init": "transfer", "window_size": 512}, - }, - }, - }, - }, - } - - def test_identity_action(self, complete_config): - """apply(config, {}) == config""" - result = compose_configs(complete_config, {}) - assert result == complete_config - - def test_surgery_monoid_associativity(self, surgery_a, surgery_b): - """merge(merge(A, B), C) == merge(A, merge(B, C)) for partial configs.""" - surgery_c = { - "decoder": { - "block": { - "mixer": { - "mixers": { - "gdn": {"type": "gdn", "init": "transfer", "convolution_layer": {"kernel_size": 4}}, - }, - }, - }, - }, - } - - # Left-associated: (A ∘ B) ∘ C - ab = compose_configs(surgery_a, surgery_b) - ab_c = compose_configs(ab, surgery_c) - - # Right-associated: A ∘ (B ∘ C) - bc = compose_configs(surgery_b, surgery_c) - a_bc = compose_configs(surgery_a, bc) - - assert ab_c == a_bc, "Surgery monoid should be associative" - - def test_monoid_action_compatibility(self, complete_config, surgery_a, surgery_b): - """apply(apply(c, A), B) == apply(c, merge(A, B)) - - This is the key law: applying surgeries sequentially should equal - merging the surgeries first, then applying once. - """ - # Sequential application: (c ⊳ A) ⊳ B - result_sequential = compose_configs(compose_configs(complete_config, surgery_a), surgery_b) - - # Merged application: c ⊳ (A ∘ B) - merged_surgery = compose_configs(surgery_a, surgery_b) - result_merged = compose_configs(complete_config, merged_surgery) - - # These should be equivalent - assert result_sequential == result_merged, "Monoid action should satisfy compatibility law" - - def test_three_way_compatibility(self, complete_config, surgery_a, surgery_b): - """Test with three surgeries for stronger confidence.""" - surgery_c = { - "decoder": { - "block": { - "mixer": { - "mixers": { - "gdn": {"type": "gdn", "init": "transfer", "convolution_layer": {"kernel_size": 4}}, - }, - }, - }, - }, - } - - # Sequential: ((c ⊳ A) ⊳ B) ⊳ C - seq = compose_configs( - compose_configs(compose_configs(complete_config, surgery_a), surgery_b), - surgery_c - ) - - # Merged: c ⊳ ((A ∘ B) ∘ C) - merged = compose_configs( - complete_config, - compose_configs(compose_configs(surgery_a, surgery_b), surgery_c) - ) - - assert seq == merged, "Three-way monoid action should satisfy compatibility" - - class TestCompositionTortureTest: """Comprehensive stress test for config composition. @@ -807,19 +704,29 @@ def test_final_config_structure(self, complete_config, additive_surgery_chain): assert mixer["mixers"]["sliding_window"]["window_size"] == 512 assert mixer["mixers"]["gdn"]["value_heads"] == 16 - def test_no_init_keys_in_result(self, complete_config, additive_surgery_chain): - """Verify no 'init' keys leak through.""" + def test_init_keys_preserved_for_planning(self, complete_config, additive_surgery_chain): + """Verify 'init' keys are preserved for plan_surgery to see. - def check_no_init(d, path=""): - if isinstance(d, dict): - assert "init" not in d, f"Found 'init' key at {path}" - for k, v in d.items(): - check_no_init(v, f"{path}.{k}") + The `init` field is metadata for weight initialization. It's preserved + through composition and only stripped when saving final output. + """ + from fast_llm_external_models.apriel2.conversion.config import strip_init_fields result = complete_config for i, surgery in enumerate(additive_surgery_chain): result = compose_configs(result, surgery) - check_no_init(result, f"step_{i+1}") + + # init should be in the composed config + mixer = result["decoder"]["block"]["mixer"] + if "mixers" in mixer: + has_init = any("init" in m for m in mixer["mixers"].values()) + assert has_init, "init should be preserved in composed config" + + # strip_init_fields removes them + stripped = strip_init_fields(result) + mixer = stripped["decoder"]["block"]["mixer"] + if "mixers" in mixer: + assert all("init" not in m for m in mixer["mixers"].values()) def test_full_torture_chain(self, complete_config, torture_surgery_chain): """Test the full 10-step torture chain produces valid configs.""" diff --git a/fast_llm_external_models/tests/test_apriel2/test_plan_composition_torture.py b/fast_llm_external_models/tests/test_apriel2/test_conversion_e2e.py similarity index 84% rename from fast_llm_external_models/tests/test_apriel2/test_plan_composition_torture.py rename to fast_llm_external_models/tests/test_apriel2/test_conversion_e2e.py index 76a77ccb6..09fb9fa13 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_plan_composition_torture.py +++ b/fast_llm_external_models/tests/test_apriel2/test_conversion_e2e.py @@ -1,6 +1,6 @@ -"""End-to-end torture test for plan composition. +"""test_conversion_e2e.py - End-to-end conversion integration tests. -This tests the FULL pipeline at every step of a surgery chain: +Tests the FULL pipeline at every step of a surgery chain: 1. Config composition produces valid configs 2. Plan building works for each surgery 3. Plan execution produces valid weights @@ -1083,66 +1083,6 @@ def mamba_config(self): }, } - def test_config_composition_identical_regardless_of_init_mode(self, base_config): - """Config composition produces same structure with init: transfer vs init: random.""" - # Surgery with init: transfer - surgery_transfer = { - "decoder": { - "block": { - "mixer": { - "type": "stochastic", - "main_mixer_name": "attention", - "mixers": { - "attention": {"type": "attention", "init": "transfer"}, - "swa": { - "type": "attention", - "init": "transfer", - "sliding_window": 512, - }, - }, - }, - }, - }, - } - - # Surgery with init: random - surgery_random = { - "decoder": { - "block": { - "mixer": { - "type": "stochastic", - "main_mixer_name": "attention", - "mixers": { - "attention": {"type": "attention", "init": "random"}, - "swa": { - "type": "attention", - "init": "random", - "sliding_window": 512, - }, - }, - }, - }, - }, - } - - # Compose configs - result_transfer = compose_configs(base_config, surgery_transfer) - result_random = compose_configs(base_config, surgery_random) - - # Both should produce identical structure (init is stripped) - assert result_transfer == result_random, ( - "Config composition should produce identical structure regardless of init mode" - ) - - # Verify the structure is correct - mixer = result_transfer["decoder"]["block"]["mixer"] - assert mixer["type"] == "stochastic" - assert "attention" in mixer["mixers"] - assert "swa" in mixer["mixers"] - # init should be stripped - assert "init" not in mixer["mixers"]["attention"] - assert "init" not in mixer["mixers"]["swa"] - def test_plan_surgery_random_succeeds_for_any_type_pair(self, mamba_config): """plan_surgery with init: random should succeed even for mamba -> attention.""" # This surgery changes mamba to attention with random init @@ -1313,8 +1253,8 @@ class TestMarkovianProperty: """ @pytest.fixture - def attention_config(self): - """Base config with attention.""" + def attention_config_dict(self): + """Base config dict with attention mixer for compose_configs tests.""" return { "model_type": "apriel2", "hidden_size": 256, @@ -1335,43 +1275,7 @@ def attention_config(self): }, } - @pytest.fixture - def stochastic_config(self): - """Config with stochastic mixer.""" - return { - "model_type": "apriel2", - "hidden_size": 256, - "vocab_size": 1000, - "decoder": { - "type": "fixed", - "num_blocks": 2, - "block": { - "mixer": { - "type": "stochastic", - "main_mixer_name": "attention", - "mixers": { - "attention": { - "type": "attention", - "heads": 8, - "head_groups": 4, - "head_size": 32, - }, - "swa": { - "type": "sliding_window", - "heads": 8, - "head_groups": 4, - "head_size": 32, - "window_size": 512, - }, - }, - }, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - }, - } - - def test_different_paths_same_config_same_plan(self, attention_config): + def test_different_paths_same_config_same_plan(self, attention_config_dict): """Two different paths to the same config produce identical plans. Path A: attention -> stochastic{att, swa} @@ -1398,7 +1302,7 @@ def test_different_paths_same_config_same_plan(self, attention_config): }, }, } - config_a = compose_configs(attention_config, surgery_a) + config_a = compose_configs(attention_config_dict, surgery_a) # Path B: First add attention only, then add swa surgery_b1 = { @@ -1414,7 +1318,7 @@ def test_different_paths_same_config_same_plan(self, attention_config): }, }, } - intermediate_config = compose_configs(attention_config, surgery_b1) + intermediate_config = compose_configs(attention_config_dict, surgery_b1) surgery_b2 = { "decoder": { @@ -1469,7 +1373,7 @@ def test_different_paths_same_config_same_plan(self, attention_config): keys_b = set(str(k) for k in plan_from_b.mappings.keys()) assert keys_a == keys_b, "Plans from same config via different paths should be identical" - def test_init_in_source_config_does_not_affect_plan(self, attention_config): + def test_init_in_source_config_does_not_affect_plan(self, attention_config_dict): """Manually injecting init into source config doesn't change the plan. This tests that plan_surgery reads init from surgery, not source. @@ -1479,8 +1383,8 @@ def test_init_in_source_config_does_not_affect_plan(self, attention_config): import copy # Create two copies of the config - config_with_init = copy.deepcopy(attention_config) - config_without_init = copy.deepcopy(attention_config) + config_with_init = copy.deepcopy(attention_config_dict) + config_without_init = copy.deepcopy(attention_config_dict) # Manually inject init into one (bypassing compose_configs) config_with_init["decoder"]["block"]["mixer"]["init"] = "random" @@ -1510,232 +1414,6 @@ def test_init_in_source_config_does_not_affect_plan(self, attention_config): # Plans should be identical - source's init field is ignored assert keys_with == keys_without, "Plan should not depend on init in source config" - def test_associativity_of_surgery_composition(self, attention_config): - """Verify associativity: (A ∘ B) ∘ C == A ∘ (B ∘ C) for surgery specs. - - This tests that composing surgeries is associative, which is - equivalent to Markovianity for plan creation. - """ - surgery_a = { - "decoder": { - "block": { - "mixer": { - "type": "stochastic", - "main_mixer_name": "attention", - "mixers": { - "attention": {"type": "attention", "init": "transfer"}, - }, - }, - }, - }, - } - - surgery_b = { - "decoder": { - "block": { - "mixer": { - "mixers": { - "swa": { - "type": "sliding_window", - "init": "transfer", - "window_size": 512, - }, - }, - }, - }, - }, - } - - surgery_c = { - "decoder": { - "block": { - "mixer": { - "mixers": { - "gdn": { - "type": "gdn", - "init": "random", - "value_heads": 8, - "key_heads": 4, - "key_head_dim": 32, - "value_head_dim": 32, - "convolution_layer": {"kernel_size": 4}, - }, - }, - }, - }, - }, - } - - # Left association: ((attention_config ∘ A) ∘ B) ∘ C - left_1 = compose_configs(attention_config, surgery_a) - left_2 = compose_configs(left_1, surgery_b) - left_result = compose_configs(left_2, surgery_c) - - # Right association: (attention_config ∘ A) ∘ (B ∘ C) - # Note: B ∘ C is partial ∘ partial = deep merge of surgery specs - bc_merged = compose_configs(surgery_b, surgery_c) - right_1 = compose_configs(attention_config, surgery_a) - right_result = compose_configs(right_1, bc_merged) - - assert left_result == right_result, "Surgery composition should be associative" - - def test_complete_configs_have_no_init_fields(self, attention_config): - """Verify that compose_configs strips init from complete configs. - - This is the key invariant that enables Markovianity: - - Complete configs (states) have no init fields - - Surgery specs (transitions) have init fields - - Plans read init from surgery, not state - """ - surgery_with_init = { - "decoder": { - "block": { - "mixer": { - "type": "stochastic", - "main_mixer_name": "attention", - "mixers": { - "attention": {"type": "attention", "init": "transfer"}, - "swa": {"type": "sliding_window", "init": "random", "window_size": 512}, - }, - }, - }, - }, - } - - result = compose_configs(attention_config, surgery_with_init) - - # Recursively check for init fields - def has_init(obj): - if isinstance(obj, dict): - if "init" in obj: - return True - return any(has_init(v) for v in obj.values()) - if isinstance(obj, list): - return any(has_init(v) for v in obj) - return False - - assert not has_init(result), "Complete configs should have no init fields" - - def test_monoid_action_law_additive_surgeries(self): - """Monoid action law HOLDS for additive surgeries. - - Additive surgeries (no type: declaration) support: - apply(apply(s, t1), t2) == apply(s, t1 ∘ t2) - - This is because additive operations commute nicely: - "add {a}" then "add {b}" == "add {a, b}" - """ - # Start with stochastic (additive surgery target) - s = { - "model_type": "apriel2", - "hidden_size": 256, - "vocab_size": 1000, - "decoder": { - "type": "fixed", - "num_blocks": 2, - "block": { - "mixer": { - "type": "stochastic", - "main_mixer_name": "attention", - "mixers": { - "attention": {"type": "attention", "heads": 8, "head_groups": 4, "head_size": 32}, - }, - }, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - }, - } - - # Additive surgeries (no type: declaration) - t1 = {"decoder": {"block": {"mixer": {"mixers": {"swa": {"type": "sliding_window", "window_size": 512}}}}}} - t2 = {"decoder": {"block": {"mixer": {"mixers": {"mamba": {"type": "mamba", "d_inner": 512}}}}}} - - # Path A: Sequential - s_prime = compose_configs(s, t1) - s_double_prime_A = compose_configs(s_prime, t2) - - # Path B: Composed - t1_t2 = compose_configs(t1, t2) - s_double_prime_B = compose_configs(s, t1_t2) - - assert s_double_prime_A == s_double_prime_B, "Monoid action law should hold for additive surgeries" - - def test_monoid_action_law_replacement_surgeries_fails(self): - """Monoid action law FAILS for replacement surgeries (by design). - - Replacement surgeries (type: stochastic declared) have: - apply(apply(s, t1), t2) != apply(s, t1 ∘ t2) - - This is FUNDAMENTAL, not a bug: - - Sequential: "set to {a}" then "set to {b}" β†’ {b} (second wins) - - Composed: merge({a}, {b}) = {a,b}, then apply β†’ {a,b} - - These are genuinely different semantics. The failure documents - the distinction between declarative composition (merge) and - operational composition (function application). - """ - s = { - "model_type": "apriel2", - "hidden_size": 256, - "vocab_size": 1000, - "decoder": { - "type": "fixed", - "num_blocks": 2, - "block": { - "mixer": {"type": "attention", "heads": 8, "head_groups": 4, "head_size": 32}, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - }, - } - - # Replacement surgeries (both declare type: stochastic) - t1 = { - "decoder": { - "block": { - "mixer": { - "type": "stochastic", - "main_mixer_name": "attention", - "mixers": {"attention": {"type": "attention"}}, - } - } - } - } - t2 = { - "decoder": { - "block": { - "mixer": { - "type": "stochastic", - "main_mixer_name": "swa", - "mixers": {"swa": {"type": "sliding_window", "window_size": 512}}, - } - } - } - } - - # Path A: Sequential (second replacement wins) - s_prime = compose_configs(s, t1) - s_double_prime_A = compose_configs(s_prime, t2) - - # Path B: Composed (declarations merged) - t1_t2 = compose_configs(t1, t2) - s_double_prime_B = compose_configs(s, t1_t2) - - # They should be DIFFERENT (law fails) - assert s_double_prime_A != s_double_prime_B, ( - "Monoid action law should FAIL for replacement surgeries" - ) - - # Verify the specific difference: - # Sequential: only swa (second replacement wins) - # Composed: both attention and swa (merged declarations) - mixers_A = set(s_double_prime_A["decoder"]["block"]["mixer"]["mixers"].keys()) - mixers_B = set(s_double_prime_B["decoder"]["block"]["mixer"]["mixers"].keys()) - - assert mixers_A == {"swa"}, "Sequential: second replacement wins" - assert mixers_B == {"attention", "swa"}, "Composed: declarations merged" - class TestCyclingSurgeryGeneration: """Tests for the cycling surgery generation functions. diff --git a/fast_llm_external_models/tests/test_apriel2/test_plan_execution.py b/fast_llm_external_models/tests/test_apriel2/test_plan_execution.py new file mode 100644 index 000000000..9a98ec13b --- /dev/null +++ b/fast_llm_external_models/tests/test_apriel2/test_plan_execution.py @@ -0,0 +1,597 @@ +"""test_plan_execution.py - Plan execution and algebraic composition laws. + +This module provides rigorous, parameterized tests for the mathematical properties +that the conversion system must satisfy. Each test class corresponds to one +algebraic structure, and each test method verifies one specific law. + +Conceptual Types +================ + +The conversion system operates on three conceptual types (all ``dict`` at runtime): + +- **S (State)**: Complete config without ``init`` fields +- **P (Partial Surgery)**: Incomplete config, may have ``init`` fields +- **T (Transition Spec)**: Complete config WITH ``init`` fields + +Algebraic Structures +==================== + +1. **Partial Surgeries (P)** form a **Monoid** under deep merge:: + + compose_configs : P Γ— P β†’ P + Identity: {} + Associativity: (p1 ∘ p2) ∘ p3 = p1 ∘ (p2 ∘ p3) + +2. **Surgeries act on States** to produce Transition Specs:: + + compose_configs : S Γ— P β†’ T + compose_configs : T Γ— P β†’ T + + Action law (additive surgeries): (s Β· p1) Β· p2 = s Β· (p1 ∘ p2) + +3. **Plans** form a **Category** with composition:: + + compose : Plan(Aβ†’B) Γ— Plan(Bβ†’C) β†’ Plan(Aβ†’C) + Associativity: (P1 ∘ P2) ∘ P3 = P1 ∘ (P2 ∘ P3) + +4. **plan_surgery is a Functor** from config pairs to plans:: + + plan_surgery : S Γ— T β†’ Plan + Functoriality: compose(plan(S,T1), plan(T1,T2)) ≑ plan(S,T2) + + This is semantic equivalence: both produce identical weights when executed. + +Important Behaviors Tested +========================== + +- **init stripping**: Between surgery iterations, T β†’ S conversion via + ``strip_init_fields()`` ensures ``init: random`` applies only to the surgery + that introduces a component. + +- **Bias inheritance**: Per-layer bias settings propagate through surgery chains. + +- **Plan composition**: Composed plans produce identical weights to direct plans. + +Design Principles +================= + +- Each law gets ONE parameterized test, not multiple similar tests +- Fixtures provide diverse configs (with/without biases) +- Corner cases are covered via parameterization, not test proliferation +- Tests document the laws they verify in their docstrings +""" + +import pytest +import torch +from functools import reduce + +from fast_llm_external_models.apriel2.conversion import ( + compose, + compose_configs, + execute, + plan_surgery, + ExprPlan, + W, + Ref, + Concat, + Slice, + Init, +) + +# Import shared helper from conftest +from fast_llm_external_models.tests.test_apriel2.conftest import make_weights_for_config + + +# ============================================================================= +# Fixtures: Use shared fixtures from conftest.py where possible +# ============================================================================= +# - base_config_dict: Complete config without biases (Llama-style) +# - base_config_with_bias_dict: Complete config with QKV biases +# - additive_surgery_chain: [wrap_stochastic, add_sliding_window, add_gdn] +# ============================================================================= + + +# ============================================================================= +# Test: Plan Composition Associativity +# ============================================================================= + + +class TestPlanCompositionAssociativity: + """ + LAW: Plan composition is associative. + + (P₁ ∘ Pβ‚‚) ∘ P₃ = P₁ ∘ (Pβ‚‚ ∘ P₃) + + where ∘ denotes compose(P1, P2). + + This must hold for the AST structure, not just semantic equivalence. + """ + + @pytest.mark.parametrize("expr_type", ["ref_chain", "with_concat", "with_slice", "with_init"]) + def test_associativity(self, expr_type): + """Plan composition is associative for various expression types.""" + # Build three plans that can be composed + if expr_type == "ref_chain": + p1 = ExprPlan(mappings={W("b"): Ref(key=W("a"))}) + p2 = ExprPlan(mappings={W("c"): Ref(key=W("b"))}) + p3 = ExprPlan(mappings={W("d"): Ref(key=W("c"))}) + elif expr_type == "with_concat": + p1 = ExprPlan(mappings={W("x"): Ref(key=W("a")), W("y"): Ref(key=W("b"))}) + p2 = ExprPlan(mappings={W("xy"): Concat(exprs=(Ref(key=W("x")), Ref(key=W("y"))), dim=0)}) + p3 = ExprPlan(mappings={W("final"): Ref(key=W("xy"))}) + elif expr_type == "with_slice": + p1 = ExprPlan(mappings={W("full"): Ref(key=W("src"))}) + p2 = ExprPlan(mappings={W("part"): Slice(expr=Ref(key=W("full")), slices=((0, 5, None),))}) + p3 = ExprPlan(mappings={W("out"): Ref(key=W("part"))}) + elif expr_type == "with_init": + p1 = ExprPlan(mappings={W("x"): Ref(key=W("a"))}) + p2 = ExprPlan(mappings={W("y"): Concat(exprs=(Ref(key=W("x")), Init(shape=(5,), init_type="zeros")), dim=0)}) + p3 = ExprPlan(mappings={W("z"): Ref(key=W("y"))}) + + left = compose(compose(p1, p2), p3) + right = compose(p1, compose(p2, p3)) + + assert left.mappings == right.mappings, f"Associativity failed for {expr_type}" + + +# ============================================================================= +# Test: Functoriality of plan_surgery (THE CRITICAL PROPERTY) +# ============================================================================= + + +class TestPlanSurgeryFunctoriality: + """ + LAW: plan_surgery is functorial with respect to config composition. + + For a surgery chain P₁, Pβ‚‚, ..., Pβ‚™ applied to base state Sβ‚€:: + + T₁ = compose_configs(Sβ‚€, P₁) # S Γ— P β†’ T + Tβ‚‚ = compose_configs(T₁, Pβ‚‚) # T Γ— P β†’ T (no stripping!) + ... + Tβ‚™ = compose_configs(Tₙ₋₁, Pβ‚™) + + Plan functoriality says:: + + compose(plan(Sβ‚€,T₁), plan(T₁,Tβ‚‚), ...) ≑ plan(Sβ‚€, Tβ‚™) + + where ≑ denotes semantic equivalence (identical weights when executed). + + NOTE: This tests T Γ— P composition WITHOUT stripping between steps. + This differs from build_plan which strips (T β†’ S) between iterations. + Both patterns are valid: + + - Without stripping: init fields accumulate, testing plan composition purity + - With stripping: init consumed per-step, testing real usage (see + test_build_plan_strips_init_between_iterations) + + The functoriality law holds in both cases because plan composition + correctly substitutes Ref expressions with their definitions. + """ + + @pytest.mark.parametrize("chain_length", [1, 2, 3]) + @pytest.mark.parametrize("use_bias", [True, False]) + def test_functoriality( + self, + chain_length, + use_bias, + base_config_dict, + base_config_with_bias_dict, + additive_surgery_chain, + ): + """ + Composed incremental plans produce same weights as direct plan. + + Parameterized over: + - chain_length: Number of surgeries (1, 2, or 3) + - use_bias: Whether base config has biases + """ + base_config = base_config_with_bias_dict if use_bias else base_config_dict + surgeries = additive_surgery_chain[:chain_length] + + # Build config chain: Cβ‚€ β†’ C₁ β†’ ... β†’ Cβ‚™ + configs = [base_config] + for s in surgeries: + configs.append(compose_configs(configs[-1], s)) + + # Build incremental plans: Pβ‚– = plan_surgery(Cₖ₋₁, Cβ‚–) + plans = [plan_surgery(configs[i], configs[i+1]) for i in range(len(surgeries))] + + # Compose all incremental plans + composed_plan = reduce(compose, plans) + + # Build direct plan: plan_surgery(Cβ‚€, Cβ‚™) + direct_plan = plan_surgery(configs[0], configs[-1]) + + # Execute both on same weights + weights = make_weights_for_config(base_config) + composed_weights = execute(composed_plan, weights, seed=42) + direct_weights = execute(direct_plan, weights, seed=42) + + # Verify semantic equivalence + assert set(composed_weights.keys()) == set(direct_weights.keys()), \ + f"Key sets differ for chain_length={chain_length}, use_bias={use_bias}" + + for key in composed_weights: + assert torch.allclose(composed_weights[key], direct_weights[key], atol=1e-6), \ + f"Weight mismatch for {key} with chain_length={chain_length}, use_bias={use_bias}" + + @pytest.mark.parametrize("split_point", [1, 2]) + def test_arbitrary_grouping( + self, + split_point, + base_config_with_bias_dict, + additive_surgery_chain, + ): + """ + Any grouping of surgery chain produces same result. + + For surgeries [S₁, Sβ‚‚, S₃], tests that: + - compose(P₁, compose(Pβ‚‚, P₃)) + - compose(compose(P₁, Pβ‚‚), P₃) + - plan_surgery(Cβ‚€, C₃) + + all produce identical weights. + """ + surgeries = additive_surgery_chain + + # Build config chain + configs = [base_config_with_bias_dict] + for s in surgeries: + configs.append(compose_configs(configs[-1], s)) + + # Build incremental plans + plans = [plan_surgery(configs[i], configs[i+1]) for i in range(3)] + + # Different groupings + left_grouped = compose(compose(plans[0], plans[1]), plans[2]) + right_grouped = compose(plans[0], compose(plans[1], plans[2])) + direct = plan_surgery(configs[0], configs[-1]) + + # Execute all + weights = make_weights_for_config(base_config_with_bias_dict) + results = { + "left": execute(left_grouped, weights, seed=42), + "right": execute(right_grouped, weights, seed=42), + "direct": execute(direct, weights, seed=42), + } + + # All must match + keys = set(results["left"].keys()) + assert keys == set(results["right"].keys()) == set(results["direct"].keys()) + + for key in keys: + assert torch.allclose(results["left"][key], results["right"][key], atol=1e-6) + assert torch.allclose(results["left"][key], results["direct"][key], atol=1e-6) + + +# ============================================================================= +# Test: Bias Inheritance Preservation (Regression for the specific bug) +# ============================================================================= + + +class TestBiasInheritancePreservation: + """ + PROPERTY: Per-layer bias settings must be preserved through surgery chains. + + When a surgery spec does not mention bias settings, they must be inherited + from the source config. This is the specific failure mode of the build_plan + bug: passing partial surgery specs to plan_surgery lost inherited fields. + + This test verifies the SYMPTOM (missing biases) rather than the LAW + (functoriality). It's kept as a focused regression test. + """ + + @pytest.mark.parametrize("num_surgeries", [1, 2, 3]) + def test_qkv_biases_preserved_through_chain( + self, + num_surgeries, + base_config_with_bias_dict, + additive_surgery_chain, + ): + """QKV biases (enabled in source) appear in plan after N surgeries.""" + surgeries = additive_surgery_chain[:num_surgeries] + + # Build config and plan chain + configs = [base_config_with_bias_dict] + for s in surgeries: + configs.append(compose_configs(configs[-1], s)) + + plans = [plan_surgery(configs[i], configs[i+1]) for i in range(num_surgeries)] + final_plan = reduce(compose, plans) if len(plans) > 1 else plans[0] + + # Check bias keys present + target_keys = {str(k) for k in final_plan.target_keys()} + + assert any("q_proj.bias" in k for k in target_keys), \ + f"q_proj.bias missing after {num_surgeries} surgeries" + assert any("k_proj.bias" in k for k in target_keys), \ + f"k_proj.bias missing after {num_surgeries} surgeries" + assert any("v_proj.bias" in k for k in target_keys), \ + f"v_proj.bias missing after {num_surgeries} surgeries" + # O bias should NOT be present (disabled in source) + assert not any("o_proj.bias" in k for k in target_keys), \ + f"o_proj.bias should not be present (disabled in source)" + + def test_bias_values_preserved( + self, + base_config_with_bias_dict, + additive_surgery_chain, + ): + """Bias tensor values are correctly transferred, not just keys.""" + surgery = additive_surgery_chain[0] # wrap_stochastic + c1 = compose_configs(base_config_with_bias_dict, surgery) + plan = plan_surgery(base_config_with_bias_dict, c1) + + weights = make_weights_for_config(base_config_with_bias_dict) + result = execute(plan, weights, seed=42) + + # Verify values match (not just that keys exist) + for i in range(base_config_with_bias_dict["decoder"]["num_blocks"]): + src_key = W(f"model.decoder.blocks.{i}.mixer.q_proj.bias") + dst_key = W(f"model.decoder.blocks.{i}.mixer.mixers.attention.q_proj.bias") + + assert dst_key in result, f"Missing {dst_key}" + assert torch.allclose(weights[src_key], result[dst_key]), \ + f"Bias values differ for block {i}" + + +# ============================================================================= +# Test: build_plan Integration (Regression test for convert.py) +# ============================================================================= + + +class TestBuildPlanIntegration: + """ + REGRESSION: build_plan must compose configs before calling plan_surgery. + + The bug was: + plan_surgery(current_config, surgery_config) # WRONG: partial + + Should be: + target = compose_configs(current_config, surgery_config) + plan_surgery(current_config, target) # CORRECT: complete + + This test verifies the fix in convert.py's build_plan function. + """ + + @pytest.mark.parametrize("num_surgeries", [1, 2]) + def test_build_plan_preserves_inherited_fields( + self, + num_surgeries, + base_config_with_bias_dict, + additive_surgery_chain, + ): + """build_plan produces plans with inherited bias mappings.""" + from fast_llm_external_models.apriel2.convert import build_plan + + surgeries = additive_surgery_chain[:num_surgeries] + + plan, final_config = build_plan( + base_config_with_bias_dict, + surgeries, + source_format="apriel2", + ) + + # Verify inherited biases in config + if num_surgeries >= 1: + attn = final_config["decoder"]["block"]["mixer"]["mixers"]["attention"] + assert attn.get("query_layer", {}).get("bias", {}).get("enabled") is True + + # Verify bias mappings in plan + target_keys = {str(k) for k in plan.target_keys()} + assert any("q_proj.bias" in k for k in target_keys), \ + f"build_plan with {num_surgeries} surgeries missing q_proj.bias" + + +# ============================================================================= +# Test: init Field Preservation (Critical for random initialization) +# ============================================================================= + + +class TestInitFieldPreservation: + """ + PROPERTY: The `init` field must be visible to plan_surgery. + + The `init` field controls weight initialization mode: + - `init: transfer` β†’ use weight transfer/conversion + - `init: random` β†’ use random initialization + + compose_configs must preserve `init` so plan_surgery can see it. + Stripping happens only at final output (when saving to disk). + """ + + def test_init_random_produces_init_expression(self, base_config_with_bias_dict): + """Surgery with init: random produces Init expressions in plan.""" + surgery = { + "decoder": { + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"init": "transfer"}, + "gdn": {"type": "gdn", "init": "random", "convolution_layer": {"kernel_size": 4}}, + }, + }, + }, + }, + } + + target = compose_configs(base_config_with_bias_dict, surgery) + plan = plan_surgery(base_config_with_bias_dict, target) + + # Check that GDN weights use Init expressions (random init) + target_keys = {str(k) for k in plan.target_keys()} + gdn_keys = [k for k in target_keys if "gdn" in k.lower()] + + assert len(gdn_keys) > 0, "No GDN keys in plan" + + # Verify at least one GDN weight uses Init (random initialization) + has_init_expr = False + for key in plan.target_keys(): + if "gdn" in str(key).lower(): + expr = plan.mappings[key] + if isinstance(expr, Init): + has_init_expr = True + break + # Also check inside Concat/other composite expressions + if hasattr(expr, 'exprs'): + for sub in expr.exprs: + if isinstance(sub, Init): + has_init_expr = True + break + + assert has_init_expr, "init: random should produce Init expressions for GDN weights" + + def test_init_transfer_produces_ref_expression(self, base_config_with_bias_dict): + """Surgery with init: transfer produces Ref expressions (weight transfer).""" + surgery = { + "decoder": { + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"init": "transfer"}, + }, + }, + }, + }, + } + + target = compose_configs(base_config_with_bias_dict, surgery) + plan = plan_surgery(base_config_with_bias_dict, target) + + # Check that attention weights use Ref expressions (transfer) + has_ref = False + for key in plan.target_keys(): + if "attention" in str(key) and "q_proj.weight" in str(key): + expr = plan.mappings[key] + if isinstance(expr, Ref): + has_ref = True + break + + assert has_ref, "init: transfer should produce Ref expressions for attention weights" + + def test_build_plan_respects_init_random(self, base_config_with_bias_dict): + """build_plan correctly uses init: random for weight initialization.""" + from fast_llm_external_models.apriel2.convert import build_plan + + # Mamba requires many config fields for random init + surgery = { + "decoder": { + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"init": "transfer"}, + "mamba": { + "type": "mamba", + "init": "random", + "d_inner": 512, + "d_state": 16, + "dt_rank": 16, + "d_xb": 64, + "d_conv": 4, + "repeat_kv_before_conv": False, + "conv_bias": True, + "dt_proj_bias": True, + "dt_min": 0.001, + "dt_max": 0.1, + "dt_init_floor": 1e-4, + }, + }, + }, + }, + }, + } + + plan, final_config = build_plan( + base_config_with_bias_dict, + [surgery], + source_format="apriel2", + ) + + # Verify mamba weights use Init (random init) + has_mamba_init = False + for key in plan.target_keys(): + key_str = str(key) + if "mamba" in key_str: + expr = plan.mappings[key] + if isinstance(expr, Init): + has_mamba_init = True + break + + assert has_mamba_init, "build_plan should use Init for init: random components" + + def test_build_plan_strips_init_between_iterations(self, base_config_with_bias_dict): + """build_plan strips init between iterations (T β†’ S conversion). + + This tests that the intermediate state between surgeries has no init fields. + The composed plan will show Init expressions because plan composition + substitutes Ref β†’ Init, but the semantics are correct: GDN is initialized + once (in surgery 1), not re-randomized in surgery 2. + """ + from fast_llm_external_models.apriel2.conversion import ( + compose_configs, strip_init_fields, plan_surgery, compose + ) + + # Surgery 1: Add GDN with random init + surgery1 = { + "decoder": { + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"init": "transfer"}, + "gdn": { + "type": "gdn", + "init": "random", + "convolution_layer": {"kernel_size": 4}, + }, + }, + }, + }, + }, + } + + # Surgery 2: Add sliding window (doesn't mention GDN) + surgery2 = { + "decoder": { + "block": { + "mixer": { + "mixers": { + "sliding_window": {"init": "transfer", "window_size": 512}, + }, + }, + }, + }, + } + + # Simulate build_plan's iteration loop + s0 = base_config_with_bias_dict + + # Iteration 1 + t1 = compose_configs(s0, surgery1) + assert t1["decoder"]["block"]["mixer"]["mixers"]["gdn"].get("init") == "random" + s1 = strip_init_fields(t1) + assert s1["decoder"]["block"]["mixer"]["mixers"]["gdn"].get("init") is None + + # Iteration 2: s1 has no init for GDN + t2 = compose_configs(s1, surgery2) + assert t2["decoder"]["block"]["mixer"]["mixers"]["gdn"].get("init") is None, \ + "GDN should have no init in T2 (wasn't in surgery2, stripped from s1)" + + # plan_surgery(s1, t2) should use Ref for GDN (transfer, not random) + plan2 = plan_surgery(s1, t2) + gdn_uses_ref = False + for key in plan2.target_keys(): + if "gdn" in str(key): + expr = plan2.mappings[key] + if isinstance(expr, Ref): + gdn_uses_ref = True + break + + assert gdn_uses_ref, "plan_surgery(s1, t2) should use Ref for GDN (transfer from s1)" From b6dd6dc563db08de16ba76c89f335cdb8a014818 Mon Sep 17 00:00:00 2001 From: Torsten Scholak Date: Tue, 16 Dec 2025 21:47:34 +0000 Subject: [PATCH 08/12] =?UTF-8?q?Fix=20O(n=C2=B2)=20tokenization=20and=20a?= =?UTF-8?q?dd=20Qwen2=20training=20examples?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Replace apply_chat_template_with_spans with tokenize_chat (O(n) token-level) - Add _mask_to_spans helper to convert boolean mask to loss masking spans - Fix chat template docs: entire assistant turn must be in {% generation %} - Add parameterized tests with exact expected tokens and trainable indices - Add prepare_tulu3.yaml and train_supernet_qwen2.yaml examples - Document performance tuning (~8k tokens/s, ~61GB memory, ~25h for 1B tokens) πŸ€– Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- .../data/preparator/gpt_memmap/prepare.py | 44 ++-- fast_llm/data/preprocessing/tokenizer.py | 77 +++---- .../apriel2/examples/prepare_tulu3.yaml | 103 ++++++++++ .../examples/train_supernet_qwen2.yaml | 193 ++++++++++++++++++ tests/data/test_tokenizer.py | 89 ++++++-- 5 files changed, 427 insertions(+), 79 deletions(-) create mode 100644 fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml create mode 100644 fast_llm_external_models/apriel2/examples/train_supernet_qwen2.yaml diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index f349b1979..a9beca42f 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -220,21 +220,25 @@ def _preprocessing_config(self) -> LanguageModelPreprocessingConfig: ) def _prepare_sample(self, sample: dict[str, typing.Any]) -> LanguageModelSample: - all_spans = [] - if self._source_schema.has_conversation: - # Conversation format: apply chat template and compute loss masking spans - messages = sample[self._source_schema.messages] - text, loss_masking_spans = self._tokenizer.apply_chat_template_with_spans( - messages, - add_generation_prompt=self._source_schema.add_generation_prompt, + tokens, train_mask = self._tokenizer.tokenize_chat( + sample[self._source_schema.messages], + self._source_schema.add_generation_prompt, + data_type=self._data_type, + ) + return LanguageModelSample( + TokenSample(tokens, [len(tokens)]), + RangeSample(_mask_to_spans(train_mask), len(tokens)), + None, + None, + None, ) - all_spans.extend([(SpanType.loss_masking, span) for span in loss_masking_spans]) - else: - # Plain text format - text = sample[self._source_schema.text] - if self._source_schema.has_loss_masking_span and not self._source_schema.has_conversation: + # Text format: use the text-spans pipeline + text = sample[self._source_schema.text] + all_spans = [] + + if self._source_schema.has_loss_masking_span: # Spans are typically stored in the (begin, last) format. We convert to (begin, end) range format. loss_masking_spans = _sort_spans( (SpanType.loss_masking, (begin, last + 1)) @@ -495,3 +499,19 @@ def _get_nearest_split(cumsum: np.ndarray, value: float) -> int: if left == len(cumsum): return left.item() return left.item() + 1 if (value - cumsum[left]) / (cumsum[left + 1] - cumsum[left]) > 0.5 else left.item() + + +def _mask_to_spans(mask: list[bool]) -> list[tuple[int, int]]: + """Convert a boolean train mask to loss masking spans (where mask[i] == False).""" + spans = [] + start = None + for i, value in enumerate(mask): + if not value: + if start is None: + start = i + elif start is not None: + spans.append((start, i)) + start = None + if start is not None: + spans.append((start, len(mask))) + return spans diff --git a/fast_llm/data/preprocessing/tokenizer.py b/fast_llm/data/preprocessing/tokenizer.py index 372d8cd90..f3b5a51a8 100644 --- a/fast_llm/data/preprocessing/tokenizer.py +++ b/fast_llm/data/preprocessing/tokenizer.py @@ -242,32 +242,17 @@ def validate_chat_template(self) -> None: "Please use a tokenizer with generation markers in its chat template." ) - def apply_chat_template_with_spans( + def tokenize_chat( self, messages: list[dict[str, str]], - *, add_generation_prompt: bool = False, - ) -> tuple[str, list[tuple[int, int]]]: - """ - Apply the tokenizer's chat template to messages and compute loss masking spans. - - This method converts a list of messages (OpenAI/Tulu format) into formatted - text and computes character-level spans that should be MASKED (not trained on). - - Note: Call validate_chat_template() once before using this method to ensure - the tokenizer has a valid chat template with generation markers. - - Args: - messages: List of message dicts with 'role' and 'content' keys. - add_generation_prompt: Whether to add a generation prompt at the end. + begin: bool = True, + end: bool = True, + data_type: DataType = DataType.int64, + ) -> tuple["torch.Tensor", list[bool]]: + """Apply chat template and return (tokens, train_mask) where train_mask[i]=True means train on token i.""" + import torch - Returns: - Tuple of (formatted_text, loss_masking_spans) where loss_masking_spans - is a list of (start, end) character positions to MASK (not train on). - """ - if not messages: - return "", [] - # Get tokens and assistant mask result = self.tokenizer.apply_chat_template( messages, tokenize=True, @@ -275,40 +260,24 @@ def apply_chat_template_with_spans( return_dict=True, add_generation_prompt=add_generation_prompt, ) - tokens = result["input_ids"] train_mask = result["assistant_masks"] - # Get text for output - full_text = self.tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=add_generation_prompt, - ) + # Prepend BOS / append EOS if needed (avoid O(n) insert) + prepend_bos = begin and (not tokens or tokens[0] != self.bod_id) + append_eos = end and (not tokens or tokens[-1] != self.eod_id) + tokens = [self.bod_id] * prepend_bos + list(tokens) + [self.eod_id] * append_eos + train_mask = [False] * prepend_bos + [bool(m) for m in train_mask] + [False] * append_eos - # Convert token mask to character spans using detokenization - # We need spans for tokens where train_mask=0 (should be masked/not trained on) - loss_masking_spans = [] - current_span_start = None - - # Track character positions by decoding incrementally - char_positions = [0] - for i in range(len(tokens)): - decoded = self.tokenizer.decode(tokens[: i + 1]) - char_positions.append(len(decoded)) - - for i, is_train in enumerate(train_mask): - if not is_train: # This token should be masked - if current_span_start is None: - current_span_start = char_positions[i] - else: # This token should be trained on - if current_span_start is not None: - loss_masking_spans.append((current_span_start, char_positions[i])) - current_span_start = None - - # Close any open span - if current_span_start is not None: - loss_masking_spans.append((current_span_start, char_positions[-1])) - - return full_text, loss_masking_spans + if self._config.max_vocab_size is not None: + tokens = ( + torch.tensor( + tokens, + dtype=torch.int64 if len(self.tokenizer) > torch.iinfo(data_type.torch).max else data_type.torch, + ) + % self._config.max_vocab_size + ).to(data_type.torch) + else: + tokens = torch.tensor(tokens, dtype=data_type.torch) + return tokens, train_mask diff --git a/fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml b/fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml new file mode 100644 index 000000000..ba85c1aed --- /dev/null +++ b/fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml @@ -0,0 +1,103 @@ +# Dataset preparation config for Tulu 3 SFT mixture with Qwen2 tokenizer +# +# This config converts the Tulu 3 SFT dataset (conversation format) to +# Fast-LLM's memmap format, with automatic loss masking span computation +# to train only on assistant responses. +# +# ============================================================================= +# TOKENIZER SETUP (one-time) +# ============================================================================= +# +# The tokenizer must have a chat template with {% generation %} markers. +# Qwen2's default template doesn't have these, so we need to patch it. +# +# IMPORTANT: The entire assistant turn (opening tag + content + closing tag) +# must be inside the {% generation %} block. This ensures the model learns to +# produce the full assistant response including special tokens like <|im_end|>. +# Reference: https://huggingface.co/HuggingFaceTB/SmolLM3-3B/blob/main/chat_template.jinja +# +# Run this Python script to create a patched tokenizer: +# +# from transformers import AutoTokenizer +# +# tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") +# +# # Patch chat template: wrap ENTIRE assistant turn in generation markers +# tokenizer.chat_template = '''{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system +# You are a helpful assistant.<|im_end|> +# ' }}{% endif %}{% if message['role'] == 'assistant' %}{% generation %}{{ '<|im_start|>assistant +# ' + message['content'] + '<|im_end|> +# ' }}{% endgeneration %}{% else %}{{ '<|im_start|>' + message['role'] + ' +# ' + message['content'] + '<|im_end|> +# ' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant +# ' }}{% endif %}''' +# +# tokenizer.save_pretrained("/path/to/qwen2-instruct-with-markers") +# +# ============================================================================= +# DATA PREPARATION +# ============================================================================= +# +# Small dataset (for testing): +# +# fast-llm prepare gpt_memmap \ +# -c fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml \ +# dataset.split=train[:1000] \ +# output_path=/path/to/tulu3-prepared-small +# +# Full dataset (~939K samples, ~6 minutes): +# +# fast-llm prepare gpt_memmap \ +# -c fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml +# +# ============================================================================= +# VERIFICATION +# ============================================================================= +# +# To verify the prepared dataset has loss masking spans: +# +# import pathlib +# from fast_llm.data.dataset.memmap import MemmapDataset +# from fast_llm.data.sample.language_model import LanguageModelSample +# from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig +# +# dataset = MemmapDataset[LanguageModelSample]( +# 'tulu3', +# pathlib.Path('/path/to/tulu3-prepared/shard_0_0.fast_llm_dataset'), +# LanguageModelPreprocessingConfig(use_loss_masking_spans=True) +# ) +# +# doc = dataset.get_document(0) +# print(f'Tokens: {len(doc.tokens.tokens)}') +# print(f'Loss masking spans: {doc.loss_masking_spans.ranges}') +# +# ============================================================================= + +# Dataset configuration +dataset: + # Tulu 3 SFT mixture from AllenAI + path: allenai/tulu-3-sft-mixture + split: train + + # Source schema for conversation format + source_schema: + # Use conversation type (vs default "text" type) + type: conversation + + # Column containing the messages list + messages: messages + +# Tokenizer configuration +# IMPORTANT: Must use a tokenizer with {% generation %} markers in its chat template. +# See instructions above to create a patched Qwen2 tokenizer. +tokenizer: + path: /path/to/qwen2-instruct-with-markers + # Qwen2 doesn't have a BOS token by default, use <|endoftext|> as BOS + bos_token: "<|endoftext|>" + +# Output configuration +output_path: /path/to/tulu3-prepared + +# Processing configuration +num_workers: 8 +documents_per_shard: 100000 diff --git a/fast_llm_external_models/apriel2/examples/train_supernet_qwen2.yaml b/fast_llm_external_models/apriel2/examples/train_supernet_qwen2.yaml new file mode 100644 index 000000000..5b190955f --- /dev/null +++ b/fast_llm_external_models/apriel2/examples/train_supernet_qwen2.yaml @@ -0,0 +1,193 @@ +# Training config for Qwen2-based Apriel2 stochastic supernet on Tulu 3 SFT data +# +# This config trains a stochastic supernet where each layer can sample from +# multiple mixer types (attention, sliding window, gated delta net, KDA). +# Only the mixer weights are trained; all other weights are frozen. +# Activation-level distillation from a teacher model guides the training. +# +# ============================================================================= +# PREREQUISITES +# ============================================================================= +# +# 1. TOKENIZER SETUP +# +# Qwen2's default chat template doesn't have generation markers needed for +# loss masking. Create a patched tokenizer following the SmolLM3 pattern: +# https://huggingface.co/HuggingFaceTB/SmolLM3-3B/blob/main/chat_template.jinja +# +# IMPORTANT: The ENTIRE assistant turn (opening tag + content + closing tag) +# must be inside {% generation %}...{% endgeneration %} markers. +# +# from transformers import AutoTokenizer +# tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") +# # Wrap entire assistant turn in generation markers (NOT just content!) +# tokenizer.chat_template = '''{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system +# You are a helpful assistant.<|im_end|> +# ' }}{% endif %}{% if message['role'] == 'assistant' %}{% generation %}{{ '<|im_start|>assistant +# ' + message['content'] + '<|im_end|> +# ' }}{% endgeneration %}{% else %}{{ '<|im_start|>' + message['role'] + ' +# ' + message['content'] + '<|im_end|> +# ' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant +# ' }}{% endif %}''' +# tokenizer.save_pretrained("/path/to/qwen2-instruct-with-markers") +# +# 2. PREPARE TULU 3 DATASET +# +# Small dataset (for testing): +# +# fast-llm prepare gpt_memmap \ +# -c fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml \ +# tokenizer.path=/path/to/qwen2-instruct-with-markers \ +# dataset.split=train[:1000] \ +# output_path=/path/to/tulu3-prepared-small +# +# Full dataset (~939K samples, ~6 minutes): +# +# fast-llm prepare gpt_memmap \ +# -c fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml \ +# tokenizer.path=/path/to/qwen2-instruct-with-markers \ +# output_path=/path/to/tulu3-prepared +# +# 3. CONVERT QWEN2 TO APRIEL2 SUPERNET (student model) +# +# This creates a stochastic supernet with multiple mixer types per layer: +# +# python fast_llm_external_models/apriel2/convert.py \ +# Qwen/Qwen2.5-0.5B-Instruct \ +# /path/to/qwen2-supernet \ +# --surgery fast_llm_external_models/apriel2/examples/stochastic_supernet.yaml +# +# 4. CONVERT QWEN2 TO APRIEL2 (teacher model) +# +# The teacher is the original model without surgery, used for distillation: +# +# python fast_llm_external_models/apriel2/convert.py \ +# Qwen/Qwen2.5-0.5B-Instruct \ +# /path/to/qwen2-teacher +# +# 5. RUN TRAINING +# +# Update paths below and run: +# +# fast-llm train gpt \ +# -c fast_llm_external_models/apriel2/examples/train_supernet_qwen2.yaml +# +# For long runs, use nohup: +# +# nohup fast-llm train gpt \ +# -c fast_llm_external_models/apriel2/examples/train_supernet_qwen2.yaml \ +# > training.log 2>&1 & +# tail -f training.log +# +# ============================================================================= +# PERFORMANCE TUNING +# ============================================================================= +# +# Default config uses seq=4096, micro_batch=2, batch=16 which gives: +# - ~8k tokens/s/gpu throughput +# - ~61GB GPU memory usage +# - ~25 hours for 1B tokens on single GPU +# +# Adjust batch settings based on your GPU memory: +# - Reduce micro_batch_size if OOM +# - Increase micro_batch_size/batch_size if memory available +# +# ============================================================================= +# OUTPUT +# ============================================================================= +# +# Checkpoints: /path/to/qwen2-supernet-trained/checkpoints/{iteration}/ +# Exports: /path/to/qwen2-supernet-trained/export/apriel2_text/{iteration}/ +# +# ============================================================================= + +# Load pretrained model (Qwen2 converted to Apriel2 supernet) +pretrained: + path: /path/to/qwen2-supernet + format: apriel2_text + model_weights: true + load_config: model + +# Model config +model: + base_model: + # Freeze all components except the mixer + decoder: + block: + mlp: + lr_scale: 0.0 # Freeze MLP + normalization: + lr_scale: 0.0 # Freeze layer norms + # Activation-level distillation from teacher + distillation_model: teacher + activation_distillation_factor: 0.8 + embeddings: + lr_scale: 0.0 # Freeze word embeddings + head: + lr_scale: 0.0 # Freeze output head + cross_entropy_implementation: torch + multi_stage: + zero_stage: 2 + distributed: + compute_dtype: bf16 + seed: 42 + +# Teacher model for activation-level distillation +reference_models: + teacher: + model: + type: gpt + pretrained: + path: /path/to/qwen2-teacher + format: apriel2_text + model_weights: true + load_config: model + +# Batch configuration (tuned for ~61GB GPU memory, ~8k tokens/s) +batch: + sequence_length: 4096 + micro_batch_size: 2 + batch_size: 16 + +# Data configuration (prepared Tulu 3 dataset) +data: + datasets: + training: + type: file + path: /path/to/tulu3-prepared/fast_llm_config.yaml + +# Optimizer configuration +optimizer: + learning_rate: + base: 1.0e-05 + decay_style: cosine + warmup_iterations: 100 + decay_iterations: 10000 + minimum: 1.0e-06 + weight_decay: 0.1 + beta_1: 0.9 + beta_2: 0.95 + +# Training configuration +# At seq=4096, batch=16: ~65k tokens/iter, ~280 iters/hour +# 10000 iters β‰ˆ 650M tokens β‰ˆ 35 hours +training: + train_iters: 10000 + num_workers: 4 + logs: + interval: 10 + checkpoint: + interval: 280 # ~hourly + export: + interval: 280 # ~hourly (useful for development/testing during training) + format: apriel2_text + test_iters: 0 + evaluators: {} + # Weights & Biases configuration (optional, uncomment to enable) + # wandb: + # entity_name: your-entity + # project_name: your-project + +# Experiment directory +run: + experiment_dir: /path/to/qwen2-supernet-trained diff --git a/tests/data/test_tokenizer.py b/tests/data/test_tokenizer.py index 4b8f45d8d..97f16c6d6 100644 --- a/tests/data/test_tokenizer.py +++ b/tests/data/test_tokenizer.py @@ -61,10 +61,13 @@ def test_validate_chat_template_with_markers(common_tokenizer): common_tokenizer.validate_chat_template() +# Realistic chat template following HF conventions (e.g., SmolLM3): +# The generation block includes the full assistant turn: opening tag, content, and closing tag. +# This ensures the model learns to emit the closing tag. CHAT_TEMPLATE = ( "{% for message in messages %}" "{% if message.role == 'assistant' %}" - "{% generation %}{{ message.content }}{% endgeneration %}" + "{% generation %}{{ message.content }}{% endgeneration %}" "{% else %}" "<{{ message.role }}>{{ message.content }}" "{% endif %}" @@ -73,24 +76,84 @@ def test_validate_chat_template_with_markers(common_tokenizer): @pytest.mark.parametrize( - ("messages", "expected_text", "expected_spans"), + ("messages", "expected_tokens", "expected_trainable_indices"), ( - ([], "", []), + # Single turn: full assistant turn (Hello) is trainable ( [{"role": "user", "content": "Hi"}, {"role": "assistant", "content": "Hello"}], - "HiHello", - [(0, 26), (31, 43)], + [49152, 27, 789, 29, 16946, 750, 789, 2293, 17822, 29, 7371, 750, 17822, 29, 49152], + [7, 8, 9, 10, 11, 12, 13], ), + # Multi-turn: both assistant turns are fully trainable ( - [{"role": "user", "content": "A"}, {"role": "assistant", "content": "B"}, {"role": "user", "content": "C"}, {"role": "assistant", "content": "D"}], - "ABCD", - [(0, 25), (26, 63), (64, 76)], + [ + {"role": "user", "content": "A"}, + {"role": "assistant", "content": "B"}, + {"role": "user", "content": "C"}, + {"role": "assistant", "content": "D"}, + ], + [49152, 27, 789, 29, 32, 750, 789, 2293, 17822, 29, 33, 750, 17822, 2293, 789, 29, 34, 750, 789, 2293, 17822, 29, 35, 750, 17822, 29, 49152], + [7, 8, 9, 10, 11, 12, 13, 19, 20, 21, 22, 23, 24, 25], + ), + # System + user + assistant: full assistant turn trainable + ( + [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello"}, + ], + [49152, 27, 3144, 29, 5815, 1139, 44569, 6928, 3144, 2293, 789, 29, 16946, 750, 789, 2293, 17822, 29, 7371, 750, 17822, 29, 49152], + [15, 16, 17, 18, 19, 20, 21], + ), + # User only: no trainable tokens + ( + [{"role": "user", "content": "Hi"}], + [49152, 27, 789, 29, 16946, 750, 789, 29, 49152], + [], + ), + # Long multi-turn (85 tokens, 3 assistant responses with tags, tests span machinery) + ( + [ + {"role": "system", "content": "You are a helpful assistant that answers questions."}, + {"role": "user", "content": "What is the capital of France?"}, + {"role": "assistant", "content": "The capital of France is Paris."}, + {"role": "user", "content": "What about Germany?"}, + {"role": "assistant", "content": "The capital of Germany is Berlin."}, + {"role": "user", "content": "And Italy?"}, + {"role": "assistant", "content": "The capital of Italy is Rome."}, + ], + [49152, 27, 3144, 29, 5815, 1139, 373, 44569, 2424, 11886, 954, 15737, 14516, 6928, 3144, 2293, 789, 29, 13938, 438, 331, 25016, 457, 12409, 562, 35838, 789, 2293, 17822, 29, 2111, 25016, 457, 12409, 562, 438, 4235, 280, 6928, 17822, 2293, 789, 29, 13938, 5028, 759, 42226, 35838, 789, 2293, 17822, 29, 2111, 25016, 457, 759, 42226, 438, 29784, 3556, 6928, 17822, 2293, 789, 29, 1996, 4413, 3326, 35838, 789, 2293, 17822, 29, 2111, 25016, 457, 4413, 3326, 438, 613, 1361, 6928, 17822, 29, 49152], + list(range(27, 41)) + list(range(49, 63)) + list(range(70, 84)), ), ), ) -def test_apply_chat_template_with_spans(common_tokenizer, messages, expected_text, expected_spans): - """Chat template produces correct text and masking spans.""" +def test_tokenize_chat(common_tokenizer, messages, expected_tokens, expected_trainable_indices): common_tokenizer.tokenizer.chat_template = CHAT_TEMPLATE - text, spans = common_tokenizer.apply_chat_template_with_spans(messages) - Assert.eq(text, expected_text) - Assert.eq(spans, expected_spans) + tokens, train_mask = common_tokenizer.tokenize_chat(messages) + Assert.eq(tokens.tolist(), expected_tokens) + Assert.eq([i for i, m in enumerate(train_mask) if m], expected_trainable_indices) + + +@pytest.mark.parametrize( + ("train_mask", "expected_loss_spans"), + ( + # All masked (no trainable tokens) + ([False, False, False], [(0, 3)]), + # All trainable (no spans) + ([True, True, True], []), + # Single trainable at start + ([True, False, False], [(1, 3)]), + # Single trainable at end + ([False, False, True], [(0, 2)]), + # Single trainable in middle + ([False, True, False], [(0, 1), (2, 3)]), + # Multiple trainable regions (simulates multi-turn conversation) + ([False, False, True, True, False, False, True, True, True, False], [(0, 2), (4, 6), (9, 10)]), + # Alternating + ([False, True, False, True, False], [(0, 1), (2, 3), (4, 5)]), + ), +) +def test_mask_to_spans(train_mask, expected_loss_spans): + from fast_llm.data.preparator.gpt_memmap.prepare import _mask_to_spans + + Assert.eq(_mask_to_spans(train_mask), expected_loss_spans) From f61a6d1088d1124afce4e0cd05ef4396134d7b77 Mon Sep 17 00:00:00 2001 From: Torsten Scholak Date: Tue, 16 Dec 2025 21:53:16 +0000 Subject: [PATCH 09/12] Improve Apriel2 conversion config composition and documentation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Refactor config.py with clearer algebraic structure documentation - Document State (S), Partial Surgery (P), and Transition Spec (T) types - Clarify monoid structure and action laws for config composition - Update activation_distillation_factor from 0.1 to 0.8 in small example πŸ€– Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- .../apriel2/conversion/__init__.py | 159 +++++++----- .../apriel2/conversion/config.py | 233 ++++++++++++------ .../apriel2/conversion/converters.py | 73 ++++-- fast_llm_external_models/apriel2/convert.py | 17 +- .../examples/train_supernet_small.yaml | 2 +- 5 files changed, 323 insertions(+), 161 deletions(-) diff --git a/fast_llm_external_models/apriel2/conversion/__init__.py b/fast_llm_external_models/apriel2/conversion/__init__.py index 60fc0ef0a..c6bad6626 100644 --- a/fast_llm_external_models/apriel2/conversion/__init__.py +++ b/fast_llm_external_models/apriel2/conversion/__init__.py @@ -1,86 +1,122 @@ """Weight conversion system for Apriel2 models. -Architecture Overview -===================== +Overview +======== -This package implements a declarative weight transformation system with two -orthogonal concerns: +This package implements a declarative weight transformation system. The core +abstraction separates config composition (structural) from plan execution (weights). -1. **Config Composition** - Structural transformations of model configs -2. **Plan Building & Execution** - Weight transformations between configs +Conceptual Types +================ -These concerns are intentionally separated: -- Config composition determines WHAT the target architecture looks like -- Plan building determines HOW weights are transformed to match -- The `init` field bridges them: it's config metadata consumed by the plan builder +All configs are ``dict``, but we distinguish three conceptual types: -Key Design Decisions -==================== +**State (S)** - A complete model config without ``init`` fields. + What you load from disk or save after conversion. -**Declarative Plans** - Plans are DATA (JSON-serializable expressions), not functions. This enables: - - Inspection and debugging of transformations - - Serialization for distributed execution - - Composition via substitution rather than function composition - -**Separation of Config and Weights** - The `init` field in surgery specs controls weight handling (transfer vs random) - but does NOT affect config composition. Config composition is purely structural. - After composition, `init` fields are stripped from complete configs. - -**Composition Semantics** - Surgery specs use declarative (merge) composition, not operational (function) - composition. For "additive" surgeries (modifying existing structure), the - monoid action law holds. For "replacement" surgeries (defining complete new - structure), sequential application differs from composed application by design. - -**Cross-Type Derivation** - When converting between mixer types (e.g., attention β†’ mamba), geometric - parameters are derived where possible: - - attention.heads β†’ mamba dimensions (MIL conversion) - - attention.heads β†’ gdn heads (DIL conversion) +**Partial Surgery (P)** - An incomplete config specifying changes. + May contain ``init`` fields (``transfer`` or ``random``). -Module Structure -================ +**Transition Spec (T)** - A complete config WITH ``init`` fields. + The result of applying surgery to a state. Describes both target + structure and weight initialization mode. + +Algebraic Structure +=================== + +**Monoid**: Partial surgeries compose via deep merge:: + + compose_configs : P Γ— P β†’ P + +**Action**: Surgeries act on states to produce transition specs:: + + compose_configs : S Γ— P β†’ T + compose_configs : T Γ— P β†’ T + +**Extraction**: Strip init to get a state:: + + strip_init_fields : T β†’ S + +**Planning**: Build weight transformation from source state + transition spec:: + + plan_surgery : S Γ— T β†’ Plan + +The ``init`` Field +================== + +The ``init`` field in surgeries specifies weight initialization: -- `config.py` - Config composition (compose_configs, apply_surgery) -- `converters.py` - Plan builders (plan_surgery, plan_mil_attention_to_mamba, etc.) -- `expr.py` - Expression types and plan class (Ref, Slice, Concat, Init, ExprPlan) -- `executor.py` - Plan execution (StreamingExecutor, execute) -- `io.py` - Streaming I/O (SafetensorLoader, ShardedSafetensorWriter) -- `llava/` - Source-specific converter for Llava β†’ Apriel2 +- ``init: transfer`` β†’ transfer/convert weights from source +- ``init: random`` β†’ randomly initialize weights -Example Usage +This field is preserved through ``compose_configs`` so ``plan_surgery`` can read it. +Use ``strip_init_fields`` before saving configs to disk. + +Typical Usage ============= +:: + from fast_llm_external_models.apriel2.conversion import ( compose_configs, plan_surgery, + strip_init_fields, execute, ) - # 1. Compose configs to get target architecture - target_config = compose_configs(source_config, surgery_spec) + # Load source state + source_state = load_config(...) # S - # 2. Build plan for weight transformation - plan = plan_surgery(source_config, target_config) + # Apply surgery + surgery = {"decoder": {"block": {"mixer": {"type": "gdn", "init": "random"}}}} # P + transition = compose_configs(source_state, surgery) # T - # 3. Execute plan to transform weights - target_weights = execute(plan, source_weights, seed=42) + # Build and execute plan + plan = plan_surgery(source_state, transition) + weights = execute(plan, source_weights, seed=42) -For streaming I/O with large models: + # Save (strip init first) + target_state = strip_init_fields(transition) # S + save_config(target_state) - from fast_llm_external_models.apriel2.conversion import ( - StreamingExecutor, - SafetensorLoader, - ShardedSafetensorWriter, - ) +For chained surgeries:: + + current_state = source_state # S + current_plan = identity_plan + + for surgery in surgery_chain: # each P + transition = compose_configs(current_state, surgery) # T + plan = plan_surgery(current_state, transition) + current_plan = compose(current_plan, plan) + current_state = strip_init_fields(transition) # S <- IMPORTANT! + +**Note**: The ``strip_init_fields`` call is critical. It ensures that ``init: random`` +applies only to the surgery that introduces a component. Without stripping, subsequent +surgeries would re-randomize existing components. See ``config.py`` docstring for details. + +Key Design Decisions +==================== + +**Declarative Plans** + Plans are data (expressions), not functions. Enables inspection, + serialization, and composition via substitution. + +**Inheritance Semantics** + When S Γ— P β†’ T, unspecified fields inherit from source. + Cross-type derivation maps geometry (attention.heads β†’ gdn.value_heads). + +**Additive vs Replacement Surgeries** + Additive surgeries (no ``type:`` declaration) satisfy the action law. + Replacement surgeries (explicit ``type:``) use last-write-wins. + +Module Structure +================ - with SafetensorLoader(source_files) as loader: - executor = StreamingExecutor(plan, loader) - with ShardedSafetensorWriter(output_dir) as writer: - for key, tensor in executor.execute(seed=42): - writer.add(key, tensor) +- ``config.py`` - Config composition (compose_configs, strip_init_fields) +- ``converters.py`` - Plan builders (plan_surgery, plan_mil_attention_to_mamba) +- ``expr.py`` - Expression types (Ref, Slice, Concat, Init, ExprPlan) +- ``executor.py`` - Plan execution (StreamingExecutor, execute) +- ``io.py`` - Streaming I/O (SafetensorLoader, ShardedSafetensorWriter) """ # Core types and plan operations @@ -127,7 +163,7 @@ ) # Config composition -from fast_llm_external_models.apriel2.conversion.config import compose_configs +from fast_llm_external_models.apriel2.conversion.config import compose_configs, strip_init_fields # Source-specific converters from fast_llm_external_models.apriel2.conversion.llava import ( @@ -175,6 +211,7 @@ "plan_kil_attention_to_kda", # Config composition "compose_configs", + "strip_init_fields", # Source-specific converters "convert_llava_config", "plan_llava_to_apriel2", diff --git a/fast_llm_external_models/apriel2/conversion/config.py b/fast_llm_external_models/apriel2/conversion/config.py index f5b19e208..3752688c1 100644 --- a/fast_llm_external_models/apriel2/conversion/config.py +++ b/fast_llm_external_models/apriel2/conversion/config.py @@ -1,59 +1,136 @@ """Config composition for Apriel2 architecture transformations. -This module handles STRUCTURAL composition of configs, independent of weight handling. -The `init` field in surgery specs is metadata for plan_surgery(), not for composition. +Conceptual Types +================ + +The system operates on three conceptual types, all represented as ``dict``: + +**State (S)** + A complete structural description of a model. Has ``hidden_size`` and ``decoder``. + Does NOT contain ``init`` fields. Represents WHAT a model looks like. + + Example: A saved config.json, or a model you're about to transform. + +**Partial Surgery (P)** + An incomplete config specifying fields to change. Missing ``hidden_size`` or + ``decoder``. May contain ``init`` fields specifying weight initialization mode. + + Example: ``{"decoder": {"block": {"mixer": {"type": "gdn", "init": "random"}}}}`` + +**Transition Spec (T)** + A complete config WITH ``init`` fields. Describes both the target structure + AND how to initialize weights. This is the output of applying a surgery to + a state - it's a complete specification of the transformation. + + Example: The result of ``compose_configs(state, surgery)`` before stripping. + +The distinction between S and T is semantic (presence of ``init``), not structural. +Both are "complete" in the sense of having ``hidden_size`` and ``decoder``. Algebraic Structure =================== -The system has a precise algebraic structure with two interacting components: +**Partial Surgeries form a Monoid (P, ∘, {})**:: + + compose_configs : P Γ— P β†’ P (deep merge, overlay wins) + + Identity: compose_configs(p, {}) = compose_configs({}, p) = p + Associativity: compose_configs(compose_configs(a, b), c) + = compose_configs(a, compose_configs(b, c)) + +**Surgeries act on States to produce Transition Specs**:: + + compose_configs : S Γ— P β†’ T (apply surgery with inheritance) + compose_configs : T Γ— P β†’ T (extend transition with more surgery) -**Surgery Specs (Monoid)** - Partial config dicts form a monoid under deep merge: - - Identity: {} (empty dict) - - Operation: compose_configs(partial1, partial2) = deep_merge(partial1, partial2) - - Associativity: (a ∘ b) ∘ c = a ∘ (b ∘ c) +**Action Law (for additive surgeries)**:: -**Complete Configs (Monoid Action)** - Surgery specs ACT on complete configs: - - Action: compose_configs(complete, partial) β†’ complete - - For additive surgeries: (s Β· t₁) Β· tβ‚‚ = s Β· (t₁ ∘ tβ‚‚) - - For replacement surgeries: action law intentionally fails (last-write-wins) + compose_configs(compose_configs(s, p₁), pβ‚‚) = compose_configs(s, compose_configs(p₁, pβ‚‚)) -This separation is fundamental: surgery specs compose declaratively (what fields to -merge), while the action on configs interprets those fields with inheritance semantics. +This law holds when surgeries are "additive" (modifying existing structure without +declaring new types). For "replacement" surgeries (explicitly declaring ``type:``), +the action law intentionally fails - this is last-write-wins semantics. -Composition Cases -================= +**State Extraction**:: -compose_configs(base, overlay) dispatches based on completeness: + strip_init_fields : T β†’ S (remove init metadata for saving) -1. **Complete + Partial** β†’ Monoid action (inheritance, cross-type derivation) -2. **Partial + Partial** β†’ Monoid operation (deep merge) -3. **Partial + Complete** β†’ Overlay wins (complete replaces partial) -4. **Complete + Complete** β†’ Deep merge, strip `init` fields +Operations Summary +================== -A config is "complete" if it has `hidden_size` and `decoder`. +``compose_configs(base, overlay)`` dispatches based on completeness: + +1. **S Γ— P β†’ T** : Apply surgery to state (inheritance, cross-type derivation) +2. **T Γ— P β†’ T** : Extend transition spec with more surgery +3. **P Γ— P β†’ P** : Merge partial surgeries (monoid operation) +4. **S Γ— S β†’ S** : Merge states (deep merge, rare) +5. **P Γ— S β†’ S** : Overlay wins (complete replaces partial) + +``strip_init_fields(config)`` removes all ``init`` fields, converting T β†’ S. Inheritance Semantics ===================== -When the monoid action applies a surgery to a complete config: +When applying a surgery (S Γ— P β†’ T): -- Unspecified fields inherit from source -- New blocks inherit from the "default" block +- Unspecified fields inherit from source state +- New decoder blocks inherit from the "default" block - Cross-type derivation maps geometry (attention.heads β†’ gdn.value_heads, etc.) -- Stochastic mixers: additive (no type decl) preserves source, replacement replaces +- Stochastic mixers: additive surgery preserves source mixers, replacement replaces -The `init` Field -================ +The ``init`` Field +================== + +The ``init`` field specifies weight initialization mode for ``plan_surgery()``: + +- ``init: transfer`` β†’ transfer weights from source (possibly with conversion) +- ``init: random`` β†’ randomly initialize weights + +**Key invariant**: ``init`` is preserved through composition so ``plan_surgery()`` +can read it. Use ``strip_init_fields()`` to obtain a pure state for: + +- Saving to disk (config.json should not contain ``init``) +- Starting the next surgery iteration (current_state should be S, not T) + +Typical Usage Pattern +===================== + +:: + + current_state: S = load_config(...) + + for surgery: P in surgery_chain: + transition: T = compose_configs(current_state, surgery) # S Γ— P β†’ T + plan = plan_surgery(current_state, transition) # plan reads init from T + current_state: S = strip_init_fields(transition) # T β†’ S for next iteration -The `init` field is metadata for plan_surgery(), NOT for config composition: -- `init: transfer` β†’ plan uses weight transfer/conversion -- `init: random` β†’ plan uses random initialization + save_config(current_state) # S has no init fields -After composition produces a complete config, ALL `init` fields are stripped. -This ensures configs are purely structural and plan creation is Markovian. +Sequential vs Merged Surgery Application +======================================== + +**IMPORTANT**: Applying surgeries sequentially (with stripping) differs from merging +surgeries first then applying once. This affects ``init`` semantics: + +**Sequential** (recommended):: + + t1 = compose_configs(s, p1) # GDN gets init: random + s1 = strip_init_fields(t1) # GDN loses init + t2 = compose_configs(s1, p2) # GDN has init: None β†’ transfer mode + +**Merged**:: + + merged = compose_configs(p1, p2) # GDN keeps init: random from p1 + t = compose_configs(s, merged) # GDN has init: random β†’ random mode + +The sequential approach means ``init: random`` applies **only to the surgery that +introduces a component**. Subsequent surgeries transfer existing weights by default. + +This is the intended behavior: if surgery 1 adds GDN with random init, and surgery 2 +adds sliding window (not mentioning GDN), GDN keeps its weights from surgery 1. + +The merged approach would re-randomize GDN in every execution, which is rarely desired. +Always use the sequential pattern shown in "Typical Usage Pattern" above. """ from __future__ import annotations @@ -68,49 +145,42 @@ def is_complete(config: dict) -> bool: def compose_configs(base: dict, overlay: dict | None) -> dict: - """Compose two configs using monoid or monoid action semantics. + """Compose configs. Dispatches based on completeness of arguments. - This function implements two algebraic operations depending on argument types: + Type Signatures (see module docstring for S, P, T definitions):: - 1. **Monoid Action** (complete + partial): Apply surgery to a complete config. - Unspecified fields inherit from base; `init` fields are stripped from result. + S Γ— P β†’ T Apply surgery to state, get transition spec + T Γ— P β†’ T Extend transition spec with more surgery + P Γ— P β†’ P Merge partial surgeries (monoid operation) + S Γ— S β†’ S Merge states (deep merge) + P Γ— S β†’ S Overlay wins - 2. **Monoid Operation** (partial + partial): Merge two surgery specs. - Deep merge with overlay winning on conflicts; `init` fields preserved. + The ``init`` field is preserved in all cases. Use ``strip_init_fields()`` + to convert T β†’ S for saving or iteration. Args: - base: Base config (complete) or surgery spec (partial). - overlay: Surgery spec to apply (partial) or config to merge. + base: State (S), transition spec (T), or partial surgery (P). + overlay: Partial surgery (P) or state (S). Returns: - - If base is complete: Complete config with surgery applied, `init` stripped. - - If both partial: Merged surgery spec with `init` preserved. + Composed config. Type depends on inputs (see signatures above). Algebraic Properties: - Surgery specs form a monoid: (a ∘ b) ∘ c = a ∘ (b ∘ c), identity = {} - - For additive surgeries, the action law holds: - compose(compose(s, t1), t2) == compose(s, compose(t1, t2)) - - For replacement surgeries (declaring type:), action law intentionally fails. + Monoid: ``compose(compose(p1, p2), p3) == compose(p1, compose(p2, p3))`` - Example: - # Apply surgery to complete config (monoid action) - source = {"hidden_size": 256, "decoder": {...}} # complete - surgery = {"decoder": {"block": {"mixer": {"type": "mamba"}}}} # partial + Action law (additive surgeries): + ``compose(compose(s, p1), p2) == compose(s, compose(p1, p2))`` - target = compose_configs(source, surgery) - # target is complete with inherited fields, init stripped + Example:: - # Merge two surgery specs (monoid operation) - s1 = {"decoder": {"block": {"mixer": {"mixers": {"a": {...}}}}}} - s2 = {"decoder": {"block": {"mixer": {"mixers": {"b": {...}}}}}} + # S Γ— P β†’ T (apply surgery to state) + state = {"hidden_size": 256, "decoder": {...}} + surgery = {"decoder": {"block": {"mixer": {"init": "random"}}}} + transition = compose_configs(state, surgery) # T, has init - merged = compose_configs(s1, s2) - # merged has both mixers a and b, init preserved - - # Use composed config with plan_surgery - plan = plan_surgery(source, target) + # Build plan, then extract state + plan = plan_surgery(state, transition) + new_state = strip_init_fields(transition) # S, no init """ if not overlay: return copy.deepcopy(base) @@ -132,9 +202,8 @@ def compose_configs(base: dict, overlay: dict | None) -> dict: if not base_complete and overlay_complete: return copy.deepcopy(overlay) - # Case 4: Both complete -> deep merge + # Case 4: Both complete -> deep merge (init preserved for plan_surgery) result = _deep_merge(base, overlay) - _strip_keys(result, {"init"}) return result @@ -166,6 +235,29 @@ def _strip_keys(config: Any, keys_to_strip: set[str]) -> None: _strip_keys(item, keys_to_strip) +def strip_init_fields(config: dict) -> dict: + """Return a copy of config with all ``init`` fields stripped (T β†’ S). + + Converts a transition spec (T) to a state (S) by removing ``init`` metadata. + Use this: + + 1. Before saving configs to disk (config.json should be purely structural) + 2. Between surgery iterations (so subsequent surgeries don't re-randomize) + + See module docstring section "Sequential vs Merged Surgery Application" for + why stripping between iterations is critical. + + Args: + config: Config dict (not modified). Typically a transition spec (T). + + Returns: + A deep copy with all ``init`` fields recursively removed (a state S). + """ + result = copy.deepcopy(config) + _strip_keys(result, {"init"}) + return result + + # ============================================================================= # Surgery application with full semantics # ============================================================================= @@ -182,14 +274,14 @@ def apply_surgery(source_config: dict, surgery_config: dict | None) -> dict: - Unspecified fields inherit from source - Cross-type derivation maps geometry (attention β†’ gdn, etc.) - Stochastic sub-mixers inherit from source's main mixer - - All `init` fields stripped from result + - `init` fields are PRESERVED for plan_surgery() to see Args: source_config: Complete Apriel2 config (the state being acted on). surgery_config: Partial surgery spec (the monoid element acting). Returns: - Complete config with surgery applied, `init` fields stripped. + Complete config with surgery applied. `init` fields preserved. """ if not surgery_config: return copy.deepcopy(source_config) @@ -231,8 +323,9 @@ def apply_surgery(source_config: dict, surgery_config: dict | None) -> dict: surgery_config["vision_encoder"], ) - # Strip init keys from final result - _strip_keys(result, {"init"}) + # NOTE: We do NOT strip init keys here. The `init` field is preserved through + # composition so that plan_surgery() can see it and decide between transfer + # vs random initialization. The caller (convert.py) strips init before saving. return result diff --git a/fast_llm_external_models/apriel2/conversion/converters.py b/fast_llm_external_models/apriel2/conversion/converters.py index b54bb5a87..c8b83f657 100644 --- a/fast_llm_external_models/apriel2/conversion/converters.py +++ b/fast_llm_external_models/apriel2/conversion/converters.py @@ -831,44 +831,69 @@ def plan_surgery( source_config: dict, target_config: dict, ) -> ExprPlan: - """Build a weight conversion plan between two Apriel2 configurations. + """Build a weight conversion plan: S Γ— T β†’ Plan. - This function creates an ExprPlan that maps source weight keys to expressions - defining how to compute target weights. The plan handles same-type passthrough, - cross-type conversions (MIL, DIL, KIL), and stochastic mixer routing. + Creates an ExprPlan mapping target weight keys to expressions over source weights. + Handles same-type passthrough, cross-type conversions (MIL, DIL, KIL), and + stochastic mixer routing. + + Type Signature:: + + plan_surgery : S Γ— T β†’ Plan + + Where S is a state (source) and T is a transition spec (target with ``init`` fields). + + The ``init`` Field + ------------------ + + The ``init`` field in ``target_config`` controls weight initialization: + + - ``init: transfer`` (or absent) β†’ create Ref expressions (transfer from source) + - ``init: random`` β†’ create Init expressions (random initialization) + + This is why ``target_config`` should be a transition spec (T) from ``compose_configs``, + not a stripped state (S). If ``init`` fields are missing, all components default to + transfer mode. Args: - source_config: Complete Apriel2 config dict describing the source architecture. - Must have all structural fields (hidden_size, decoder, etc.) fully specified. - target_config: Complete Apriel2 config dict describing the target architecture. - Must be fully specified with all inherited fields resolved. Use - compose_configs(source_config, surgery_spec) to produce this from a - partial surgery specification. + source_config: State (S) - complete config describing source architecture. + Must have hidden_size, decoder, etc. No ``init`` fields expected. + target_config: Transition spec (T) - complete config with ``init`` fields. + Use ``compose_configs(source, surgery)`` to produce this. Returns: ExprPlan mapping target weight keys to expressions over source weights. - Example: + Example:: + # Apply a surgery that wraps attention in a stochastic mixer surgery_spec = { "decoder": {"block": {"mixer": { "type": "stochastic", - "mixers": {"attention": {"type": "attention", "init": "transfer"}} + "mixers": { + "attention": {"init": "transfer"}, + "gdn": {"type": "gdn", "init": "random"}, + } }}} } - # First compose to get complete target config with inherited fields - target_config = compose_configs(source_config, surgery_spec) + # S Γ— P β†’ T + transition = compose_configs(source_config, surgery_spec) + + # S Γ— T β†’ Plan + plan = plan_surgery(source_config, transition) + + # Execute + new_weights = execute(plan, source_weights, seed=42) - # Then build the plan from two complete configs - plan = plan_surgery(source_config, target_config) - new_weights = execute(plan, source_weights, seed=0) + # T β†’ S for saving + target_state = strip_init_fields(transition) Note: - Both arguments must be complete configs. The target_config determines the - full target architecture including all inherited fields (bias settings, - rotary config, etc.). Passing a partial surgery spec directly will result - in missing inherited fields and incorrect plans. + Both arguments must be complete (have hidden_size and decoder). + The target_config should retain ``init`` fields from the surgery spec. + Passing a stripped state as target will cause all components to use + transfer mode, which may not be intended. """ hidden_size = target_config.get("hidden_size", source_config.get("hidden_size")) assert hidden_size is not None, "hidden_size must be specified in source or target config" @@ -922,8 +947,10 @@ def _plan_non_decoder_weights(config: dict) -> ExprPlan: embed = W("model", "embed_tokens", "weight") mappings[embed] = Ref(key=embed) - head = W("lm_head", "weight") - mappings[head] = Ref(key=head) + # lm_head only if not tied to embeddings + if not config.get("tie_word_embeddings", False): + head = W("lm_head", "weight") + mappings[head] = Ref(key=head) norm = W("model", "norm", "weight") mappings[norm] = Ref(key=norm) diff --git a/fast_llm_external_models/apriel2/convert.py b/fast_llm_external_models/apriel2/convert.py index 05c38c7ce..60786d22c 100644 --- a/fast_llm_external_models/apriel2/convert.py +++ b/fast_llm_external_models/apriel2/convert.py @@ -43,6 +43,7 @@ compose, compose_configs, plan_surgery, + strip_init_fields, ) # Import source-specific converters @@ -149,15 +150,19 @@ def build_plan( # Apply surgery chain if requested if surgery_configs: for i, surgery_config in enumerate(surgery_configs, 1): - surgery_plan = plan_surgery(current_config, surgery_config) + # S Γ— P β†’ T: compose state with surgery to get transition spec + target_config = compose_configs(current_config, surgery_config) + + # S Γ— T β†’ Plan: build plan from source state and transition spec + surgery_plan = plan_surgery(current_config, target_config) logger.info(f"Built surgery plan [{i}/{len(surgery_configs)}]: {surgery_plan.summary()['num_targets']} targets") - # Compose: current -> surgery + # Compose plans current_plan = compose(current_plan, surgery_plan) logger.info(f"Composed plan [{i}/{len(surgery_configs)}]: {current_plan.summary()['num_targets']} targets") - # Compose configs: merge surgery spec into current config - current_config = compose_configs(current_config, surgery_config) + # T β†’ S: strip init for next iteration (init is consumed by plan_surgery) + current_config = strip_init_fields(target_config) return current_plan, current_config @@ -407,11 +412,11 @@ def main(): show_plan=args.show_plan or args.verbose, ) - # Save config + # Save config (build_plan returns S which has no init, but strip defensively) output_config_file = args.output_dir / "config.json" logger.info(f"Saving config to {output_config_file}") with open(output_config_file, "w") as f: - json.dump(apriel2_config, f, indent=2) + json.dump(strip_init_fields(apriel2_config), f, indent=2) # Copy tokenizer files copy_tokenizer_files(input_dir, args.output_dir) diff --git a/fast_llm_external_models/apriel2/examples/train_supernet_small.yaml b/fast_llm_external_models/apriel2/examples/train_supernet_small.yaml index 78c22e57f..be4d06e0a 100644 --- a/fast_llm_external_models/apriel2/examples/train_supernet_small.yaml +++ b/fast_llm_external_models/apriel2/examples/train_supernet_small.yaml @@ -107,7 +107,7 @@ model: lr_scale: 0.0 # Freeze layer norms (norm_1 and norm_2 in each block) # Activation-level distillation: teach mixers to mimic teacher's attention outputs distillation_model: teacher - activation_distillation_factor: 0.1 + activation_distillation_factor: 0.8 embeddings: lr_scale: 0.0 # Freeze word embeddings head: From 8933953d61c2efa499bdd94d55bfe43e2b613955 Mon Sep 17 00:00:00 2001 From: Torsten Scholak Date: Wed, 17 Dec 2025 17:58:41 +0000 Subject: [PATCH 10/12] Fix RangeSample.from_documents and loss mask distillation bugs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - range.py: Use append() instead of extend() for tuple pairs. The extend() call was flattening tuples into individual integers, causing "cannot unpack non-iterable numpy.int64" errors when iterating over ranges. - model.py: Fix attribute name from output_layer to head. The config uses 'head' for the language model head configuration. πŸ€– Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- fast_llm/data/sample/range.py | 2 +- fast_llm/models/gpt/model.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/fast_llm/data/sample/range.py b/fast_llm/data/sample/range.py index 8dd351e1f..22d5e8992 100644 --- a/fast_llm/data/sample/range.py +++ b/fast_llm/data/sample/range.py @@ -38,7 +38,7 @@ def from_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self: sample_size = 0 for document in documents: for begin, end in document.ranges: - ranges.extend((begin + sample_size, end + sample_size)) + ranges.append((begin + sample_size, end + sample_size)) sample_size += document.sample_size return cls(ranges, sample_size) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index a0c381439..fd8d2af1b 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -247,7 +247,7 @@ def preprocess_batch( for sample_index, loss_masking_spans in enumerate(loss_masking_spans.ranges): for begin, end in loss_masking_spans: loss_mask[sample_index, begin:end] = False - if self._config.output_layer.distillation_model is not None: + if self._config.head.distillation_model is not None: kwargs[LanguageModelKwargs.loss_mask] = loss_mask labels = torch.where(loss_mask, labels, -100) From 1277894a690f94f9183b8d1fc1338328a95b5b23 Mon Sep 17 00:00:00 2001 From: Torsten Scholak Date: Fri, 19 Dec 2025 15:41:19 +0000 Subject: [PATCH 11/12] Skip roundtrip integration tests on CPU-only CI MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Integration tests should run on realistic hardware. Roundtrip tests (Apriel2 -> Fast-LLM -> Apriel2) now skip when CUDA is unavailable. Changes: - Add CUDA check to roundtrip_converted fixture - Lazy-load roundtrip fixture in converted_model to avoid eager evaluation - Apriel2 and supernet tests still run on CPU (16 tests) - Roundtrip tests skip on CPU-only CI (8 tests) πŸ€– Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- .../tests/test_apriel2/test_integration.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/fast_llm_external_models/tests/test_apriel2/test_integration.py b/fast_llm_external_models/tests/test_apriel2/test_integration.py index c11302d22..b90f0774e 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_integration.py +++ b/fast_llm_external_models/tests/test_apriel2/test_integration.py @@ -136,6 +136,9 @@ def supernet_converted(qwen2_source, apriel2_converted): @pytest.fixture(scope="module") def roundtrip_converted(supernet_converted, qwen2_source): """Stage 3: Supernet -> Fast-LLM -> Supernet.""" + if not torch.cuda.is_available(): + pytest.skip("Roundtrip conversion requires CUDA (integration tests need realistic hardware)") + from fast_llm.engine.checkpoint.config import ( CheckpointLoadConfig, CheckpointSaveConfig, @@ -181,18 +184,22 @@ def roundtrip_converted(supernet_converted, qwen2_source): @pytest.fixture(params=["apriel2", "supernet", "roundtrip"]) -def converted_model(request, apriel2_converted, supernet_converted, roundtrip_converted): +def converted_model(request, apriel2_converted, supernet_converted): """Parameterized fixture providing each conversion stage for testing. This allows a single test to run against all stages automatically. """ if request.param == "roundtrip": pytest.importorskip("fast_llm") + if not torch.cuda.is_available(): + pytest.skip("Roundtrip tests require CUDA (integration tests need realistic hardware)") + # Lazy-load to avoid fixture evaluation when CUDA unavailable + roundtrip_converted = request.getfixturevalue("roundtrip_converted") + return roundtrip_converted return { "apriel2": apriel2_converted, "supernet": supernet_converted, - "roundtrip": roundtrip_converted, }[request.param] From 9273966076de95ec1dc57154120c355a0b5cb88c Mon Sep 17 00:00:00 2001 From: Torsten Scholak Date: Sat, 20 Dec 2025 22:25:40 +0000 Subject: [PATCH 12/12] Refactor conversation format handling and tokenize_chat MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Split LanguageModelSourceConfig into abstract base + DocumentSourceConfig - Remove has_conversation property, use isinstance checks instead - Move _mask_to_spans to tokenizer module as _train_mask_to_loss_spans - tokenize_chat now returns (tokens, loss_masking_spans) directly - Safer BOS/EOS handling: check anywhere in tokens, not just first/last - Remove unused add_generation_prompt parameter from tokenize_chat πŸ€– Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- fast_llm/data/preparator/gpt_memmap/config.py | 138 ++++++------ .../data/preparator/gpt_memmap/prepare.py | 207 +++++++++--------- fast_llm/data/preprocessing/tokenizer.py | 51 ++++- .../apriel2/examples/prepare_tulu3.yaml | 2 +- tests/data/test_tokenizer.py | 29 ++- 5 files changed, 226 insertions(+), 201 deletions(-) diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 2aa0fbf31..a1aadf40a 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -18,30 +18,78 @@ @config_class(registry=True) class LanguageModelSourceConfig(Config): """ - A schema holding the name of each relevant column in the dataset. - Setting optional entries will enable the associated feature. + Abstract base class for data source schemas. - This is the base class for source schemas. Use `type: text` (default) for - plain text datasets, or `type: conversation` for chat/conversation datasets. + Use `type: document` (default) for documents with text, optional span annotations, and optional images. + Use `type: conversation` for structured chat/conversation datasets. + """ + + @classmethod + def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: + if cls is LanguageModelSourceConfig and cls.get_subclass(default.get("type")) is None: + # Default to DocumentSourceConfig when type is not specified + return DocumentSourceConfig._from_dict(default, strict) + return super()._from_dict(default, strict=strict) + + @functools.cached_property + def columns(self) -> list[str]: + """Columns to read from the dataset.""" + raise NotImplementedError + + @functools.cached_property + def has_loss_masking_span(self) -> bool: + return False + + @functools.cached_property + def has_preference_spans(self) -> bool: + return False + + @functools.cached_property + def has_images(self) -> bool: + return False + + +@config_class(dynamic_type={LanguageModelSourceConfig: "document"}) +class DocumentSourceConfig(LanguageModelSourceConfig): + """ + Source schema for document datasets with text, optional span annotations, and optional images. + + The dataset should have a text column containing the document text. + Optionally, it can have additional columns for: + - Loss masking spans: character ranges to mask from loss computation + - Preference spans: chosen/rejected text for DPO training + - Images: image data with character positions for multimodal training """ text: str = Field( default="text", - desc="Field of the dataset to use.", + desc="Field containing the document text.", hint=FieldHint.optional, ) - loss_masking_spans: None | str = Field( - default=None, desc="Field containing character spans to mask for loss computation", hint=FieldHint.optional + loss_masking_spans: str | None = Field( + default=None, + desc="Field containing character spans to mask for loss computation.", + hint=FieldHint.optional, ) - chosen_span: None | str = Field( - default=None, desc="Field containing chosen text for preference optimization", hint=FieldHint.optional + chosen_span: str | None = Field( + default=None, + desc="Field containing chosen text for preference optimization.", + hint=FieldHint.optional, ) - rejected_span: None | str = Field( - default=None, desc="Field containing rejected text for preference optimization", hint=FieldHint.optional + rejected_span: str | None = Field( + default=None, + desc="Field containing rejected text for preference optimization.", + hint=FieldHint.optional, + ) + images: str | None = Field( + default=None, + desc="Field containing images.", + hint=FieldHint.optional, ) - images: None | str = Field(default=None, desc="Field containing images", hint=FieldHint.optional) - image_positions: None | str = Field( - default=None, desc="Field containing image positions in the text.", hint=FieldHint.optional + image_positions: str | None = Field( + default=None, + desc="Field containing image positions in the text.", + hint=FieldHint.optional, ) @functools.cached_property @@ -69,28 +117,10 @@ def has_images(self) -> bool: Assert.eq(self.images is None, self.image_positions is None) return self.images is not None - @functools.cached_property - def has_conversation(self) -> bool: - """Whether this is a conversation source schema.""" - return False - def _validate(self): super()._validate() if self.has_preference_spans and self.has_loss_masking_span: - raise ValueError(f"Can not enable both loss masking and preference spans.") - - -@config_class(dynamic_type={LanguageModelSourceConfig: "text"}) -class TextSourceConfig(LanguageModelSourceConfig): - """ - Source schema for plain text datasets (default). - - The dataset should have a text column containing the document text. - Optionally, it can have additional columns for loss masking spans, - preference spans (for DPO), or images. - """ - - pass + raise ValueError("Cannot enable both loss masking and preference spans.") @config_class(dynamic_type={LanguageModelSourceConfig: "conversation"}) @@ -120,59 +150,21 @@ class ConversationSourceConfig(LanguageModelSourceConfig): } """ - # Override text field - not used directly for conversation format - text: None | str = Field( - default=None, - desc="Not used for conversation format. Text is generated from messages.", - hint=FieldHint.optional, - ) - - # Conversation-specific fields messages: str = Field( default="messages", desc="Field containing the conversation messages list. Each message should have 'role' and 'content' keys.", hint=FieldHint.core, ) - add_generation_prompt: bool = Field( - default=False, - desc="Whether to add a generation prompt at the end of the conversation. " - "Typically False for training data.", - hint=FieldHint.optional, - ) - @functools.cached_property def columns(self) -> list[str]: - # For conversation format, we read the messages column, not text - columns = [self.messages] - # Images can still be used with conversation format - if self.has_images: - columns.extend([self.images, self.image_positions]) - return columns - - @functools.cached_property - def has_conversation(self) -> bool: - return True + return [self.messages] @functools.cached_property def has_loss_masking_span(self) -> bool: - # Conversation format always generates loss masking spans + # Conversation format always generates loss masking spans from chat template markers return True - def _validate(self): - # Skip parent validation that checks text field - Config._validate(self) - if self.has_preference_spans: - raise ValueError("Preference spans are not supported with conversation format.") - if self.has_images: - # Images with conversation format would require computing image positions in the - # chat-template-formatted text, which is complex and format-dependent. - # For VLM training with conversations, preprocess the data to plain text format first. - raise ValueError( - "Images are not yet supported with conversation format. " - "For multimodal conversation data, preprocess to plain text format with image positions." - ) - @config_class() class GPTHuggingfaceDatasetConfig(Config): diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index a9beca42f..eeb925591 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -28,7 +28,12 @@ ) from fast_llm.data.dataset.memmap import MemmapDataset from fast_llm.data.preparator.config import DatasetPreparator -from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig, LanguageModelSourceConfig +from fast_llm.data.preparator.gpt_memmap.config import ( + ConversationSourceConfig, + GPTMemmapDatasetPreparatorConfig, + LanguageModelSourceConfig, + DocumentSourceConfig, +) from fast_llm.data.preprocessing.abstract import NullPreprocessingConfig from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.data.preprocessing.tokenizer import Tokenizer @@ -133,7 +138,7 @@ def run(self) -> None: self._tokenizer = self._config.tokenizer.get_tokenizer() # Validate chat template for conversation format - if self._source_schema.has_conversation: + if isinstance(self._source_schema, ConversationSourceConfig): self._tokenizer.validate_chat_template() # Decide the datatype based on the tokenizer vocabulary size @@ -220,108 +225,110 @@ def _preprocessing_config(self) -> LanguageModelPreprocessingConfig: ) def _prepare_sample(self, sample: dict[str, typing.Any]) -> LanguageModelSample: - if self._source_schema.has_conversation: - tokens, train_mask = self._tokenizer.tokenize_chat( + token_spans_by_type = collections.defaultdict(list) + image_patches = image_token_maps = image_position_ids = patch_counts = None + + if isinstance(self._source_schema, ConversationSourceConfig): + # Conversation format: tokenize messages and get loss masking spans from chat template + tokens, loss_masking_spans = self._tokenizer.tokenize_chat( sample[self._source_schema.messages], - self._source_schema.add_generation_prompt, + True, + True, data_type=self._data_type, ) - return LanguageModelSample( - TokenSample(tokens, [len(tokens)]), - RangeSample(_mask_to_spans(train_mask), len(tokens)), - None, - None, - None, - ) + token_spans_by_type[SpanType.loss_masking] = loss_masking_spans + elif isinstance(self._source_schema, DocumentSourceConfig): + # Document format: use the text-spans pipeline + text = sample[self._source_schema.text] + all_spans = [] + + if self._source_schema.has_loss_masking_span: + # Spans are typically stored in the (begin, last) format. We convert to (begin, end) range format. + loss_masking_spans = _sort_spans( + (SpanType.loss_masking, (begin, last + 1)) + for begin, last in np.array(sample[self._source_schema.loss_masking_spans], dtype=np.int32) + .reshape(-1, 2) + .tolist() + ) + all_spans.extend(loss_masking_spans) - # Text format: use the text-spans pipeline - text = sample[self._source_schema.text] - all_spans = [] - - if self._source_schema.has_loss_masking_span: - # Spans are typically stored in the (begin, last) format. We convert to (begin, end) range format. - loss_masking_spans = _sort_spans( - (SpanType.loss_masking, (begin, last + 1)) - for begin, last in np.array(sample[self._source_schema.loss_masking_spans], dtype=np.int32) - .reshape(-1, 2) - .tolist() - ) - all_spans.extend(loss_masking_spans) - - if self._source_schema.has_preference_spans: - full_chosen_text = text + sample[self._source_schema.chosen_span] + self._tokenizer.tokenizer.eos_token - full_rejected_text = self._tokenizer.tokenizer.bos_token + text + sample[self._source_schema.rejected_span] - # compute chosen span - chosen_spans = [(SpanType.chosen, (len(text), len(full_chosen_text)))] - - # compute rejected span - rejected_span = [ - ( - SpanType.rejected, - ( - len(full_chosen_text) + len(self._tokenizer.tokenizer.bos_token) + len(text), - len(full_chosen_text) + len(full_rejected_text), - ), + if self._source_schema.has_preference_spans: + full_chosen_text = text + sample[self._source_schema.chosen_span] + self._tokenizer.tokenizer.eos_token + full_rejected_text = ( + self._tokenizer.tokenizer.bos_token + text + sample[self._source_schema.rejected_span] ) - ] - # pack texts - text = full_chosen_text + full_rejected_text - all_spans.extend(chosen_spans + rejected_span) - - if self._source_schema.has_images: - # Get the images and positions, sorted by position. - images, image_positions = ( - zip( - *sorted( - zip( - sample[self._source_schema.images], - sample[self._source_schema.image_positions], - strict=True, + # compute chosen span + chosen_spans = [(SpanType.chosen, (len(text), len(full_chosen_text)))] + + # compute rejected span + rejected_span = [ + ( + SpanType.rejected, + ( + len(full_chosen_text) + len(self._tokenizer.tokenizer.bos_token) + len(text), + len(full_chosen_text) + len(full_rejected_text), ), - key=lambda x: x[1], ) + ] + # pack texts + text = full_chosen_text + full_rejected_text + all_spans.extend(chosen_spans + rejected_span) + + if self._source_schema.has_images: + # Get the images and positions, sorted by position. + images, image_positions = ( + zip( + *sorted( + zip( + sample[self._source_schema.images], + sample[self._source_schema.image_positions], + strict=True, + ), + key=lambda x: x[1], + ) + ) + if len(sample[self._source_schema.images]) > 0 + else ([], []) ) - if len(sample[self._source_schema.images]) > 0 - else ([], []) - ) - # Get the image patches and associated data. - image_patches, image_position_ids, image_token_maps, image_token_ids, patch_counts = ( - self._config.image_patches.get_patches_from_images(images, self._data_type) + # Get the image patches and associated data. + image_patches, image_position_ids, image_token_maps, image_token_ids, patch_counts = ( + self._config.image_patches.get_patches_from_images(images, self._data_type) + ) + patch_count_cumsum = padded_cumsum(patch_counts).tolist() + # Add an empty "span" at each image position so we know where to insert them in the tokenized sequence. + all_spans.extend([(SpanType.image, (position, position)) for position in image_positions]) + + # Sort the spans by location (begin), keeping track of their type. + # Note: overlapping spans are not supported (explicit assertion in the tokenizer). + span_types, spans = zip(*_sort_spans(all_spans)) if all_spans else ([], []) + # Tokenize the text, and determine the span locations in the tokenized text. + tokens, token_spans = self._tokenizer.tokenize_with_spans( + text, True, True, text_spans=spans, data_type=self._data_type ) - patch_count_cumsum = padded_cumsum(patch_counts).tolist() - # Add an empty "span" at each image position so we know where to insert them in the tokenized sequence. - all_spans.extend([(SpanType.image, (position, position)) for position in image_positions]) - - # Sort the spans by location (begin), keeping track of their type. - # Note: overlapping spans are not supported (explicit assertion in the tokenizer). - span_types, spans = zip(*_sort_spans(all_spans)) if all_spans else ([], []) - # Tokenize the text, and determine the span locations in the tokenized text. - tokens, token_spans = self._tokenizer.tokenize_with_spans( - text, True, True, text_spans=spans, data_type=self._data_type - ) - # Gather token spans by type. - token_spans_by_type = collections.defaultdict(list) - if self._source_schema.has_images: - # Insert the image token ids in the token sequence and shift the spans accordingly. - tokens_shift = 0 - image_index = 0 - for span_type, (begin, end) in zip(span_types, token_spans, strict=True): - # Account for the tokens already inserted. - begin = begin + tokens_shift - end = end + tokens_shift - if span_type == SpanType.image: - # Shift the token map to the image location. - image_token_maps[patch_count_cumsum[image_index] : patch_count_cumsum[image_index + 1]] += begin - # Insert the placeholder and image break tokens. - tokens = torch.cat([tokens[:begin], image_token_ids[image_index], tokens[begin:]]) - tokens_shift += len(image_token_ids[image_index]) - image_index += 1 - else: - token_spans_by_type[span_type].append((begin, end)) + # Gather token spans by type. + if self._source_schema.has_images: + # Insert the image token ids in the token sequence and shift the spans accordingly. + tokens_shift = 0 + image_index = 0 + for span_type, (begin, end) in zip(span_types, token_spans, strict=True): + # Account for the tokens already inserted. + begin = begin + tokens_shift + end = end + tokens_shift + if span_type == SpanType.image: + # Shift the token map to the image location. + image_token_maps[patch_count_cumsum[image_index] : patch_count_cumsum[image_index + 1]] += begin + # Insert the placeholder and image break tokens. + tokens = torch.cat([tokens[:begin], image_token_ids[image_index], tokens[begin:]]) + tokens_shift += len(image_token_ids[image_index]) + image_index += 1 + else: + token_spans_by_type[span_type].append((begin, end)) + else: + for span_type, token_span in zip(span_types, token_spans, strict=True): + token_spans_by_type[span_type].append(token_span) else: - for span_type, token_span in zip(span_types, token_spans, strict=True): - token_spans_by_type[span_type].append(token_span) + raise NotImplementedError(f"Unsupported source schema type: {type(self._source_schema)}") sample_size = len(tokens) @@ -501,17 +508,3 @@ def _get_nearest_split(cumsum: np.ndarray, value: float) -> int: return left.item() + 1 if (value - cumsum[left]) / (cumsum[left + 1] - cumsum[left]) > 0.5 else left.item() -def _mask_to_spans(mask: list[bool]) -> list[tuple[int, int]]: - """Convert a boolean train mask to loss masking spans (where mask[i] == False).""" - spans = [] - start = None - for i, value in enumerate(mask): - if not value: - if start is None: - start = i - elif start is not None: - spans.append((start, i)) - start = None - if start is not None: - spans.append((start, len(mask))) - return spans diff --git a/fast_llm/data/preprocessing/tokenizer.py b/fast_llm/data/preprocessing/tokenizer.py index f3b5a51a8..2d27c3853 100644 --- a/fast_llm/data/preprocessing/tokenizer.py +++ b/fast_llm/data/preprocessing/tokenizer.py @@ -245,12 +245,17 @@ def validate_chat_template(self) -> None: def tokenize_chat( self, messages: list[dict[str, str]], - add_generation_prompt: bool = False, begin: bool = True, end: bool = True, data_type: DataType = DataType.int64, - ) -> tuple["torch.Tensor", list[bool]]: - """Apply chat template and return (tokens, train_mask) where train_mask[i]=True means train on token i.""" + ) -> tuple["torch.Tensor", list[tuple[int, int]]]: + """ + Apply chat template and return (tokens, loss_masking_spans). + + The loss_masking_spans mark token ranges to EXCLUDE from training (where the model + should not learn). These are derived from the chat template's generation markers - + tokens outside {% generation %}...{% endgeneration %} blocks are masked. + """ import torch result = self.tokenizer.apply_chat_template( @@ -258,17 +263,22 @@ def tokenize_chat( tokenize=True, return_assistant_tokens_mask=True, return_dict=True, - add_generation_prompt=add_generation_prompt, + add_generation_prompt=False, ) tokens = result["input_ids"] train_mask = result["assistant_masks"] - # Prepend BOS / append EOS if needed (avoid O(n) insert) - prepend_bos = begin and (not tokens or tokens[0] != self.bod_id) - append_eos = end and (not tokens or tokens[-1] != self.eod_id) + # Prepend BOS / append EOS if not already present anywhere in the sequence. + # We check anywhere (not just first/last) because some chat templates add trailing + # whitespace after the final EOS token, e.g. "<|im_end|>\n". + prepend_bos = begin and self.bod_id not in tokens + append_eos = end and self.eod_id not in tokens tokens = [self.bod_id] * prepend_bos + list(tokens) + [self.eod_id] * append_eos train_mask = [False] * prepend_bos + [bool(m) for m in train_mask] + [False] * append_eos + # Convert boolean train mask to loss masking spans (spans where train_mask[i] == False) + loss_masking_spans = _train_mask_to_loss_spans(train_mask) + if self._config.max_vocab_size is not None: tokens = ( torch.tensor( @@ -279,5 +289,30 @@ def tokenize_chat( ).to(data_type.torch) else: tokens = torch.tensor(tokens, dtype=data_type.torch) - return tokens, train_mask + return tokens, loss_masking_spans + + +def _train_mask_to_loss_spans(train_mask: list[bool]) -> list[tuple[int, int]]: + """ + Convert a boolean train mask to loss masking spans. + + Args: + train_mask: Boolean list where True = train on this token, False = don't train + + Returns: + List of (begin, end) spans marking token ranges to EXCLUDE from training + (i.e., where train_mask[i] == False). + """ + spans = [] + start = None + for i, should_train in enumerate(train_mask): + if not should_train: + if start is None: + start = i + elif start is not None: + spans.append((start, i)) + start = None + if start is not None: + spans.append((start, len(train_mask))) + return spans diff --git a/fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml b/fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml index ba85c1aed..34672916c 100644 --- a/fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml +++ b/fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml @@ -81,7 +81,7 @@ dataset: # Source schema for conversation format source_schema: - # Use conversation type (vs default "text" type) + # Use conversation type (vs default "document" type) type: conversation # Column containing the messages list diff --git a/tests/data/test_tokenizer.py b/tests/data/test_tokenizer.py index 97f16c6d6..f8f07ef0f 100644 --- a/tests/data/test_tokenizer.py +++ b/tests/data/test_tokenizer.py @@ -76,15 +76,17 @@ def test_validate_chat_template_with_markers(common_tokenizer): @pytest.mark.parametrize( - ("messages", "expected_tokens", "expected_trainable_indices"), + ("messages", "expected_tokens", "expected_loss_masking_spans"), ( # Single turn: full assistant turn (Hello) is trainable + # 15 tokens, trainable indices 7-13, loss mask spans cover 0-6 and 14 ( [{"role": "user", "content": "Hi"}, {"role": "assistant", "content": "Hello"}], [49152, 27, 789, 29, 16946, 750, 789, 2293, 17822, 29, 7371, 750, 17822, 29, 49152], - [7, 8, 9, 10, 11, 12, 13], + [(0, 7), (14, 15)], ), # Multi-turn: both assistant turns are fully trainable + # 27 tokens, trainable indices 7-13 and 19-25 ( [ {"role": "user", "content": "A"}, @@ -93,9 +95,10 @@ def test_validate_chat_template_with_markers(common_tokenizer): {"role": "assistant", "content": "D"}, ], [49152, 27, 789, 29, 32, 750, 789, 2293, 17822, 29, 33, 750, 17822, 2293, 789, 29, 34, 750, 789, 2293, 17822, 29, 35, 750, 17822, 29, 49152], - [7, 8, 9, 10, 11, 12, 13, 19, 20, 21, 22, 23, 24, 25], + [(0, 7), (14, 19), (26, 27)], ), # System + user + assistant: full assistant turn trainable + # 23 tokens, trainable indices 15-21 ( [ {"role": "system", "content": "You are helpful."}, @@ -103,15 +106,17 @@ def test_validate_chat_template_with_markers(common_tokenizer): {"role": "assistant", "content": "Hello"}, ], [49152, 27, 3144, 29, 5815, 1139, 44569, 6928, 3144, 2293, 789, 29, 16946, 750, 789, 2293, 17822, 29, 7371, 750, 17822, 29, 49152], - [15, 16, 17, 18, 19, 20, 21], + [(0, 15), (22, 23)], ), # User only: no trainable tokens + # 9 tokens, no trainable indices ( [{"role": "user", "content": "Hi"}], [49152, 27, 789, 29, 16946, 750, 789, 29, 49152], - [], + [(0, 9)], ), # Long multi-turn (85 tokens, 3 assistant responses with tags, tests span machinery) + # Trainable: indices 27-40, 49-62, 70-83 ( [ {"role": "system", "content": "You are a helpful assistant that answers questions."}, @@ -123,15 +128,15 @@ def test_validate_chat_template_with_markers(common_tokenizer): {"role": "assistant", "content": "The capital of Italy is Rome."}, ], [49152, 27, 3144, 29, 5815, 1139, 373, 44569, 2424, 11886, 954, 15737, 14516, 6928, 3144, 2293, 789, 29, 13938, 438, 331, 25016, 457, 12409, 562, 35838, 789, 2293, 17822, 29, 2111, 25016, 457, 12409, 562, 438, 4235, 280, 6928, 17822, 2293, 789, 29, 13938, 5028, 759, 42226, 35838, 789, 2293, 17822, 29, 2111, 25016, 457, 759, 42226, 438, 29784, 3556, 6928, 17822, 2293, 789, 29, 1996, 4413, 3326, 35838, 789, 2293, 17822, 29, 2111, 25016, 457, 4413, 3326, 438, 613, 1361, 6928, 17822, 29, 49152], - list(range(27, 41)) + list(range(49, 63)) + list(range(70, 84)), + [(0, 27), (41, 49), (63, 70), (84, 85)], ), ), ) -def test_tokenize_chat(common_tokenizer, messages, expected_tokens, expected_trainable_indices): +def test_tokenize_chat(common_tokenizer, messages, expected_tokens, expected_loss_masking_spans): common_tokenizer.tokenizer.chat_template = CHAT_TEMPLATE - tokens, train_mask = common_tokenizer.tokenize_chat(messages) + tokens, loss_masking_spans = common_tokenizer.tokenize_chat(messages) Assert.eq(tokens.tolist(), expected_tokens) - Assert.eq([i for i, m in enumerate(train_mask) if m], expected_trainable_indices) + Assert.eq(loss_masking_spans, expected_loss_masking_spans) @pytest.mark.parametrize( @@ -153,7 +158,7 @@ def test_tokenize_chat(common_tokenizer, messages, expected_tokens, expected_tra ([False, True, False, True, False], [(0, 1), (2, 3), (4, 5)]), ), ) -def test_mask_to_spans(train_mask, expected_loss_spans): - from fast_llm.data.preparator.gpt_memmap.prepare import _mask_to_spans +def test_train_mask_to_loss_spans(train_mask, expected_loss_spans): + from fast_llm.data.preprocessing.tokenizer import _train_mask_to_loss_spans - Assert.eq(_mask_to_spans(train_mask), expected_loss_spans) + Assert.eq(_train_mask_to_loss_spans(train_mask), expected_loss_spans)