diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py
index 503b400c3..a1aadf40a 100644
--- a/fast_llm/data/preparator/gpt_memmap/config.py
+++ b/fast_llm/data/preparator/gpt_memmap/config.py
@@ -15,30 +15,81 @@
from fast_llm.data.preparator.gpt_memmap.prepare import GPTMemmapDatasetPreparator
-@config_class()
+@config_class(registry=True)
class LanguageModelSourceConfig(Config):
"""
- A schema holding the name of each relevant column in the dataset.
- Setting optional entries will enable the associated feature.
+ Abstract base class for data source schemas.
+
+ Use `type: document` (default) for documents with text, optional span annotations, and optional images.
+ Use `type: conversation` for structured chat/conversation datasets.
+ """
+
+ @classmethod
+ def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self:
+ if cls is LanguageModelSourceConfig and cls.get_subclass(default.get("type")) is None:
+ # Default to DocumentSourceConfig when type is not specified
+ return DocumentSourceConfig._from_dict(default, strict)
+ return super()._from_dict(default, strict=strict)
+
+ @functools.cached_property
+ def columns(self) -> list[str]:
+ """Columns to read from the dataset."""
+ raise NotImplementedError
+
+ @functools.cached_property
+ def has_loss_masking_span(self) -> bool:
+ return False
+
+ @functools.cached_property
+ def has_preference_spans(self) -> bool:
+ return False
+
+ @functools.cached_property
+ def has_images(self) -> bool:
+ return False
+
+
+@config_class(dynamic_type={LanguageModelSourceConfig: "document"})
+class DocumentSourceConfig(LanguageModelSourceConfig):
+ """
+ Source schema for document datasets with text, optional span annotations, and optional images.
+
+ The dataset should have a text column containing the document text.
+ Optionally, it can have additional columns for:
+ - Loss masking spans: character ranges to mask from loss computation
+ - Preference spans: chosen/rejected text for DPO training
+ - Images: image data with character positions for multimodal training
"""
text: str = Field(
default="text",
- desc="Field of the dataset to use.",
+ desc="Field containing the document text.",
+ hint=FieldHint.optional,
+ )
+ loss_masking_spans: str | None = Field(
+ default=None,
+ desc="Field containing character spans to mask for loss computation.",
hint=FieldHint.optional,
)
- loss_masking_spans: None | str = Field(
- default=None, desc="Field containing character spans to mask for loss computation", hint=FieldHint.optional
+ chosen_span: str | None = Field(
+ default=None,
+ desc="Field containing chosen text for preference optimization.",
+ hint=FieldHint.optional,
)
- chosen_span: None | str = Field(
- default=None, desc="Field containing chosen text for preference optimization", hint=FieldHint.optional
+ rejected_span: str | None = Field(
+ default=None,
+ desc="Field containing rejected text for preference optimization.",
+ hint=FieldHint.optional,
)
- rejected_span: None | str = Field(
- default=None, desc="Field containing rejected text for preference optimization", hint=FieldHint.optional
+ images: str | None = Field(
+ default=None,
+ desc="Field containing images.",
+ hint=FieldHint.optional,
)
- images: None | str = Field(default=None, desc="Field containing images", hint=FieldHint.optional)
- image_positions: None | str = Field(
- default=None, desc="Field containing image positions in the text.", hint=FieldHint.optional
+ image_positions: str | None = Field(
+ default=None,
+ desc="Field containing image positions in the text.",
+ hint=FieldHint.optional,
)
@functools.cached_property
@@ -48,6 +99,8 @@ def columns(self) -> list[str]:
columns.append(self.loss_masking_spans)
if self.has_preference_spans:
columns.extend([self.chosen_span, self.rejected_span])
+ if self.has_images:
+ columns.extend([self.images, self.image_positions])
return columns
@functools.cached_property
@@ -67,7 +120,50 @@ def has_images(self) -> bool:
def _validate(self):
super()._validate()
if self.has_preference_spans and self.has_loss_masking_span:
- raise ValueError(f"Can not enable both loss masking and preference spans.")
+ raise ValueError("Cannot enable both loss masking and preference spans.")
+
+
+@config_class(dynamic_type={LanguageModelSourceConfig: "conversation"})
+class ConversationSourceConfig(LanguageModelSourceConfig):
+ """
+ Source schema for chat/conversation datasets (e.g., Tulu 3, ShareGPT, OpenAI format).
+
+ The dataset should have a messages column containing a list of message dicts,
+ where each message has 'role' and 'content' keys. Common roles include:
+ - 'system': System prompt
+ - 'user': User input
+ - 'assistant': Model response (trained on by default)
+ - 'tool': Tool/function results
+ - 'ipython': Code execution results
+
+ The conversation is formatted using the tokenizer's chat template, which must
+ contain {% generation %}...{% endgeneration %} markers to define which content
+ to train on. Loss masking spans are automatically computed from these markers.
+
+ Example dataset format:
+ {
+ "messages": [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": "Hello!"},
+ {"role": "assistant", "content": "Hi there!"},
+ ]
+ }
+ """
+
+ messages: str = Field(
+ default="messages",
+ desc="Field containing the conversation messages list. Each message should have 'role' and 'content' keys.",
+ hint=FieldHint.core,
+ )
+
+ @functools.cached_property
+ def columns(self) -> list[str]:
+ return [self.messages]
+
+ @functools.cached_property
+ def has_loss_masking_span(self) -> bool:
+ # Conversation format always generates loss masking spans from chat template markers
+ return True
@config_class()
diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py
index 2ea81d8a6..eeb925591 100644
--- a/fast_llm/data/preparator/gpt_memmap/prepare.py
+++ b/fast_llm/data/preparator/gpt_memmap/prepare.py
@@ -28,7 +28,12 @@
)
from fast_llm.data.dataset.memmap import MemmapDataset
from fast_llm.data.preparator.config import DatasetPreparator
-from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig, LanguageModelSourceConfig
+from fast_llm.data.preparator.gpt_memmap.config import (
+ ConversationSourceConfig,
+ GPTMemmapDatasetPreparatorConfig,
+ LanguageModelSourceConfig,
+ DocumentSourceConfig,
+)
from fast_llm.data.preprocessing.abstract import NullPreprocessingConfig
from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig
from fast_llm.data.preprocessing.tokenizer import Tokenizer
@@ -132,6 +137,10 @@ def run(self) -> None:
# Load tokenizer
self._tokenizer = self._config.tokenizer.get_tokenizer()
+ # Validate chat template for conversation format
+ if isinstance(self._source_schema, ConversationSourceConfig):
+ self._tokenizer.validate_chat_template()
+
# Decide the datatype based on the tokenizer vocabulary size
self._data_type = (
get_unsigned_integer_type(self._tokenizer.vocab_size)
@@ -216,92 +225,110 @@ def _preprocessing_config(self) -> LanguageModelPreprocessingConfig:
)
def _prepare_sample(self, sample: dict[str, typing.Any]) -> LanguageModelSample:
- text = sample[self._source_schema.text]
- all_spans = []
- if self._source_schema.has_loss_masking_span:
- # Spans are typically stored in the (begin, last) format. We convert to (begin, end) range format.
- loss_masking_spans = _sort_spans(
- (SpanType.loss_masking, (begin, last + 1))
- for begin, last in np.array(sample[self._source_schema.loss_masking_spans], dtype=np.int32)
- .reshape(-1, 2)
- .tolist()
+ token_spans_by_type = collections.defaultdict(list)
+ image_patches = image_token_maps = image_position_ids = patch_counts = None
+
+ if isinstance(self._source_schema, ConversationSourceConfig):
+ # Conversation format: tokenize messages and get loss masking spans from chat template
+ tokens, loss_masking_spans = self._tokenizer.tokenize_chat(
+ sample[self._source_schema.messages],
+ True,
+ True,
+ data_type=self._data_type,
)
- all_spans.extend(loss_masking_spans)
-
- if self._source_schema.has_preference_spans:
- full_chosen_text = text + sample[self._source_schema.chosen_span] + self._tokenizer.tokenizer.eos_token
- full_rejected_text = self._tokenizer.tokenizer.bos_token + text + sample[self._source_schema.rejected_span]
- # compute chosen span
- chosen_spans = [(SpanType.chosen, (len(text), len(full_chosen_text)))]
-
- # compute rejected span
- rejected_span = [
- (
- SpanType.rejected,
- (
- len(full_chosen_text) + len(self._tokenizer.tokenizer.bos_token) + len(text),
- len(full_chosen_text) + len(full_rejected_text),
- ),
+ token_spans_by_type[SpanType.loss_masking] = loss_masking_spans
+ elif isinstance(self._source_schema, DocumentSourceConfig):
+ # Document format: use the text-spans pipeline
+ text = sample[self._source_schema.text]
+ all_spans = []
+
+ if self._source_schema.has_loss_masking_span:
+ # Spans are typically stored in the (begin, last) format. We convert to (begin, end) range format.
+ loss_masking_spans = _sort_spans(
+ (SpanType.loss_masking, (begin, last + 1))
+ for begin, last in np.array(sample[self._source_schema.loss_masking_spans], dtype=np.int32)
+ .reshape(-1, 2)
+ .tolist()
)
- ]
- # pack texts
- text = full_chosen_text + full_rejected_text
- all_spans.extend(chosen_spans + rejected_span)
-
- if self._source_schema.has_images:
- # Get the images and positions, sorted by position.
- images, image_positions = (
- zip(
- *sorted(
- zip(
- sample[self._source_schema.images],
- sample[self._source_schema.image_positions],
- strict=True,
+ all_spans.extend(loss_masking_spans)
+
+ if self._source_schema.has_preference_spans:
+ full_chosen_text = text + sample[self._source_schema.chosen_span] + self._tokenizer.tokenizer.eos_token
+ full_rejected_text = (
+ self._tokenizer.tokenizer.bos_token + text + sample[self._source_schema.rejected_span]
+ )
+ # compute chosen span
+ chosen_spans = [(SpanType.chosen, (len(text), len(full_chosen_text)))]
+
+ # compute rejected span
+ rejected_span = [
+ (
+ SpanType.rejected,
+ (
+ len(full_chosen_text) + len(self._tokenizer.tokenizer.bos_token) + len(text),
+ len(full_chosen_text) + len(full_rejected_text),
),
- key=lambda x: x[1],
)
+ ]
+ # pack texts
+ text = full_chosen_text + full_rejected_text
+ all_spans.extend(chosen_spans + rejected_span)
+
+ if self._source_schema.has_images:
+ # Get the images and positions, sorted by position.
+ images, image_positions = (
+ zip(
+ *sorted(
+ zip(
+ sample[self._source_schema.images],
+ sample[self._source_schema.image_positions],
+ strict=True,
+ ),
+ key=lambda x: x[1],
+ )
+ )
+ if len(sample[self._source_schema.images]) > 0
+ else ([], [])
)
- if len(sample[self._source_schema.images]) > 0
- else ([], [])
- )
- # Get the image patches and associated data.
- image_patches, image_position_ids, image_token_maps, image_token_ids, patch_counts = (
- self._config.image_patches.get_patches_from_images(images, self._data_type)
+ # Get the image patches and associated data.
+ image_patches, image_position_ids, image_token_maps, image_token_ids, patch_counts = (
+ self._config.image_patches.get_patches_from_images(images, self._data_type)
+ )
+ patch_count_cumsum = padded_cumsum(patch_counts).tolist()
+ # Add an empty "span" at each image position so we know where to insert them in the tokenized sequence.
+ all_spans.extend([(SpanType.image, (position, position)) for position in image_positions])
+
+ # Sort the spans by location (begin), keeping track of their type.
+ # Note: overlapping spans are not supported (explicit assertion in the tokenizer).
+ span_types, spans = zip(*_sort_spans(all_spans)) if all_spans else ([], [])
+ # Tokenize the text, and determine the span locations in the tokenized text.
+ tokens, token_spans = self._tokenizer.tokenize_with_spans(
+ text, True, True, text_spans=spans, data_type=self._data_type
)
- patch_count_cumsum = padded_cumsum(patch_counts).tolist()
- # Add an empty "span" at each image position so we know where to insert them in the tokenized sequence.
- all_spans.extend([(SpanType.image, (position, position)) for position in image_positions])
-
- # Sort the spans by location (begin), keeping track of their type.
- # Note: overlapping spans are not supported (explicit assertion in the tokenizer).
- span_types, spans = zip(*_sort_spans(all_spans)) if all_spans else ([], [])
- # Tokenize the text, and determine the span locations in the tokenized text.
- tokens, token_spans = self._tokenizer.tokenize_with_spans(
- text, True, True, text_spans=spans, data_type=self._data_type
- )
- # Gather token spans by type.
- token_spans_by_type = collections.defaultdict(list)
- if self._source_schema.has_images:
- # Insert the image token ids in the token sequence and shift the spans accordingly.
- tokens_shift = 0
- image_index = 0
- for span_type, (begin, end) in zip(span_types, token_spans, strict=True):
- # Account for the tokens already inserted.
- begin = begin + tokens_shift
- end = end + tokens_shift
- if span_type == SpanType.image:
- # Shift the token map to the image location.
- image_token_maps[patch_count_cumsum[image_index] : patch_count_cumsum[image_index + 1]] += begin
- # Insert the placeholder and image break tokens.
- tokens = torch.cat([tokens[:begin], image_token_ids[image_index], tokens[begin:]])
- tokens_shift += len(image_token_ids[image_index])
- image_index += 1
- else:
- token_spans_by_type[span_type].append((begin, end))
+ # Gather token spans by type.
+ if self._source_schema.has_images:
+ # Insert the image token ids in the token sequence and shift the spans accordingly.
+ tokens_shift = 0
+ image_index = 0
+ for span_type, (begin, end) in zip(span_types, token_spans, strict=True):
+ # Account for the tokens already inserted.
+ begin = begin + tokens_shift
+ end = end + tokens_shift
+ if span_type == SpanType.image:
+ # Shift the token map to the image location.
+ image_token_maps[patch_count_cumsum[image_index] : patch_count_cumsum[image_index + 1]] += begin
+ # Insert the placeholder and image break tokens.
+ tokens = torch.cat([tokens[:begin], image_token_ids[image_index], tokens[begin:]])
+ tokens_shift += len(image_token_ids[image_index])
+ image_index += 1
+ else:
+ token_spans_by_type[span_type].append((begin, end))
+ else:
+ for span_type, token_span in zip(span_types, token_spans, strict=True):
+ token_spans_by_type[span_type].append(token_span)
else:
- for span_type, token_span in zip(span_types, token_spans, strict=True):
- token_spans_by_type[span_type].append(token_span)
+ raise NotImplementedError(f"Unsupported source schema type: {type(self._source_schema)}")
sample_size = len(tokens)
@@ -479,3 +506,5 @@ def _get_nearest_split(cumsum: np.ndarray, value: float) -> int:
if left == len(cumsum):
return left.item()
return left.item() + 1 if (value - cumsum[left]) / (cumsum[left + 1] - cumsum[left]) > 0.5 else left.item()
+
+
diff --git a/fast_llm/data/preprocessing/tokenizer.py b/fast_llm/data/preprocessing/tokenizer.py
index abfb5b3d2..2d27c3853 100644
--- a/fast_llm/data/preprocessing/tokenizer.py
+++ b/fast_llm/data/preprocessing/tokenizer.py
@@ -213,3 +213,106 @@ def _remove_delimiters(
@property
def eod(self):
return self.eod_id
+
+ @staticmethod
+ def _has_generation_markers(template: str | None) -> bool:
+ """Check if a template has generation markers."""
+ return template is not None and "{% generation %}" in template
+
+ def validate_chat_template(self) -> None:
+ """
+ Validate the tokenizer's chat template has generation markers.
+
+ Raises:
+ ValueError: If the tokenizer lacks a chat template or generation markers.
+ """
+ template = self.tokenizer.chat_template
+
+ if template is None:
+ raise ValueError(
+ "Tokenizer does not have a chat template. "
+ "Conversation format requires a tokenizer with a built-in chat template "
+ "containing {% generation %}...{% endgeneration %} markers."
+ )
+
+ if not self._has_generation_markers(template):
+ raise ValueError(
+ "Tokenizer's chat template does not contain {% generation %}...{% endgeneration %} markers. "
+ "These markers are required to determine which tokens to train on. "
+ "Please use a tokenizer with generation markers in its chat template."
+ )
+
+ def tokenize_chat(
+ self,
+ messages: list[dict[str, str]],
+ begin: bool = True,
+ end: bool = True,
+ data_type: DataType = DataType.int64,
+ ) -> tuple["torch.Tensor", list[tuple[int, int]]]:
+ """
+ Apply chat template and return (tokens, loss_masking_spans).
+
+ The loss_masking_spans mark token ranges to EXCLUDE from training (where the model
+ should not learn). These are derived from the chat template's generation markers -
+ tokens outside {% generation %}...{% endgeneration %} blocks are masked.
+ """
+ import torch
+
+ result = self.tokenizer.apply_chat_template(
+ messages,
+ tokenize=True,
+ return_assistant_tokens_mask=True,
+ return_dict=True,
+ add_generation_prompt=False,
+ )
+ tokens = result["input_ids"]
+ train_mask = result["assistant_masks"]
+
+ # Prepend BOS / append EOS if not already present anywhere in the sequence.
+ # We check anywhere (not just first/last) because some chat templates add trailing
+ # whitespace after the final EOS token, e.g. "<|im_end|>\n".
+ prepend_bos = begin and self.bod_id not in tokens
+ append_eos = end and self.eod_id not in tokens
+ tokens = [self.bod_id] * prepend_bos + list(tokens) + [self.eod_id] * append_eos
+ train_mask = [False] * prepend_bos + [bool(m) for m in train_mask] + [False] * append_eos
+
+ # Convert boolean train mask to loss masking spans (spans where train_mask[i] == False)
+ loss_masking_spans = _train_mask_to_loss_spans(train_mask)
+
+ if self._config.max_vocab_size is not None:
+ tokens = (
+ torch.tensor(
+ tokens,
+ dtype=torch.int64 if len(self.tokenizer) > torch.iinfo(data_type.torch).max else data_type.torch,
+ )
+ % self._config.max_vocab_size
+ ).to(data_type.torch)
+ else:
+ tokens = torch.tensor(tokens, dtype=data_type.torch)
+ return tokens, loss_masking_spans
+
+
+def _train_mask_to_loss_spans(train_mask: list[bool]) -> list[tuple[int, int]]:
+ """
+ Convert a boolean train mask to loss masking spans.
+
+ Args:
+ train_mask: Boolean list where True = train on this token, False = don't train
+
+ Returns:
+ List of (begin, end) spans marking token ranges to EXCLUDE from training
+ (i.e., where train_mask[i] == False).
+ """
+ spans = []
+ start = None
+ for i, should_train in enumerate(train_mask):
+ if not should_train:
+ if start is None:
+ start = i
+ elif start is not None:
+ spans.append((start, i))
+ start = None
+ if start is not None:
+ spans.append((start, len(train_mask)))
+ return spans
+
diff --git a/fast_llm/models/gpt/conversion/apriel2.py b/fast_llm/models/gpt/conversion/apriel2.py
index 4ed588ed5..dc2d4b4ad 100644
--- a/fast_llm/models/gpt/conversion/apriel2.py
+++ b/fast_llm/models/gpt/conversion/apriel2.py
@@ -39,8 +39,20 @@ def import_config(cls, config: dict) -> dict:
"head_groups": config["head_groups"],
"head_size": config["head_size"],
"rotary": rotary,
- "add_linear_biases": config["add_linear_biases"],
}
+ # Per-layer bias configuration mirroring Fast-LLM structure
+ # If per-layer configs exist, use them; otherwise fall back to add_linear_biases
+ if "query_layer" in config:
+ result["query_layer"] = config["query_layer"]
+ if "key_layer" in config:
+ result["key_layer"] = config["key_layer"]
+ if "value_layer" in config:
+ result["value_layer"] = config["value_layer"]
+ if "dense_layer" in config:
+ result["dense_layer"] = config["dense_layer"]
+ # add_linear_biases serves as default for layers without explicit config
+ if "add_linear_biases" in config:
+ result["add_linear_biases"] = config["add_linear_biases"]
if "window_size" in config:
result["window_size"] = config["window_size"]
return result
@@ -58,18 +70,37 @@ def export_config(cls, config: AttentionConfig) -> dict:
else:
raise NotImplementedError(f"Unsupported rotary type: {type(config.rotary).__name__}")
- return {
+ result = {
"type": "attention",
"heads": config.heads,
"head_groups": config.head_groups,
"head_size": config.head_size,
- "add_linear_biases": config.add_linear_biases,
"rotary": {
"type": rotary_type,
"theta": config.rotary.theta,
},
"window_size": config.window_size,
}
+ # Export per-layer bias configuration
+ # Only include if explicitly set (not None)
+ if config.query_layer.bias.enabled is not None:
+ result["query_layer"] = {"bias": {"enabled": config.query_layer.bias.enabled}}
+ if config.key_layer.bias.enabled is not None:
+ result["key_layer"] = {"bias": {"enabled": config.key_layer.bias.enabled}}
+ if config.value_layer.bias.enabled is not None:
+ result["value_layer"] = {"bias": {"enabled": config.value_layer.bias.enabled}}
+ if config.dense_layer.bias.enabled is not None:
+ result["dense_layer"] = {"bias": {"enabled": config.dense_layer.bias.enabled}}
+ # add_linear_biases as fallback default
+ result["add_linear_biases"] = config.add_linear_biases
+ return result
+
+ @classmethod
+ def _get_effective_bias(cls, layer_config, default: bool) -> bool:
+ """Get effective bias setting: use layer-specific if set, else default."""
+ if layer_config.bias.enabled is not None:
+ return layer_config.bias.enabled
+ return default
@classmethod
def get_converters(
@@ -79,11 +110,20 @@ def get_converters(
hf_prefix: str,
drop_on_export: bool = False,
) -> list[WeightConverter]:
+ # Determine effective bias for each projection
+ q_bias = cls._get_effective_bias(config.query_layer, config.add_linear_biases)
+ k_bias = cls._get_effective_bias(config.key_layer, config.add_linear_biases)
+ v_bias = cls._get_effective_bias(config.value_layer, config.add_linear_biases)
+ o_bias = cls._get_effective_bias(config.dense_layer, config.add_linear_biases)
+ # For key_value, both k and v must have same bias setting
+ # (they're combined in Fast-LLM's key_value layer)
+ kv_bias = k_bias and v_bias
+
return [
*get_weight_and_bias_converters(
f"{fast_llm_prefix}.query",
f"{hf_prefix}.q_proj",
- config.add_linear_biases,
+ q_bias,
QueryWeightConverter,
config,
drop_on_export=drop_on_export,
@@ -91,7 +131,7 @@ def get_converters(
*get_weight_and_bias_converters(
f"{fast_llm_prefix}.key_value",
(f"{hf_prefix}.k_proj", f"{hf_prefix}.v_proj"),
- config.add_linear_biases,
+ kv_bias,
KeyValueWeightConverter,
config,
drop_on_export=drop_on_export,
@@ -99,7 +139,7 @@ def get_converters(
*get_weight_and_bias_converters(
f"{fast_llm_prefix}.dense",
f"{hf_prefix}.o_proj",
- config.add_linear_biases,
+ o_bias,
drop_on_export=drop_on_export,
),
]
@@ -524,6 +564,12 @@ def import_config(cls, config: dict, block_config: dict) -> dict:
"gated": mlp_config["gated"],
"add_linear_biases": mlp_config["add_linear_biases"],
}
+ # Import per-layer MLP bias settings (layer_1, layer_2)
+ for layer_name in ("layer_1", "layer_2"):
+ if layer_name in mlp_config:
+ layer_cfg = mlp_config[layer_name]
+ if "bias" in layer_cfg:
+ mlp[layer_name] = {"bias": layer_cfg["bias"]}
normalization = block_config["normalization"]
@@ -578,6 +624,11 @@ def export_config(cls, config: DecoderBlockConfig) -> dict:
"gated": config.mlp.gated,
"add_linear_biases": config.mlp.add_linear_biases,
}
+ # Export per-layer MLP bias settings (layer_1, layer_2)
+ if config.mlp.layer_1.bias.enabled is not None:
+ mlp["layer_1"] = {"bias": {"enabled": config.mlp.layer_1.bias.enabled}}
+ if config.mlp.layer_2.bias.enabled is not None:
+ mlp["layer_2"] = {"bias": {"enabled": config.mlp.layer_2.bias.enabled}}
normalization = {"type": norm_type_str, "epsilon": config.normalization.epsilon}
@@ -624,24 +675,52 @@ def get_converters(
)
)
- converters.extend(
- [
+ # Per-layer MLP bias: use layer-specific setting if set, else default
+ def get_mlp_layer_bias(layer_config, default: bool) -> bool:
+ if layer_config.bias.enabled is not None:
+ return layer_config.bias.enabled
+ return default
+
+ layer_1_bias = get_mlp_layer_bias(config.mlp.layer_1, config.mlp.add_linear_biases)
+ layer_2_bias = get_mlp_layer_bias(config.mlp.layer_2, config.mlp.add_linear_biases)
+
+ if config.mlp.gated:
+ # Gated MLP: gate_proj + up_proj -> layer_1 (split), down_proj -> layer_2
+ converters.extend([
*get_weight_and_bias_converters(
f"{fast_llm_prefix}.mlp.layer_1",
(f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"),
- config.mlp.add_linear_biases,
+ layer_1_bias,
SplitWeightConverter,
drop_on_export=drop_on_export,
),
*get_weight_and_bias_converters(
f"{fast_llm_prefix}.mlp.layer_2",
f"{hf_prefix}.mlp.down_proj",
- config.mlp.add_linear_biases,
+ layer_2_bias,
MLPLayer2Converter,
drop_on_export=drop_on_export,
),
- ]
- )
+ ])
+ else:
+ # Non-gated MLP: up_proj -> layer_1, down_proj -> layer_2
+ # Note: layer_2 still needs MLPLayer2Converter for the transpose
+ converters.extend([
+ *get_weight_and_bias_converters(
+ f"{fast_llm_prefix}.mlp.layer_1",
+ f"{hf_prefix}.mlp.up_proj",
+ layer_1_bias,
+ WeightConverter,
+ drop_on_export=drop_on_export,
+ ),
+ *get_weight_and_bias_converters(
+ f"{fast_llm_prefix}.mlp.layer_2",
+ f"{hf_prefix}.mlp.down_proj",
+ layer_2_bias,
+ MLPLayer2Converter,
+ drop_on_export=drop_on_export,
+ ),
+ ])
converters.extend(
[
diff --git a/fast_llm/models/gpt/conversion/qwen2.py b/fast_llm/models/gpt/conversion/qwen2.py
index a8bc33454..4ebf18c3a 100644
--- a/fast_llm/models/gpt/conversion/qwen2.py
+++ b/fast_llm/models/gpt/conversion/qwen2.py
@@ -1,15 +1,21 @@
import typing
from fast_llm.engine.checkpoint.config import CheckpointFormat
+from fast_llm.engine.checkpoint.external import WeightConverter
from fast_llm.layers.attention.config import AttentionConfig
+from fast_llm.layers.decoder.mlp.config import MLPConfig
from fast_llm.models.gpt.conversion.config import Qwen2CheckpointFormat
from fast_llm.models.gpt.conversion.llama import (
+ KeyValueWeightConverter,
LlamaAttentionConverter,
LlamaBaseModelConverter,
LlamaBlockConverter,
LlamaDecoderConverter,
LlamaHeadConverter,
LlamaHuggingfaceCheckpointHandler,
+ LlamaMLPConverter,
+ QueryWeightConverter,
+ get_weight_and_bias_converters,
)
from fast_llm.utils import Assert
@@ -17,6 +23,22 @@
class Qwen2AttentionConverter(LlamaAttentionConverter):
# TODO: Support sliding window with max_window_layers (need 2 kinds of block?)
+ @classmethod
+ def import_config(cls, config: dict) -> dict:
+ config["attention_bias"] = True
+ out = super().import_config(config)
+ out["query_layer"] = {"bias": {"enabled": True}}
+ out["key_layer"] = {"bias": {"enabled": True}}
+ out["value_layer"] = {"bias": {"enabled": True}}
+ out["dense_layer"] = {"bias": {"enabled": False}}
+ return out
+
+ @classmethod
+ def export_config(cls, config: AttentionConfig) -> dict:
+ out = super().export_config(config)
+ del out["attention_bias"]
+ return out
+
@classmethod
def _check_config(cls, config: AttentionConfig) -> None:
Assert.is_(type(config), AttentionConfig)
@@ -32,9 +54,56 @@ def _check_config(cls, config: AttentionConfig) -> None:
Assert.is_(config.value_layer.bias.enabled, True)
Assert.incl(config.dense_layer.bias.enabled, (None, False))
+ @classmethod
+ def get_converters(
+ cls,
+ config: AttentionConfig,
+ fast_llm_prefix: str,
+ hf_prefix: str,
+ drop_on_export: bool = False,
+ ) -> list[WeightConverter]:
+ return [
+ *get_weight_and_bias_converters(
+ f"{fast_llm_prefix}.query",
+ f"{hf_prefix}.q_proj",
+ True,
+ QueryWeightConverter,
+ config,
+ drop_on_export=drop_on_export,
+ ),
+ *get_weight_and_bias_converters(
+ f"{fast_llm_prefix}.key_value",
+ (f"{hf_prefix}.k_proj", f"{hf_prefix}.v_proj"),
+ True,
+ KeyValueWeightConverter,
+ config,
+ drop_on_export=drop_on_export,
+ ),
+ *get_weight_and_bias_converters(
+ f"{fast_llm_prefix}.dense",
+ f"{hf_prefix}.o_proj",
+ False,
+ drop_on_export=drop_on_export,
+ ),
+ ]
+
+
+class Qwen2MLPConverter(LlamaMLPConverter):
+ @classmethod
+ def import_config(cls, config: dict) -> dict:
+ config["mlp_bias"] = False
+ return super().import_config(config)
+
+ @classmethod
+ def export_config(cls, config: MLPConfig) -> dict:
+ out = super().export_config(config)
+ del out["mlp_bias"]
+ return out
+
class Qwen2BlockConverter(LlamaBlockConverter):
mixer_converter_class: typing.ClassVar[type[Qwen2AttentionConverter]] = Qwen2AttentionConverter
+ mlp_converter_class: typing.ClassVar[type[Qwen2MLPConverter]] = Qwen2MLPConverter
class Qwen2DecoderConverter(LlamaDecoderConverter):
diff --git a/fast_llm_external_models/apriel2/cache.py b/fast_llm_external_models/apriel2/cache.py
index 86c67a085..32db547b9 100644
--- a/fast_llm_external_models/apriel2/cache.py
+++ b/fast_llm_external_models/apriel2/cache.py
@@ -4,14 +4,18 @@
class _AttentionCache:
- __slots__ = ["key", "value", "window"]
+ __slots__ = ["key", "value", "window", "cumulative_length"]
def __init__(self, window=None):
self.key = None
self.value = None
self.window = window
+ self.cumulative_length = 0
def update(self, key, value):
+ new_tokens = key.shape[-2]
+ self.cumulative_length += new_tokens
+
if self.key is None:
if self.window and key.shape[-2] > self.window:
self.key = key[..., -self.window :, :].contiguous()
@@ -35,6 +39,40 @@ def _window(self, cache, new):
return cache
return torch.cat([cache, new], -2)[..., -self.window :, :].contiguous()
+ def reset(self):
+ self.key = None
+ self.value = None
+ self.cumulative_length = 0
+
+ def reorder(self, beam_idx):
+ if self.key is not None:
+ self.key = self.key.index_select(0, beam_idx.to(self.key.device))
+ self.value = self.value.index_select(0, beam_idx.to(self.value.device))
+
+ def crop(self, max_length):
+ if self.key is not None:
+ self.key = self.key[..., :max_length, :]
+ self.value = self.value[..., :max_length, :]
+ self.cumulative_length = self.key.shape[-2]
+
+ def batch_repeat(self, repeats):
+ if self.key is not None:
+ self.key = self.key.repeat_interleave(repeats, dim=0)
+ self.value = self.value.repeat_interleave(repeats, dim=0)
+
+ def batch_select(self, indices):
+ if self.key is not None:
+ self.key = self.key.index_select(0, indices.to(self.key.device))
+ self.value = self.value.index_select(0, indices.to(self.value.device))
+
+ @property
+ def is_initialized(self):
+ return self.key is not None
+
+ @property
+ def batch_size(self):
+ return self.key.shape[0] if self.key is not None else None
+
class _SSMCache:
__slots__ = ["conv", "recurrent"]
@@ -43,6 +81,52 @@ def __init__(self):
self.conv = None
self.recurrent = None
+ def reset(self):
+ self.conv = None
+ self.recurrent = None
+
+ def reorder(self, beam_idx):
+ if self.conv is not None:
+ if isinstance(self.conv, tuple):
+ self.conv = tuple(c.index_select(0, beam_idx.to(c.device)) for c in self.conv)
+ else:
+ self.conv = self.conv.index_select(0, beam_idx.to(self.conv.device))
+ if self.recurrent is not None:
+ self.recurrent = self.recurrent.index_select(0, beam_idx.to(self.recurrent.device))
+
+ def crop(self, max_length):
+ pass # SSM caches don't have sequence dimension to crop
+
+ def batch_repeat(self, repeats):
+ if self.conv is not None:
+ if isinstance(self.conv, tuple):
+ self.conv = tuple(c.repeat_interleave(repeats, dim=0) for c in self.conv)
+ else:
+ self.conv = self.conv.repeat_interleave(repeats, dim=0)
+ if self.recurrent is not None:
+ self.recurrent = self.recurrent.repeat_interleave(repeats, dim=0)
+
+ def batch_select(self, indices):
+ if self.conv is not None:
+ if isinstance(self.conv, tuple):
+ self.conv = tuple(c.index_select(0, indices.to(c.device)) for c in self.conv)
+ else:
+ self.conv = self.conv.index_select(0, indices.to(self.conv.device))
+ if self.recurrent is not None:
+ self.recurrent = self.recurrent.index_select(0, indices.to(self.recurrent.device))
+
+ @property
+ def is_initialized(self):
+ return self.conv is not None
+
+ @property
+ def batch_size(self):
+ if self.conv is None:
+ return None
+ if isinstance(self.conv, tuple):
+ return self.conv[0].shape[0]
+ return self.conv.shape[0]
+
class _DummyCacheLayer:
pass
@@ -93,14 +177,19 @@ def set_active_mixer(self, layer_idx, mixer_name):
self.active_mixers[layer_idx] = mixer_name
def get_seq_length(self, layer_idx=0):
+ """Returns the cumulative sequence length of tokens seen by the cache.
+
+ For sliding window caches, this returns the total tokens seen (not just cached).
+ This matches HuggingFace's DynamicSlidingWindowLayer behavior.
+ """
layer = self.layers[layer_idx]
if isinstance(layer, dict):
mixer = self.active_mixers[layer_idx]
if mixer and isinstance(layer[mixer], _AttentionCache):
- return layer[mixer].key.shape[-2] if layer[mixer].key is not None else 0
+ return layer[mixer].cumulative_length
return 0
if isinstance(layer, _AttentionCache):
- return layer.key.shape[-2] if layer.key is not None else 0
+ return layer.cumulative_length
return 0
def get_max_cache_shape(self, layer_idx=0):
@@ -114,22 +203,61 @@ def get_max_cache_shape(self, layer_idx=0):
return None
def get_mask_sizes(self, cache_position, layer_idx):
+ """Return the length and offset of the cache, used to generate the attention mask.
+
+ For standard (non-sliding) attention:
+ kv_offset = 0 (KV[0] corresponds to sequence position 0)
+ kv_length = cumulative_length + query_length
+
+ For sliding window attention:
+ kv_offset = max(cumulative_length - window + 1, 0)
+ kv_length = min(cumulative_length, window - 1) + query_length
+
+ For SSM/linear layers:
+ kv_offset = 0, kv_length = query_length (no KV cache to attend to)
+ """
query_length = cache_position.shape[0]
- past_seen_tokens = self.get_seq_length(layer_idx)
- kv_length = query_length + past_seen_tokens
- kv_offset = past_seen_tokens
- return kv_length, kv_offset
+ layer = self.layers[layer_idx]
+
+ # Handle stochastic layers by getting the active mixer's cache
+ if isinstance(layer, dict):
+ mixer = self.active_mixers[layer_idx]
+ if mixer is None:
+ # No active mixer set, return defaults
+ return query_length, 0
+ cache = layer[mixer]
+ else:
+ cache = layer
+
+ # SSM layers don't have KV cache for attention mask purposes
+ if isinstance(cache, _SSMCache):
+ return query_length, 0
+
+ # Attention cache - check if sliding window
+ if isinstance(cache, _AttentionCache):
+ cumulative = cache.cumulative_length
+ window = cache.window
+
+ if window is not None:
+ # Sliding window attention
+ kv_offset = max(cumulative - window + 1, 0)
+ if cumulative >= window:
+ kv_length = window - 1 + query_length
+ else:
+ kv_length = cumulative + query_length
+ else:
+ # Full attention
+ kv_offset = 0
+ kv_length = cumulative + query_length
+
+ return kv_length, kv_offset
+
+ # Fallback
+ return query_length, 0
@property
def has_previous_state(self):
- for layer in self.layers:
- if isinstance(layer, dict):
- for cache in layer.values():
- if isinstance(cache, _SSMCache) and cache.conv is not None:
- return True
- elif isinstance(layer, _SSMCache) and layer.conv is not None:
- return True
- return False
+ return any(isinstance(cache, _SSMCache) and cache.conv is not None for cache in self._iter_caches())
@property
def key_cache(self):
@@ -147,101 +275,33 @@ def conv_states(self):
def recurrent_states(self):
return _LayerListAccessor(self, "recurrent")
- def reorder_cache(self, beam_idx):
- for i, layer in enumerate(self.layers):
+ def _iter_caches(self):
+ """Iterate over all leaf cache objects (flattening stochastic layer dicts)."""
+ for layer in self.layers:
if isinstance(layer, dict):
- for cache in layer.values():
- self._reorder_cache_obj(cache, beam_idx)
+ yield from layer.values()
else:
- self._reorder_cache_obj(layer, beam_idx)
+ yield layer
- def _reorder_cache_obj(self, cache, beam_idx):
- if isinstance(cache, _AttentionCache):
- if cache.key is not None:
- cache.key = cache.key.index_select(0, beam_idx.to(cache.key.device))
- cache.value = cache.value.index_select(0, beam_idx.to(cache.value.device))
- elif isinstance(cache, _SSMCache):
- if cache.conv is not None:
- # Handle both single tensor (GDN/Mamba) and tuple (KDA) conv states
- if isinstance(cache.conv, tuple):
- cache.conv = tuple(c.index_select(0, beam_idx.to(c.device)) for c in cache.conv)
- else:
- cache.conv = cache.conv.index_select(0, beam_idx.to(cache.conv.device))
- if cache.recurrent is not None:
- cache.recurrent = cache.recurrent.index_select(0, beam_idx.to(cache.recurrent.device))
+ def reorder_cache(self, beam_idx):
+ for cache in self._iter_caches():
+ cache.reorder(beam_idx)
def reset(self):
- for layer in self.layers:
- if isinstance(layer, dict):
- for cache in layer.values():
- self._reset_cache_obj(cache)
- else:
- self._reset_cache_obj(layer)
-
- def _reset_cache_obj(self, cache):
- if isinstance(cache, _AttentionCache):
- cache.key = None
- cache.value = None
- elif isinstance(cache, _SSMCache):
- cache.conv = None
- cache.recurrent = None
+ for cache in self._iter_caches():
+ cache.reset()
def crop(self, max_length):
- for layer in self.layers:
- if isinstance(layer, dict):
- for cache in layer.values():
- if isinstance(cache, _AttentionCache) and cache.key is not None:
- cache.key = cache.key[..., :max_length, :]
- cache.value = cache.value[..., :max_length, :]
- elif isinstance(layer, _AttentionCache) and layer.key is not None:
- layer.key = layer.key[..., :max_length, :]
- layer.value = layer.value[..., :max_length, :]
+ for cache in self._iter_caches():
+ cache.crop(max_length)
def batch_repeat_interleave(self, repeats):
- for layer in self.layers:
- if isinstance(layer, dict):
- for cache in layer.values():
- self._batch_repeat_cache_obj(cache, repeats)
- else:
- self._batch_repeat_cache_obj(layer, repeats)
-
- def _batch_repeat_cache_obj(self, cache, repeats):
- if isinstance(cache, _AttentionCache):
- if cache.key is not None:
- cache.key = cache.key.repeat_interleave(repeats, dim=0)
- cache.value = cache.value.repeat_interleave(repeats, dim=0)
- elif isinstance(cache, _SSMCache):
- if cache.conv is not None:
- # Handle both single tensor (GDN/Mamba) and tuple (KDA) conv states
- if isinstance(cache.conv, tuple):
- cache.conv = tuple(c.repeat_interleave(repeats, dim=0) for c in cache.conv)
- else:
- cache.conv = cache.conv.repeat_interleave(repeats, dim=0)
- if cache.recurrent is not None:
- cache.recurrent = cache.recurrent.repeat_interleave(repeats, dim=0)
+ for cache in self._iter_caches():
+ cache.batch_repeat(repeats)
def batch_select_indices(self, indices):
- for layer in self.layers:
- if isinstance(layer, dict):
- for cache in layer.values():
- self._batch_select_cache_obj(cache, indices)
- else:
- self._batch_select_cache_obj(layer, indices)
-
- def _batch_select_cache_obj(self, cache, indices):
- if isinstance(cache, _AttentionCache):
- if cache.key is not None:
- cache.key = cache.key.index_select(0, indices.to(cache.key.device))
- cache.value = cache.value.index_select(0, indices.to(cache.value.device))
- elif isinstance(cache, _SSMCache):
- if cache.conv is not None:
- # Handle both single tensor (GDN/Mamba) and tuple (KDA) conv states
- if isinstance(cache.conv, tuple):
- cache.conv = tuple(c.index_select(0, indices.to(c.device)) for c in cache.conv)
- else:
- cache.conv = cache.conv.index_select(0, indices.to(cache.conv.device))
- if cache.recurrent is not None:
- cache.recurrent = cache.recurrent.index_select(0, indices.to(cache.recurrent.device))
+ for cache in self._iter_caches():
+ cache.batch_select(indices)
@property
def is_compileable(self):
@@ -249,19 +309,7 @@ def is_compileable(self):
@property
def is_initialized(self):
- for layer in self.layers:
- if isinstance(layer, dict):
- for cache in layer.values():
- if isinstance(cache, _AttentionCache) and cache.key is not None:
- return True
- if isinstance(cache, _SSMCache) and cache.conv is not None:
- return True
- else:
- if isinstance(layer, _AttentionCache) and layer.key is not None:
- return True
- if isinstance(layer, _SSMCache) and layer.conv is not None:
- return True
- return False
+ return any(cache.is_initialized for cache in self._iter_caches())
@property
def is_sliding(self):
@@ -280,39 +328,20 @@ def is_sliding(self):
@property
def max_batch_size(self):
- for layer in self.layers:
- if isinstance(layer, dict):
- for cache in layer.values():
- if isinstance(cache, _AttentionCache) and cache.key is not None:
- return cache.key.shape[0]
- if isinstance(cache, _SSMCache) and cache.conv is not None:
- # Handle both single tensor and tuple conv states
- if isinstance(cache.conv, tuple):
- return cache.conv[0].shape[0]
- return cache.conv.shape[0]
- else:
- if isinstance(layer, _AttentionCache) and layer.key is not None:
- return layer.key.shape[0]
- if isinstance(layer, _SSMCache) and layer.conv is not None:
- # Handle both single tensor and tuple conv states
- if isinstance(layer.conv, tuple):
- return layer.conv[0].shape[0]
- return layer.conv.shape[0]
+ for cache in self._iter_caches():
+ bs = cache.batch_size
+ if bs is not None:
+ return bs
return None
@property
def max_cache_len(self):
- max_len = None
- for layer in self.layers:
- if isinstance(layer, dict):
- for cache in layer.values():
- if isinstance(cache, _AttentionCache):
- if cache.window is not None:
- max_len = cache.window if max_len is None else min(max_len, cache.window)
- elif isinstance(layer, _AttentionCache):
- if layer.window is not None:
- max_len = layer.window if max_len is None else min(max_len, layer.window)
- return max_len
+ windows = [
+ cache.window
+ for cache in self._iter_caches()
+ if isinstance(cache, _AttentionCache) and cache.window is not None
+ ]
+ return min(windows) if windows else None
def __len__(self):
return len(self.layers)
diff --git a/fast_llm_external_models/apriel2/conversion/__init__.py b/fast_llm_external_models/apriel2/conversion/__init__.py
index 983a632e0..c6bad6626 100644
--- a/fast_llm_external_models/apriel2/conversion/__init__.py
+++ b/fast_llm_external_models/apriel2/conversion/__init__.py
@@ -1,86 +1,122 @@
"""Weight conversion system for Apriel2 models.
-Architecture Overview
-=====================
+Overview
+========
-This package implements a declarative weight transformation system with two
-orthogonal concerns:
+This package implements a declarative weight transformation system. The core
+abstraction separates config composition (structural) from plan execution (weights).
-1. **Config Composition** - Structural transformations of model configs
-2. **Plan Building & Execution** - Weight transformations between configs
+Conceptual Types
+================
-These concerns are intentionally separated:
-- Config composition determines WHAT the target architecture looks like
-- Plan building determines HOW weights are transformed to match
-- The `init` field bridges them: it's config metadata consumed by the plan builder
+All configs are ``dict``, but we distinguish three conceptual types:
-Key Design Decisions
-====================
+**State (S)** - A complete model config without ``init`` fields.
+ What you load from disk or save after conversion.
-**Declarative Plans**
- Plans are DATA (JSON-serializable expressions), not functions. This enables:
- - Inspection and debugging of transformations
- - Serialization for distributed execution
- - Composition via substitution rather than function composition
-
-**Separation of Config and Weights**
- The `init` field in surgery specs controls weight handling (transfer vs random)
- but does NOT affect config composition. Config composition is purely structural.
- After composition, `init` fields are stripped from complete configs.
-
-**Composition Semantics**
- Surgery specs use declarative (merge) composition, not operational (function)
- composition. For "additive" surgeries (modifying existing structure), the
- monoid action law holds. For "replacement" surgeries (defining complete new
- structure), sequential application differs from composed application by design.
-
-**Cross-Type Derivation**
- When converting between mixer types (e.g., attention → mamba), geometric
- parameters are derived where possible:
- - attention.heads → mamba dimensions (MIL conversion)
- - attention.heads → gdn heads (DIL conversion)
+**Partial Surgery (P)** - An incomplete config specifying changes.
+ May contain ``init`` fields (``transfer`` or ``random``).
-Module Structure
-================
+**Transition Spec (T)** - A complete config WITH ``init`` fields.
+ The result of applying surgery to a state. Describes both target
+ structure and weight initialization mode.
+
+Algebraic Structure
+===================
+
+**Monoid**: Partial surgeries compose via deep merge::
+
+ compose_configs : P × P → P
+
+**Action**: Surgeries act on states to produce transition specs::
+
+ compose_configs : S × P → T
+ compose_configs : T × P → T
+
+**Extraction**: Strip init to get a state::
+
+ strip_init_fields : T → S
+
+**Planning**: Build weight transformation from source state + transition spec::
+
+ plan_surgery : S × T → Plan
+
+The ``init`` Field
+==================
+
+The ``init`` field in surgeries specifies weight initialization:
-- `config.py` - Config composition (compose_configs, apply_surgery)
-- `converters.py` - Plan builders (plan_surgery, plan_mil_attention_to_mamba, etc.)
-- `expr.py` - Expression types and plan class (Ref, Slice, Concat, Init, ExprPlan)
-- `executor.py` - Plan execution (StreamingExecutor, execute)
-- `io.py` - Streaming I/O (SafetensorLoader, ShardedSafetensorWriter)
-- `llava/` - Source-specific converter for Llava → Apriel2
+- ``init: transfer`` → transfer/convert weights from source
+- ``init: random`` → randomly initialize weights
-Example Usage
+This field is preserved through ``compose_configs`` so ``plan_surgery`` can read it.
+Use ``strip_init_fields`` before saving configs to disk.
+
+Typical Usage
=============
+::
+
from fast_llm_external_models.apriel2.conversion import (
compose_configs,
plan_surgery,
+ strip_init_fields,
execute,
)
- # 1. Compose configs to get target architecture
- target_config = compose_configs(source_config, surgery_spec)
+ # Load source state
+ source_state = load_config(...) # S
- # 2. Build plan for weight transformation
- plan = plan_surgery(source_config, surgery_spec)
+ # Apply surgery
+ surgery = {"decoder": {"block": {"mixer": {"type": "gdn", "init": "random"}}}} # P
+ transition = compose_configs(source_state, surgery) # T
- # 3. Execute plan to transform weights
- target_weights = execute(plan, source_weights, seed=42)
+ # Build and execute plan
+ plan = plan_surgery(source_state, transition)
+ weights = execute(plan, source_weights, seed=42)
-For streaming I/O with large models:
+ # Save (strip init first)
+ target_state = strip_init_fields(transition) # S
+ save_config(target_state)
- from fast_llm_external_models.apriel2.conversion import (
- StreamingExecutor,
- SafetensorLoader,
- ShardedSafetensorWriter,
- )
+For chained surgeries::
+
+ current_state = source_state # S
+ current_plan = identity_plan
+
+ for surgery in surgery_chain: # each P
+ transition = compose_configs(current_state, surgery) # T
+ plan = plan_surgery(current_state, transition)
+ current_plan = compose(current_plan, plan)
+ current_state = strip_init_fields(transition) # S <- IMPORTANT!
+
+**Note**: The ``strip_init_fields`` call is critical. It ensures that ``init: random``
+applies only to the surgery that introduces a component. Without stripping, subsequent
+surgeries would re-randomize existing components. See ``config.py`` docstring for details.
+
+Key Design Decisions
+====================
+
+**Declarative Plans**
+ Plans are data (expressions), not functions. Enables inspection,
+ serialization, and composition via substitution.
+
+**Inheritance Semantics**
+ When S × P → T, unspecified fields inherit from source.
+ Cross-type derivation maps geometry (attention.heads → gdn.value_heads).
+
+**Additive vs Replacement Surgeries**
+ Additive surgeries (no ``type:`` declaration) satisfy the action law.
+ Replacement surgeries (explicit ``type:``) use last-write-wins.
+
+Module Structure
+================
- with SafetensorLoader(source_files) as loader:
- executor = StreamingExecutor(plan, loader)
- with ShardedSafetensorWriter(output_dir) as writer:
- for key, tensor in executor.execute(seed=42):
- writer.add(key, tensor)
+- ``config.py`` - Config composition (compose_configs, strip_init_fields)
+- ``converters.py`` - Plan builders (plan_surgery, plan_mil_attention_to_mamba)
+- ``expr.py`` - Expression types (Ref, Slice, Concat, Init, ExprPlan)
+- ``executor.py`` - Plan execution (StreamingExecutor, execute)
+- ``io.py`` - Streaming I/O (SafetensorLoader, ShardedSafetensorWriter)
"""
# Core types and plan operations
@@ -127,7 +163,7 @@
)
# Config composition
-from fast_llm_external_models.apriel2.conversion.config import compose_configs
+from fast_llm_external_models.apriel2.conversion.config import compose_configs, strip_init_fields
# Source-specific converters
from fast_llm_external_models.apriel2.conversion.llava import (
@@ -175,6 +211,7 @@
"plan_kil_attention_to_kda",
# Config composition
"compose_configs",
+ "strip_init_fields",
# Source-specific converters
"convert_llava_config",
"plan_llava_to_apriel2",
diff --git a/fast_llm_external_models/apriel2/conversion/config.py b/fast_llm_external_models/apriel2/conversion/config.py
index 48f8ff44b..3752688c1 100644
--- a/fast_llm_external_models/apriel2/conversion/config.py
+++ b/fast_llm_external_models/apriel2/conversion/config.py
@@ -1,56 +1,136 @@
"""Config composition for Apriel2 architecture transformations.
-This module handles STRUCTURAL composition of configs, independent of weight handling.
-The `init` field in surgery specs is preserved as metadata for the plan builder but
-does not affect how configs are composed.
+Conceptual Types
+================
-Composition Cases
-=================
+The system operates on three conceptual types, all represented as ``dict``:
-compose_configs(base, overlay) handles four cases based on completeness:
+**State (S)**
+ A complete structural description of a model. Has ``hidden_size`` and ``decoder``.
+ Does NOT contain ``init`` fields. Represents WHAT a model looks like.
-1. **Complete + Partial** → Apply surgery semantics (inheritance, cross-type derivation)
-2. **Partial + Partial** → Deep merge (monoid operation on surgery specs)
-3. **Partial + Complete** → Overlay wins (complete config replaces partial)
-4. **Complete + Complete** → Deep merge, then strip `init` fields
+ Example: A saved config.json, or a model you're about to transform.
-A config is "complete" if it has `hidden_size` and `decoder` (i.e., it's a full model
-config, not a surgery spec).
+**Partial Surgery (P)**
+ An incomplete config specifying fields to change. Missing ``hidden_size`` or
+ ``decoder``. May contain ``init`` fields specifying weight initialization mode.
-Surgery Semantics
-=================
+ Example: ``{"decoder": {"block": {"mixer": {"type": "gdn", "init": "random"}}}}``
-When applying a surgery spec to a complete config:
+**Transition Spec (T)**
+ A complete config WITH ``init`` fields. Describes both the target structure
+ AND how to initialize weights. This is the output of applying a surgery to
+ a state - it's a complete specification of the transformation.
-**Inheritance**
- Unspecified parameters inherit from the source config. New blocks inherit
- from the "default" block (first block in pattern, or the single fixed block).
+ Example: The result of ``compose_configs(state, surgery)`` before stripping.
-**Cross-Type Derivation**
- When changing mixer types, geometric parameters are derived where possible:
- - attention → sliding_window: preserve heads, head_groups, head_size
- - attention → gdn: heads → value_heads, head_groups → key_heads
- - attention → mamba: derive d_inner, d_xb, dt_rank from hidden_size
- - attention → kda: preserve heads, head_size → head_dim
+The distinction between S and T is semantic (presence of ``init``), not structural.
+Both are "complete" in the sense of having ``hidden_size`` and ``decoder``.
-**Stochastic Mixer Composition**
- Two semantics based on whether surgery declares `type: stochastic`:
- - Replacement: surgery declares type → only surgery's sub-mixers included
- - Additive: surgery omits type → source sub-mixers preserved, surgery adds/modifies
+Algebraic Structure
+===================
- This distinction means the monoid action law holds for additive surgeries but
- intentionally fails for replacement surgeries (they have "last-write-wins" semantics).
+**Partial Surgeries form a Monoid (P, ∘, {})**::
-The `init` Field
-================
+ compose_configs : P × P → P (deep merge, overlay wins)
+
+ Identity: compose_configs(p, {}) = compose_configs({}, p) = p
+ Associativity: compose_configs(compose_configs(a, b), c)
+ = compose_configs(a, compose_configs(b, c))
+
+**Surgeries act on States to produce Transition Specs**::
+
+ compose_configs : S × P → T (apply surgery with inheritance)
+ compose_configs : T × P → T (extend transition with more surgery)
+
+**Action Law (for additive surgeries)**::
+
+ compose_configs(compose_configs(s, p₁), p₂) = compose_configs(s, compose_configs(p₁, p₂))
+
+This law holds when surgeries are "additive" (modifying existing structure without
+declaring new types). For "replacement" surgeries (explicitly declaring ``type:``),
+the action law intentionally fails - this is last-write-wins semantics.
+
+**State Extraction**::
+
+ strip_init_fields : T → S (remove init metadata for saving)
+
+Operations Summary
+==================
+
+``compose_configs(base, overlay)`` dispatches based on completeness:
+
+1. **S × P → T** : Apply surgery to state (inheritance, cross-type derivation)
+2. **T × P → T** : Extend transition spec with more surgery
+3. **P × P → P** : Merge partial surgeries (monoid operation)
+4. **S × S → S** : Merge states (deep merge, rare)
+5. **P × S → S** : Overlay wins (complete replaces partial)
+
+``strip_init_fields(config)`` removes all ``init`` fields, converting T → S.
+
+Inheritance Semantics
+=====================
+
+When applying a surgery (S × P → T):
+
+- Unspecified fields inherit from source state
+- New decoder blocks inherit from the "default" block
+- Cross-type derivation maps geometry (attention.heads → gdn.value_heads, etc.)
+- Stochastic mixers: additive surgery preserves source mixers, replacement replaces
+
+The ``init`` Field
+==================
+
+The ``init`` field specifies weight initialization mode for ``plan_surgery()``:
+
+- ``init: transfer`` → transfer weights from source (possibly with conversion)
+- ``init: random`` → randomly initialize weights
+
+**Key invariant**: ``init`` is preserved through composition so ``plan_surgery()``
+can read it. Use ``strip_init_fields()`` to obtain a pure state for:
+
+- Saving to disk (config.json should not contain ``init``)
+- Starting the next surgery iteration (current_state should be S, not T)
+
+Typical Usage Pattern
+=====================
+
+::
+
+ current_state: S = load_config(...)
+
+ for surgery: P in surgery_chain:
+ transition: T = compose_configs(current_state, surgery) # S × P → T
+ plan = plan_surgery(current_state, transition) # plan reads init from T
+ current_state: S = strip_init_fields(transition) # T → S for next iteration
+
+ save_config(current_state) # S has no init fields
+
+Sequential vs Merged Surgery Application
+========================================
+
+**IMPORTANT**: Applying surgeries sequentially (with stripping) differs from merging
+surgeries first then applying once. This affects ``init`` semantics:
+
+**Sequential** (recommended)::
-The `init` field is metadata for the plan builder, NOT for config composition:
-- `init: transfer` → plan builder creates weight transfer mappings
-- `init: random` → plan builder creates random initialization
+ t1 = compose_configs(s, p1) # GDN gets init: random
+ s1 = strip_init_fields(t1) # GDN loses init
+ t2 = compose_configs(s1, p2) # GDN has init: None → transfer mode
-After surgery is applied to produce a complete config, ALL `init` fields are stripped.
-This ensures configs are purely structural and plan creation is Markovian (depends only
-on current config + surgery, not on history).
+**Merged**::
+
+ merged = compose_configs(p1, p2) # GDN keeps init: random from p1
+ t = compose_configs(s, merged) # GDN has init: random → random mode
+
+The sequential approach means ``init: random`` applies **only to the surgery that
+introduces a component**. Subsequent surgeries transfer existing weights by default.
+
+This is the intended behavior: if surgery 1 adds GDN with random init, and surgery 2
+adds sliding window (not mentioning GDN), GDN keeps its weights from surgery 1.
+
+The merged approach would re-randomize GDN in every execution, which is rarely desired.
+Always use the sequential pattern shown in "Typical Usage Pattern" above.
"""
from __future__ import annotations
@@ -65,14 +145,42 @@ def is_complete(config: dict) -> bool:
def compose_configs(base: dict, overlay: dict | None) -> dict:
- """Compose two configs.
+ """Compose configs. Dispatches based on completeness of arguments.
+
+ Type Signatures (see module docstring for S, P, T definitions)::
+
+ S × P → T Apply surgery to state, get transition spec
+ T × P → T Extend transition spec with more surgery
+ P × P → P Merge partial surgeries (monoid operation)
+ S × S → S Merge states (deep merge)
+ P × S → S Overlay wins
+
+ The ``init`` field is preserved in all cases. Use ``strip_init_fields()``
+ to convert T → S for saving or iteration.
Args:
- base: Base config (complete or partial surgery spec).
- overlay: Overlay config (complete or partial surgery spec).
+ base: State (S), transition spec (T), or partial surgery (P).
+ overlay: Partial surgery (P) or state (S).
Returns:
- Composed config.
+ Composed config. Type depends on inputs (see signatures above).
+
+ Algebraic Properties:
+ Monoid: ``compose(compose(p1, p2), p3) == compose(p1, compose(p2, p3))``
+
+ Action law (additive surgeries):
+ ``compose(compose(s, p1), p2) == compose(s, compose(p1, p2))``
+
+ Example::
+
+ # S × P → T (apply surgery to state)
+ state = {"hidden_size": 256, "decoder": {...}}
+ surgery = {"decoder": {"block": {"mixer": {"init": "random"}}}}
+ transition = compose_configs(state, surgery) # T, has init
+
+ # Build plan, then extract state
+ plan = plan_surgery(state, transition)
+ new_state = strip_init_fields(transition) # S, no init
"""
if not overlay:
return copy.deepcopy(base)
@@ -94,9 +202,8 @@ def compose_configs(base: dict, overlay: dict | None) -> dict:
if not base_complete and overlay_complete:
return copy.deepcopy(overlay)
- # Case 4: Both complete -> deep merge
+ # Case 4: Both complete -> deep merge (init preserved for plan_surgery)
result = _deep_merge(base, overlay)
- _strip_keys(result, {"init"})
return result
@@ -128,26 +235,53 @@ def _strip_keys(config: Any, keys_to_strip: set[str]) -> None:
_strip_keys(item, keys_to_strip)
+def strip_init_fields(config: dict) -> dict:
+ """Return a copy of config with all ``init`` fields stripped (T → S).
+
+ Converts a transition spec (T) to a state (S) by removing ``init`` metadata.
+ Use this:
+
+ 1. Before saving configs to disk (config.json should be purely structural)
+ 2. Between surgery iterations (so subsequent surgeries don't re-randomize)
+
+ See module docstring section "Sequential vs Merged Surgery Application" for
+ why stripping between iterations is critical.
+
+ Args:
+ config: Config dict (not modified). Typically a transition spec (T).
+
+ Returns:
+ A deep copy with all ``init`` fields recursively removed (a state S).
+ """
+ result = copy.deepcopy(config)
+ _strip_keys(result, {"init"})
+ return result
+
+
# =============================================================================
# Surgery application with full semantics
# =============================================================================
def apply_surgery(source_config: dict, surgery_config: dict | None) -> dict:
- """Apply surgery specification to a complete source config.
+ """Apply surgery spec to complete config (the monoid action).
- This handles:
- - Top-level scalar overrides
- - Decoder composition (fixed vs pattern)
- - Stochastic mixer sub-mixer inheritance
- - Cross-type derivation (attention → gdn, attention → mamba)
+ This is the internal implementation of the monoid action: surgery specs
+ acting on complete configs. Called by compose_configs when base is complete
+ and overlay is partial.
+
+ Implements inheritance semantics:
+ - Unspecified fields inherit from source
+ - Cross-type derivation maps geometry (attention → gdn, etc.)
+ - Stochastic sub-mixers inherit from source's main mixer
+ - `init` fields are PRESERVED for plan_surgery() to see
Args:
- source_config: Complete Apriel2 config.
- surgery_config: Partial surgery specification.
+ source_config: Complete Apriel2 config (the state being acted on).
+ surgery_config: Partial surgery spec (the monoid element acting).
Returns:
- Complete Apriel2 config with surgery applied.
+ Complete config with surgery applied. `init` fields preserved.
"""
if not surgery_config:
return copy.deepcopy(source_config)
@@ -189,8 +323,9 @@ def apply_surgery(source_config: dict, surgery_config: dict | None) -> dict:
surgery_config["vision_encoder"],
)
- # Strip init keys from final result
- _strip_keys(result, {"init"})
+ # NOTE: We do NOT strip init keys here. The `init` field is preserved through
+ # composition so that plan_surgery() can see it and decide between transfer
+ # vs random initialization. The caller (convert.py) strips init before saving.
return result
@@ -392,6 +527,12 @@ def _compose_single_mixer(source: dict, surgery: dict, hidden_size: int) -> dict
result[key] = surgery[key]
elif key in source:
result[key] = source[key]
+ # Copy per-layer bias settings (query_layer, key_layer, value_layer, dense_layer)
+ for key in ["query_layer", "key_layer", "value_layer", "dense_layer", "add_linear_biases"]:
+ if key in surgery:
+ result[key] = surgery[key]
+ elif key in source:
+ result[key] = copy.deepcopy(source[key])
# Preserve init
if "init" in surgery:
result["init"] = surgery["init"]
diff --git a/fast_llm_external_models/apriel2/conversion/converters.py b/fast_llm_external_models/apriel2/conversion/converters.py
index 6d1350c54..c8b83f657 100644
--- a/fast_llm_external_models/apriel2/conversion/converters.py
+++ b/fast_llm_external_models/apriel2/conversion/converters.py
@@ -79,6 +79,21 @@
# This is the single source of truth for each mixer's weight schema.
+def _get_attention_bias_enabled(config: dict, layer_name: str) -> bool:
+ """Get whether bias is enabled for an attention layer.
+
+ Checks per-layer bias config (e.g., query_layer.bias.enabled).
+ Falls back to add_linear_biases if not set.
+ """
+ layer_cfg = config.get(layer_name, {})
+ bias_cfg = layer_cfg.get("bias", {})
+ enabled = bias_cfg.get("enabled")
+ if enabled is not None:
+ return enabled
+ # Fall back to add_linear_biases
+ return config.get("add_linear_biases", False)
+
+
def _plan_attention_mixer(
*,
prefix: W,
@@ -90,9 +105,13 @@ def _plan_attention_mixer(
Weight schema:
- q_proj.weight: (q_size, hidden_size)
+ - q_proj.bias: (q_size,) [if query_layer.bias.enabled]
- k_proj.weight: (kv_size, hidden_size)
+ - k_proj.bias: (kv_size,) [if key_layer.bias.enabled]
- v_proj.weight: (kv_size, hidden_size)
+ - v_proj.bias: (kv_size,) [if value_layer.bias.enabled]
- o_proj.weight: (hidden_size, q_size)
+ - o_proj.bias: (hidden_size,) [if dense_layer.bias.enabled]
Args:
prefix: Target weight path prefix.
@@ -100,12 +119,28 @@ def _plan_attention_mixer(
hidden_size: Model hidden size.
source_prefix: If provided, passthrough from source. If None, random init.
"""
+ # Check per-layer bias configuration
+ q_bias = _get_attention_bias_enabled(config, "query_layer")
+ k_bias = _get_attention_bias_enabled(config, "key_layer")
+ v_bias = _get_attention_bias_enabled(config, "value_layer")
+ o_bias = _get_attention_bias_enabled(config, "dense_layer")
+
if source_prefix is not None:
- # Passthrough
- return ExprPlan(mappings={
+ # Passthrough weights
+ mappings: dict[W, Expr] = {
prefix / proj / "weight": Ref(key=source_prefix / proj / "weight")
for proj in ["q_proj", "k_proj", "v_proj", "o_proj"]
- })
+ }
+ # Passthrough biases if enabled
+ if q_bias:
+ mappings[prefix / "q_proj" / "bias"] = Ref(key=source_prefix / "q_proj" / "bias")
+ if k_bias:
+ mappings[prefix / "k_proj" / "bias"] = Ref(key=source_prefix / "k_proj" / "bias")
+ if v_bias:
+ mappings[prefix / "v_proj" / "bias"] = Ref(key=source_prefix / "v_proj" / "bias")
+ if o_bias:
+ mappings[prefix / "o_proj" / "bias"] = Ref(key=source_prefix / "o_proj" / "bias")
+ return ExprPlan(mappings=mappings)
# Random init
heads = config["heads"]
@@ -114,12 +149,22 @@ def _plan_attention_mixer(
q_size = heads * head_size
kv_size = head_groups * head_size
- return ExprPlan(mappings={
+ mappings = {
prefix / "q_proj" / "weight": Init(shape=(q_size, hidden_size), init_type="kaiming"),
prefix / "k_proj" / "weight": Init(shape=(kv_size, hidden_size), init_type="kaiming"),
prefix / "v_proj" / "weight": Init(shape=(kv_size, hidden_size), init_type="kaiming"),
prefix / "o_proj" / "weight": Init(shape=(hidden_size, q_size), init_type="kaiming"),
- })
+ }
+ # Random init biases if enabled
+ if q_bias:
+ mappings[prefix / "q_proj" / "bias"] = Init(shape=(q_size,), init_type="zeros")
+ if k_bias:
+ mappings[prefix / "k_proj" / "bias"] = Init(shape=(kv_size,), init_type="zeros")
+ if v_bias:
+ mappings[prefix / "v_proj" / "bias"] = Init(shape=(kv_size,), init_type="zeros")
+ if o_bias:
+ mappings[prefix / "o_proj" / "bias"] = Init(shape=(hidden_size,), init_type="zeros")
+ return ExprPlan(mappings=mappings)
def _plan_mamba_mixer(
@@ -786,7 +831,70 @@ def plan_surgery(
source_config: dict,
target_config: dict,
) -> ExprPlan:
- """Build plan for Apriel2→Apriel2 surgery (MIL, DIL, KIL, stochastic mixers, etc.)."""
+ """Build a weight conversion plan: S × T → Plan.
+
+ Creates an ExprPlan mapping target weight keys to expressions over source weights.
+ Handles same-type passthrough, cross-type conversions (MIL, DIL, KIL), and
+ stochastic mixer routing.
+
+ Type Signature::
+
+ plan_surgery : S × T → Plan
+
+ Where S is a state (source) and T is a transition spec (target with ``init`` fields).
+
+ The ``init`` Field
+ ------------------
+
+ The ``init`` field in ``target_config`` controls weight initialization:
+
+ - ``init: transfer`` (or absent) → create Ref expressions (transfer from source)
+ - ``init: random`` → create Init expressions (random initialization)
+
+ This is why ``target_config`` should be a transition spec (T) from ``compose_configs``,
+ not a stripped state (S). If ``init`` fields are missing, all components default to
+ transfer mode.
+
+ Args:
+ source_config: State (S) - complete config describing source architecture.
+ Must have hidden_size, decoder, etc. No ``init`` fields expected.
+ target_config: Transition spec (T) - complete config with ``init`` fields.
+ Use ``compose_configs(source, surgery)`` to produce this.
+
+ Returns:
+ ExprPlan mapping target weight keys to expressions over source weights.
+
+ Example::
+
+ # Apply a surgery that wraps attention in a stochastic mixer
+ surgery_spec = {
+ "decoder": {"block": {"mixer": {
+ "type": "stochastic",
+ "mixers": {
+ "attention": {"init": "transfer"},
+ "gdn": {"type": "gdn", "init": "random"},
+ }
+ }}}
+ }
+
+ # S × P → T
+ transition = compose_configs(source_config, surgery_spec)
+
+ # S × T → Plan
+ plan = plan_surgery(source_config, transition)
+
+ # Execute
+ new_weights = execute(plan, source_weights, seed=42)
+
+ # T → S for saving
+ target_state = strip_init_fields(transition)
+
+ Note:
+ Both arguments must be complete (have hidden_size and decoder).
+ The target_config should retain ``init`` fields from the surgery spec.
+ Passing a stripped state as target will cause all components to use
+ transfer mode, which may not be intended.
+ """
hidden_size = target_config.get("hidden_size", source_config.get("hidden_size"))
assert hidden_size is not None, "hidden_size must be specified in source or target config"
@@ -839,14 +947,16 @@ def _plan_non_decoder_weights(config: dict) -> ExprPlan:
embed = W("model", "embed_tokens", "weight")
mappings[embed] = Ref(key=embed)
- head = W("lm_head", "weight")
- mappings[head] = Ref(key=head)
+ # lm_head only if not tied to embeddings
+ if not config.get("tie_word_embeddings", False):
+ head = W("lm_head", "weight")
+ mappings[head] = Ref(key=head)
norm = W("model", "norm", "weight")
mappings[norm] = Ref(key=norm)
- if "vision_encoder" in config:
- vision_config = config["vision_encoder"]
+ vision_config = config.get("vision_encoder")
+ if vision_config:
vision = W("model", "vision_encoder")
patch_emb = vision / "embeddings" / "patch_embeddings" / "weight"
@@ -986,6 +1096,24 @@ def _plan_mixer(
)
+def _get_mlp_bias_enabled(config: dict, layer_name: str) -> bool:
+ """Get whether bias is enabled for an MLP layer.
+
+ Checks per-layer bias config (e.g., layer_1.bias.enabled, layer_2.bias.enabled).
+ Falls back to add_linear_biases if not set.
+
+ Note: layer_1 corresponds to gate_proj and up_proj (gated MLP) or just up_proj (non-gated)
+ layer_2 corresponds to down_proj
+ """
+ layer_cfg = config.get(layer_name, {})
+ bias_cfg = layer_cfg.get("bias", {})
+ enabled = bias_cfg.get("enabled")
+ if enabled is not None:
+ return enabled
+ # Fall back to add_linear_biases
+ return config.get("add_linear_biases", False)
+
+
def _plan_mlp(
target_layer_idx: int,
source_layer_idx: int,
@@ -1006,7 +1134,7 @@ def _plan_mlp_transfer(
target_mlp: dict,
hidden_size: int,
) -> ExprPlan:
- """Passthrough for MLP weights."""
+ """Passthrough for MLP weights and biases."""
source_mlp_path = W("model", "decoder", "blocks", source_layer_idx, "mlp")
target_mlp_path = W("model", "decoder", "blocks", target_layer_idx, "mlp")
@@ -1019,10 +1147,37 @@ def _plan_mlp_transfer(
f"Use 'init: random' to initialize randomly."
)
- return ExprPlan(mappings={
+ # Check per-layer bias configuration
+ layer_1_bias = _get_mlp_bias_enabled(target_mlp, "layer_1")
+ layer_2_bias = _get_mlp_bias_enabled(target_mlp, "layer_2")
+
+ # Check if gated MLP (has gate_proj) or non-gated (only up_proj)
+ gated = target_mlp.get("gated", True) # Default to gated for backwards compatibility
+
+ # Passthrough weights
+ # layer_1 = gate_proj + up_proj (gated) or just up_proj (non-gated)
+ # layer_2 = down_proj
+ if gated:
+ weight_projs = ["gate_proj", "up_proj", "down_proj"]
+ else:
+ weight_projs = ["up_proj", "down_proj"]
+
+ mappings: dict[W, Expr] = {
target_mlp_path / proj / "weight": Ref(key=source_mlp_path / proj / "weight")
- for proj in ["gate_proj", "up_proj", "down_proj"]
- })
+ for proj in weight_projs
+ }
+
+ # Passthrough biases if enabled
+ if layer_1_bias:
+ if gated:
+ mappings[target_mlp_path / "gate_proj" / "bias"] = Ref(key=source_mlp_path / "gate_proj" / "bias")
+ mappings[target_mlp_path / "up_proj" / "bias"] = Ref(key=source_mlp_path / "up_proj" / "bias")
+
+ # layer_2 = down_proj
+ if layer_2_bias:
+ mappings[target_mlp_path / "down_proj" / "bias"] = Ref(key=source_mlp_path / "down_proj" / "bias")
+
+ return ExprPlan(mappings=mappings)
def _plan_random_mlp(
@@ -1030,20 +1185,41 @@ def _plan_random_mlp(
target_mlp: dict,
hidden_size: int,
) -> ExprPlan:
- """Random initialization for MLP."""
+ """Random initialization for MLP weights and biases."""
target_mlp_path = W("model", "decoder", "blocks", target_layer_idx, "mlp")
intermediate_size = target_mlp["intermediate_size"]
- return ExprPlan(mappings={
- target_mlp_path / "gate_proj" / "weight": Init(
- shape=(intermediate_size, hidden_size), init_type="kaiming"
- ),
- target_mlp_path / "up_proj" / "weight": Init(
+
+ # Check per-layer bias configuration
+ layer_1_bias = _get_mlp_bias_enabled(target_mlp, "layer_1")
+ layer_2_bias = _get_mlp_bias_enabled(target_mlp, "layer_2")
+
+ # Check if gated MLP (has gate_proj) or non-gated (only up_proj)
+ gated = target_mlp.get("gated", True) # Default to gated for backwards compatibility
+
+ # Random init weights
+ mappings: dict[W, Expr] = {}
+ if gated:
+ mappings[target_mlp_path / "gate_proj" / "weight"] = Init(
shape=(intermediate_size, hidden_size), init_type="kaiming"
- ),
- target_mlp_path / "down_proj" / "weight": Init(
- shape=(hidden_size, intermediate_size), init_type="kaiming"
- ),
- })
+ )
+ mappings[target_mlp_path / "up_proj" / "weight"] = Init(
+ shape=(intermediate_size, hidden_size), init_type="kaiming"
+ )
+ mappings[target_mlp_path / "down_proj" / "weight"] = Init(
+ shape=(hidden_size, intermediate_size), init_type="kaiming"
+ )
+
+ # Random init biases if enabled
+ if layer_1_bias:
+ if gated:
+ mappings[target_mlp_path / "gate_proj" / "bias"] = Init(shape=(intermediate_size,), init_type="zeros")
+ mappings[target_mlp_path / "up_proj" / "bias"] = Init(shape=(intermediate_size,), init_type="zeros")
+
+ # layer_2 = down_proj
+ if layer_2_bias:
+ mappings[target_mlp_path / "down_proj" / "bias"] = Init(shape=(hidden_size,), init_type="zeros")
+
+ return ExprPlan(mappings=mappings)
def _plan_norms(
diff --git a/fast_llm_external_models/apriel2/conversion/qwen2/__init__.py b/fast_llm_external_models/apriel2/conversion/qwen2/__init__.py
new file mode 100644
index 000000000..d0a0b8e6e
--- /dev/null
+++ b/fast_llm_external_models/apriel2/conversion/qwen2/__init__.py
@@ -0,0 +1,6 @@
+"""Qwen2/Qwen2.5 to Apriel2 conversion module."""
+
+from fast_llm_external_models.apriel2.conversion.qwen2.config import convert_config
+from fast_llm_external_models.apriel2.conversion.qwen2.plan import plan_qwen2_to_apriel2
+
+__all__ = ["convert_config", "plan_qwen2_to_apriel2"]
diff --git a/fast_llm_external_models/apriel2/conversion/qwen2/config.py b/fast_llm_external_models/apriel2/conversion/qwen2/config.py
new file mode 100644
index 000000000..70629fe0e
--- /dev/null
+++ b/fast_llm_external_models/apriel2/conversion/qwen2/config.py
@@ -0,0 +1,79 @@
+"""Qwen2/Qwen2.5 to Apriel2 config conversion."""
+
+
+def convert_config(qwen2_config: dict) -> dict:
+ """Convert Qwen2/Qwen2.5 config to Apriel2TextConfig format.
+
+ Qwen2.5 architecture:
+ - Standard transformer with GQA (grouped query attention)
+ - QKV bias enabled, O bias disabled
+ - MLP bias disabled
+ - Gated SwiGLU MLP
+ - RMSNorm
+ - RoPE embeddings
+
+ Args:
+ qwen2_config: HuggingFace Qwen2Config as dict
+
+ Returns:
+ Apriel2TextConfig-compatible dict
+ """
+ hidden_size = qwen2_config["hidden_size"]
+ num_attention_heads = qwen2_config["num_attention_heads"]
+ num_key_value_heads = qwen2_config.get("num_key_value_heads", num_attention_heads)
+ head_dim = hidden_size // num_attention_heads
+
+ # Qwen2 uses QKV bias but not O bias - mirror Fast-LLM's per-layer config
+ return {
+ "model_type": "apriel2_text",
+ "architectures": ["Apriel2ForCausalLM"],
+ "auto_map": {
+ "AutoConfig": "configuration_apriel2.Apriel2TextConfig",
+ "AutoModel": "modeling_apriel2.Apriel2TextModel",
+ "AutoModelForCausalLM": "modeling_apriel2.Apriel2ForCausalLM",
+ },
+ "hidden_size": hidden_size,
+ "vocab_size": qwen2_config["vocab_size"],
+ "tie_word_embeddings": qwen2_config.get("tie_word_embeddings", False),
+ "decoder": {
+ "type": "fixed",
+ "num_blocks": qwen2_config["num_hidden_layers"],
+ "block": {
+ "mixer": {
+ "type": "attention",
+ "heads": num_attention_heads,
+ "head_groups": num_key_value_heads,
+ "head_size": head_dim,
+ # Per-layer bias config matching Fast-LLM structure
+ "query_layer": {"bias": {"enabled": True}},
+ "key_layer": {"bias": {"enabled": True}},
+ "value_layer": {"bias": {"enabled": True}},
+ "dense_layer": {"bias": {"enabled": False}},
+ "rotary": {
+ "type": "mistral_1d",
+ "theta": qwen2_config.get("rope_theta", 1000000.0),
+ },
+ },
+ "mlp": {
+ "type": "mlp",
+ "intermediate_size": qwen2_config["intermediate_size"],
+ "activation": qwen2_config.get("hidden_act", "silu"),
+ "gated": True,
+ "add_linear_biases": False,
+ },
+ "normalization": {
+ "type": "rms_norm",
+ "epsilon": qwen2_config.get("rms_norm_eps", 1e-6),
+ },
+ },
+ },
+ "head": {
+ "normalization": {
+ "type": "rms_norm",
+ "epsilon": qwen2_config.get("rms_norm_eps", 1e-6),
+ }
+ },
+ "embeddings": {
+ "max_position_embeddings": qwen2_config.get("max_position_embeddings", 32768),
+ },
+ }
diff --git a/fast_llm_external_models/apriel2/conversion/qwen2/plan.py b/fast_llm_external_models/apriel2/conversion/qwen2/plan.py
new file mode 100644
index 000000000..7752d37c9
--- /dev/null
+++ b/fast_llm_external_models/apriel2/conversion/qwen2/plan.py
@@ -0,0 +1,109 @@
+"""Qwen2/Qwen2.5 to Apriel2 weight conversion plan."""
+
+from fast_llm_external_models.apriel2.conversion.expr import (
+ Expr,
+ ExprPlan,
+ Ref,
+ W,
+)
+
+
+def plan_qwen2_to_apriel2(qwen2_config: dict) -> ExprPlan:
+ """Build an expression plan for Qwen2/Qwen2.5 to Apriel2 conversion.
+
+ This is a pure mapping (all Ref expressions) since Qwen2→Apriel2
+ is just renaming keys. The weight tensors are identical.
+
+ Key mapping (source keys have "model." prefix in safetensors):
+ Qwen2 (safetensor key) Apriel2
+ ---------------------- -------
+ model.embed_tokens.weight -> model.embed_tokens.weight
+ model.norm.weight -> model.norm.weight
+ model.layers.{i}.input_layernorm.weight -> model.decoder.blocks.{i}.input_layernorm.weight
+ model.layers.{i}.post_attention_layernorm.weight -> model.decoder.blocks.{i}.post_attention_layernorm.weight
+ model.layers.{i}.self_attn.q_proj.weight -> model.decoder.blocks.{i}.mixer.q_proj.weight
+ model.layers.{i}.self_attn.q_proj.bias -> model.decoder.blocks.{i}.mixer.q_proj.bias
+ model.layers.{i}.self_attn.k_proj.weight -> model.decoder.blocks.{i}.mixer.k_proj.weight
+ model.layers.{i}.self_attn.k_proj.bias -> model.decoder.blocks.{i}.mixer.k_proj.bias
+ model.layers.{i}.self_attn.v_proj.weight -> model.decoder.blocks.{i}.mixer.v_proj.weight
+ model.layers.{i}.self_attn.v_proj.bias -> model.decoder.blocks.{i}.mixer.v_proj.bias
+ model.layers.{i}.self_attn.o_proj.weight -> model.decoder.blocks.{i}.mixer.o_proj.weight
+ model.layers.{i}.mlp.gate_proj.weight -> model.decoder.blocks.{i}.mlp.gate_proj.weight
+ model.layers.{i}.mlp.up_proj.weight -> model.decoder.blocks.{i}.mlp.up_proj.weight
+ model.layers.{i}.mlp.down_proj.weight -> model.decoder.blocks.{i}.mlp.down_proj.weight
+
+ Note: Qwen2 has QKV biases but no O bias. The Apriel2 config uses per-layer
+ bias settings (query_layer.bias.enabled=True, dense_layer.bias.enabled=False)
+ to match this exactly - no workarounds needed.
+
+ Args:
+ qwen2_config: HuggingFace Qwen2Config as dict
+
+ Returns:
+ ExprPlan with Ref mappings
+ """
+ mappings: dict[str, Expr] = {}
+
+ num_layers = qwen2_config["num_hidden_layers"]
+
+ # Static mappings (embeddings and final norm)
+ # Note: Qwen2 safetensor keys have "model." prefix
+ static_mappings = [
+ (W("model", "embed_tokens", "weight"), W("model", "embed_tokens", "weight")),
+ (W("model", "norm", "weight"), W("model", "norm", "weight")),
+ ]
+
+ # lm_head - only if not tied
+ if not qwen2_config.get("tie_word_embeddings", False):
+ static_mappings.append(
+ (W("lm_head", "weight"), W("lm_head", "weight"))
+ )
+
+ for src, tgt in static_mappings:
+ mappings[tgt] = Ref(key=src)
+
+ # Layer mappings
+ for layer in range(num_layers):
+ # Source has "model.layers.{i}" prefix
+ qwen_layer = W("model", "layers", layer)
+ apriel_layer = W("model", "decoder", "blocks", layer)
+
+ # Attention projection weights
+ for proj in ["q_proj", "k_proj", "v_proj", "o_proj"]:
+ src = qwen_layer / "self_attn" / proj / "weight"
+ tgt = apriel_layer / "mixer" / proj / "weight"
+ mappings[tgt] = Ref(key=src)
+
+ # QKV biases (Qwen2 has these, but not O bias)
+ for proj in ["q_proj", "k_proj", "v_proj"]:
+ src = qwen_layer / "self_attn" / proj / "bias"
+ tgt = apriel_layer / "mixer" / proj / "bias"
+ mappings[tgt] = Ref(key=src)
+
+ # Note: o_proj has no bias in Qwen2, and Apriel2 config has dense_layer.bias.enabled=False
+
+ # MLP projections
+ for proj in ["gate_proj", "up_proj", "down_proj"]:
+ src = qwen_layer / "mlp" / proj / "weight"
+ tgt = apriel_layer / "mlp" / proj / "weight"
+ mappings[tgt] = Ref(key=src)
+
+ # Layer norms
+ mappings[apriel_layer / "input_layernorm" / "weight"] = Ref(
+ key=qwen_layer / "input_layernorm" / "weight"
+ )
+ mappings[apriel_layer / "post_attention_layernorm" / "weight"] = Ref(
+ key=qwen_layer / "post_attention_layernorm" / "weight"
+ )
+
+ return ExprPlan(
+ mappings=mappings,
+ source_format="qwen2",
+ target_format="apriel2",
+ metadata={
+ "num_layers": num_layers,
+ "hidden_size": qwen2_config["hidden_size"],
+ "num_attention_heads": qwen2_config["num_attention_heads"],
+ "num_key_value_heads": qwen2_config.get("num_key_value_heads", qwen2_config["num_attention_heads"]),
+ },
+ )
diff --git a/fast_llm_external_models/apriel2/convert.py b/fast_llm_external_models/apriel2/convert.py
index cbf921b31..60786d22c 100644
--- a/fast_llm_external_models/apriel2/convert.py
+++ b/fast_llm_external_models/apriel2/convert.py
@@ -15,6 +15,7 @@
Supported source formats:
- llava: Llava/Pixtral models
+- qwen2: Qwen2/Qwen2.5 models
- apriel2: Apriel2 models (surgery-only mode - no conversion, just apply surgeries)
"""
@@ -42,10 +43,12 @@
compose,
compose_configs,
plan_surgery,
+ strip_init_fields,
)
# Import source-specific converters
from fast_llm_external_models.apriel2.conversion import llava as llava_converter
+from fast_llm_external_models.apriel2.conversion import qwen2 as qwen2_converter
logger = logging.getLogger(__name__)
@@ -73,6 +76,7 @@ def _identity_plan(config: dict) -> ExprPlan:
# Each entry maps format name to (config_converter, plan_builder)
SOURCE_FORMATS: dict[str, tuple[Callable[[dict], dict], Callable[[dict], ExprPlan]]] = {
"llava": (llava_converter.convert_config, llava_converter.plan_llava_to_apriel2),
+ "qwen2": (qwen2_converter.convert_config, qwen2_converter.plan_qwen2_to_apriel2),
"apriel2": (_identity_config, _identity_plan),
}
@@ -88,8 +92,12 @@ def detect_source_format(config: dict) -> str | None:
if model_type in ("llava", "pixtral") or "text_config" in config:
return "llava"
+ # Qwen2/Qwen2.5 detection
+ if model_type == "qwen2":
+ return "qwen2"
+
# Apriel2 detection - check for Apriel2-specific structure
- if model_type == "apriel2" or "decoder" in config:
+ if model_type in ("apriel2", "apriel2_text") or "decoder" in config:
return "apriel2"
return None
@@ -142,15 +150,19 @@ def build_plan(
# Apply surgery chain if requested
if surgery_configs:
for i, surgery_config in enumerate(surgery_configs, 1):
- surgery_plan = plan_surgery(current_config, surgery_config)
+ # S × P → T: compose state with surgery to get transition spec
+ target_config = compose_configs(current_config, surgery_config)
+
+ # S × T → Plan: build plan from source state and transition spec
+ surgery_plan = plan_surgery(current_config, target_config)
logger.info(f"Built surgery plan [{i}/{len(surgery_configs)}]: {surgery_plan.summary()['num_targets']} targets")
- # Compose: current -> surgery
+ # Compose plans
current_plan = compose(current_plan, surgery_plan)
logger.info(f"Composed plan [{i}/{len(surgery_configs)}]: {current_plan.summary()['num_targets']} targets")
- # Compose configs: merge surgery spec into current config
- current_config = compose_configs(current_config, surgery_config)
+ # T → S: strip init for next iteration (init is consumed by plan_surgery)
+ current_config = strip_init_fields(target_config)
return current_plan, current_config
@@ -400,11 +412,11 @@ def main():
show_plan=args.show_plan or args.verbose,
)
- # Save config
+ # Save config (build_plan returns S which has no init, but strip defensively)
output_config_file = args.output_dir / "config.json"
logger.info(f"Saving config to {output_config_file}")
with open(output_config_file, "w") as f:
- json.dump(apriel2_config, f, indent=2)
+ json.dump(strip_init_fields(apriel2_config), f, indent=2)
# Copy tokenizer files
copy_tokenizer_files(input_dir, args.output_dir)
diff --git a/fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml b/fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml
new file mode 100644
index 000000000..34672916c
--- /dev/null
+++ b/fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml
@@ -0,0 +1,103 @@
+# Dataset preparation config for Tulu 3 SFT mixture with Qwen2 tokenizer
+#
+# This config converts the Tulu 3 SFT dataset (conversation format) to
+# Fast-LLM's memmap format, with automatic loss masking span computation
+# to train only on assistant responses.
+#
+# =============================================================================
+# TOKENIZER SETUP (one-time)
+# =============================================================================
+#
+# The tokenizer must have a chat template with {% generation %} markers.
+# Qwen2's default template doesn't have these, so we need to patch it.
+#
+# IMPORTANT: The entire assistant turn (opening tag + content + closing tag)
+# must be inside the {% generation %} block. This ensures the model learns to
+# produce the full assistant response including special tokens like <|im_end|>.
+# Reference: https://huggingface.co/HuggingFaceTB/SmolLM3-3B/blob/main/chat_template.jinja
+#
+# Run this Python script to create a patched tokenizer:
+#
+# from transformers import AutoTokenizer
+#
+# tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
+#
+# # Patch chat template: wrap ENTIRE assistant turn in generation markers
+# tokenizer.chat_template = '''{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system
+# You are a helpful assistant.<|im_end|>
+# ' }}{% endif %}{% if message['role'] == 'assistant' %}{% generation %}{{ '<|im_start|>assistant
+# ' + message['content'] + '<|im_end|>
+# ' }}{% endgeneration %}{% else %}{{ '<|im_start|>' + message['role'] + '
+# ' + message['content'] + '<|im_end|>
+# ' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant
+# ' }}{% endif %}'''
+#
+# tokenizer.save_pretrained("/path/to/qwen2-instruct-with-markers")
+#
+# =============================================================================
+# DATA PREPARATION
+# =============================================================================
+#
+# Small dataset (for testing):
+#
+# fast-llm prepare gpt_memmap \
+# -c fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml \
+# dataset.split=train[:1000] \
+# output_path=/path/to/tulu3-prepared-small
+#
+# Full dataset (~939K samples, ~6 minutes):
+#
+# fast-llm prepare gpt_memmap \
+# -c fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml
+#
+# =============================================================================
+# VERIFICATION
+# =============================================================================
+#
+# To verify the prepared dataset has loss masking spans:
+#
+# import pathlib
+# from fast_llm.data.dataset.memmap import MemmapDataset
+# from fast_llm.data.sample.language_model import LanguageModelSample
+# from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig
+#
+# dataset = MemmapDataset[LanguageModelSample](
+# 'tulu3',
+# pathlib.Path('/path/to/tulu3-prepared/shard_0_0.fast_llm_dataset'),
+# LanguageModelPreprocessingConfig(use_loss_masking_spans=True)
+# )
+#
+# doc = dataset.get_document(0)
+# print(f'Tokens: {len(doc.tokens.tokens)}')
+# print(f'Loss masking spans: {doc.loss_masking_spans.ranges}')
+#
+# =============================================================================
+
+# Dataset configuration
+dataset:
+ # Tulu 3 SFT mixture from AllenAI
+ path: allenai/tulu-3-sft-mixture
+ split: train
+
+ # Source schema for conversation format
+ source_schema:
+ # Use conversation type (vs default "document" type)
+ type: conversation
+
+ # Column containing the messages list
+ messages: messages
+
+# Tokenizer configuration
+# IMPORTANT: Must use a tokenizer with {% generation %} markers in its chat template.
+# See instructions above to create a patched Qwen2 tokenizer.
+tokenizer:
+ path: /path/to/qwen2-instruct-with-markers
+ # Qwen2 doesn't have a BOS token by default, use <|endoftext|> as BOS
+ bos_token: "<|endoftext|>"
+
+# Output configuration
+output_path: /path/to/tulu3-prepared
+
+# Processing configuration
+num_workers: 8
+documents_per_shard: 100000
diff --git a/fast_llm_external_models/apriel2/examples/train_supernet_qwen2.yaml b/fast_llm_external_models/apriel2/examples/train_supernet_qwen2.yaml
new file mode 100644
index 000000000..5b190955f
--- /dev/null
+++ b/fast_llm_external_models/apriel2/examples/train_supernet_qwen2.yaml
@@ -0,0 +1,193 @@
+# Training config for Qwen2-based Apriel2 stochastic supernet on Tulu 3 SFT data
+#
+# This config trains a stochastic supernet where each layer can sample from
+# multiple mixer types (attention, sliding window, gated delta net, KDA).
+# Only the mixer weights are trained; all other weights are frozen.
+# Activation-level distillation from a teacher model guides the training.
+#
+# =============================================================================
+# PREREQUISITES
+# =============================================================================
+#
+# 1. TOKENIZER SETUP
+#
+# Qwen2's default chat template doesn't have generation markers needed for
+# loss masking. Create a patched tokenizer following the SmolLM3 pattern:
+# https://huggingface.co/HuggingFaceTB/SmolLM3-3B/blob/main/chat_template.jinja
+#
+# IMPORTANT: The ENTIRE assistant turn (opening tag + content + closing tag)
+# must be inside {% generation %}...{% endgeneration %} markers.
+#
+# from transformers import AutoTokenizer
+# tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
+# # Wrap entire assistant turn in generation markers (NOT just content!)
+# tokenizer.chat_template = '''{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system
+# You are a helpful assistant.<|im_end|>
+# ' }}{% endif %}{% if message['role'] == 'assistant' %}{% generation %}{{ '<|im_start|>assistant
+# ' + message['content'] + '<|im_end|>
+# ' }}{% endgeneration %}{% else %}{{ '<|im_start|>' + message['role'] + '
+# ' + message['content'] + '<|im_end|>
+# ' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant
+# ' }}{% endif %}'''
+# tokenizer.save_pretrained("/path/to/qwen2-instruct-with-markers")
+#
+# 2. PREPARE TULU 3 DATASET
+#
+# Small dataset (for testing):
+#
+# fast-llm prepare gpt_memmap \
+# -c fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml \
+# tokenizer.path=/path/to/qwen2-instruct-with-markers \
+# dataset.split=train[:1000] \
+# output_path=/path/to/tulu3-prepared-small
+#
+# Full dataset (~939K samples, ~6 minutes):
+#
+# fast-llm prepare gpt_memmap \
+# -c fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml \
+# tokenizer.path=/path/to/qwen2-instruct-with-markers \
+# output_path=/path/to/tulu3-prepared
+#
+# 3. CONVERT QWEN2 TO APRIEL2 SUPERNET (student model)
+#
+# This creates a stochastic supernet with multiple mixer types per layer:
+#
+# python fast_llm_external_models/apriel2/convert.py \
+# Qwen/Qwen2.5-0.5B-Instruct \
+# /path/to/qwen2-supernet \
+# --surgery fast_llm_external_models/apriel2/examples/stochastic_supernet.yaml
+#
+# 4. CONVERT QWEN2 TO APRIEL2 (teacher model)
+#
+# The teacher is the original model without surgery, used for distillation:
+#
+# python fast_llm_external_models/apriel2/convert.py \
+# Qwen/Qwen2.5-0.5B-Instruct \
+# /path/to/qwen2-teacher
+#
+# 5. RUN TRAINING
+#
+# Update paths below and run:
+#
+# fast-llm train gpt \
+# -c fast_llm_external_models/apriel2/examples/train_supernet_qwen2.yaml
+#
+# For long runs, use nohup:
+#
+# nohup fast-llm train gpt \
+# -c fast_llm_external_models/apriel2/examples/train_supernet_qwen2.yaml \
+# > training.log 2>&1 &
+# tail -f training.log
+#
+# =============================================================================
+# PERFORMANCE TUNING
+# =============================================================================
+#
+# Default config uses seq=4096, micro_batch=2, batch=16 which gives:
+# - ~8k tokens/s/gpu throughput
+# - ~61GB GPU memory usage
+# - ~25 hours for 1B tokens on single GPU
+#
+# Adjust batch settings based on your GPU memory:
+# - Reduce micro_batch_size if OOM
+# - Increase micro_batch_size/batch_size if memory available
+#
+# =============================================================================
+# OUTPUT
+# =============================================================================
+#
+# Checkpoints: /path/to/qwen2-supernet-trained/checkpoints/{iteration}/
+# Exports: /path/to/qwen2-supernet-trained/export/apriel2_text/{iteration}/
+#
+# =============================================================================
+
+# Load pretrained model (Qwen2 converted to Apriel2 supernet)
+pretrained:
+ path: /path/to/qwen2-supernet
+ format: apriel2_text
+ model_weights: true
+ load_config: model
+
+# Model config
+model:
+ base_model:
+ # Freeze all components except the mixer
+ decoder:
+ block:
+ mlp:
+ lr_scale: 0.0 # Freeze MLP
+ normalization:
+ lr_scale: 0.0 # Freeze layer norms
+ # Activation-level distillation from teacher
+ distillation_model: teacher
+ activation_distillation_factor: 0.8
+ embeddings:
+ lr_scale: 0.0 # Freeze word embeddings
+ head:
+ lr_scale: 0.0 # Freeze output head
+ cross_entropy_implementation: torch
+ multi_stage:
+ zero_stage: 2
+ distributed:
+ compute_dtype: bf16
+ seed: 42
+
+# Teacher model for activation-level distillation
+reference_models:
+ teacher:
+ model:
+ type: gpt
+ pretrained:
+ path: /path/to/qwen2-teacher
+ format: apriel2_text
+ model_weights: true
+ load_config: model
+
+# Batch configuration (tuned for ~61GB GPU memory, ~8k tokens/s)
+batch:
+ sequence_length: 4096
+ micro_batch_size: 2
+ batch_size: 16
+
+# Data configuration (prepared Tulu 3 dataset)
+data:
+ datasets:
+ training:
+ type: file
+ path: /path/to/tulu3-prepared/fast_llm_config.yaml
+
+# Optimizer configuration
+optimizer:
+ learning_rate:
+ base: 1.0e-05
+ decay_style: cosine
+ warmup_iterations: 100
+ decay_iterations: 10000
+ minimum: 1.0e-06
+ weight_decay: 0.1
+ beta_1: 0.9
+ beta_2: 0.95
+
+# Training configuration
+# At seq=4096, batch=16: ~65k tokens/iter, ~280 iters/hour
+# 10000 iters ≈ 650M tokens ≈ 35 hours
+training:
+ train_iters: 10000
+ num_workers: 4
+ logs:
+ interval: 10
+ checkpoint:
+ interval: 280 # ~hourly
+ export:
+ interval: 280 # ~hourly (useful for development/testing during training)
+ format: apriel2_text
+ test_iters: 0
+ evaluators: {}
+ # Weights & Biases configuration (optional, uncomment to enable)
+ # wandb:
+ # entity_name: your-entity
+ # project_name: your-project
+
+# Experiment directory
+run:
+ experiment_dir: /path/to/qwen2-supernet-trained
diff --git a/fast_llm_external_models/apriel2/examples/train_supernet_small.yaml b/fast_llm_external_models/apriel2/examples/train_supernet_small.yaml
index 78c22e57f..be4d06e0a 100644
--- a/fast_llm_external_models/apriel2/examples/train_supernet_small.yaml
+++ b/fast_llm_external_models/apriel2/examples/train_supernet_small.yaml
@@ -107,7 +107,7 @@ model:
lr_scale: 0.0 # Freeze layer norms (norm_1 and norm_2 in each block)
# Activation-level distillation: teach mixers to mimic teacher's attention outputs
distillation_model: teacher
- activation_distillation_factor: 0.1
+ activation_distillation_factor: 0.8
embeddings:
lr_scale: 0.0 # Freeze word embeddings
head:
diff --git a/fast_llm_external_models/apriel2/modeling_apriel2.py b/fast_llm_external_models/apriel2/modeling_apriel2.py
index 4c263b4e2..878677653 100644
--- a/fast_llm_external_models/apriel2/modeling_apriel2.py
+++ b/fast_llm_external_models/apriel2/modeling_apriel2.py
@@ -24,8 +24,8 @@
is_torch_flex_attn_available,
)
-from fast_llm_external_models.apriel2.cache import Apriel2Cache
-from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config, Apriel2TextConfig
+from .cache import Apriel2Cache
+from .configuration_apriel2 import Apriel2Config, Apriel2TextConfig
# GDN implementation - matches Fast-LLM's gdn.py exactly
try:
@@ -395,14 +395,30 @@ def __init__(self, d_model: int, mixer_config: dict, layer_idx: int, config):
# cross_document_attention: if False, use cu_seqlens to isolate sequences (e.g., images)
self.cross_document_attention = mixer_config.get("cross_document_attention", True)
- # Whether to add biases to linear projections
- add_bias = mixer_config.get("add_linear_biases", False)
-
- # Projections (Fast-LLM weight names: q_proj, k_proj, v_proj, o_proj)
- self.q_proj = nn.Linear(d_model, self.num_heads * self.head_dim, bias=add_bias)
- self.k_proj = nn.Linear(d_model, self.num_key_value_heads * self.head_dim, bias=add_bias)
- self.v_proj = nn.Linear(d_model, self.num_key_value_heads * self.head_dim, bias=add_bias)
- self.o_proj = nn.Linear(self.num_heads * self.head_dim, d_model, bias=add_bias)
+ # Bias configuration mirroring Fast-LLM's structure:
+ # - add_linear_biases: bool (default for all projections)
+ # - query_layer: {"bias": {"enabled": bool}} (per-layer override)
+ # - key_layer: {"bias": {"enabled": bool}}
+ # - value_layer: {"bias": {"enabled": bool}}
+ # - dense_layer: {"bias": {"enabled": bool}}
+ default_bias = mixer_config.get("add_linear_biases", False)
+
+ def get_layer_bias(layer_name: str) -> bool:
+ layer_cfg = mixer_config.get(layer_name, {})
+ bias_cfg = layer_cfg.get("bias", {})
+ enabled = bias_cfg.get("enabled")
+ return default_bias if enabled is None else enabled
+
+ q_bias = get_layer_bias("query_layer")
+ k_bias = get_layer_bias("key_layer")
+ v_bias = get_layer_bias("value_layer")
+ o_bias = get_layer_bias("dense_layer")
+
+ # Projections
+ self.q_proj = nn.Linear(d_model, self.num_heads * self.head_dim, bias=q_bias)
+ self.k_proj = nn.Linear(d_model, self.num_key_value_heads * self.head_dim, bias=k_bias)
+ self.v_proj = nn.Linear(d_model, self.num_key_value_heads * self.head_dim, bias=v_bias)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, d_model, bias=o_bias)
@classmethod
def setup(
@@ -1828,16 +1844,36 @@ def __init__(
self.post_attention_layernorm = self._create_norm(norm_config, hidden_size, rms_norm_eps)
def _create_mlp(self, mlp_config: dict, hidden_size: int):
- """Create MLP based on config."""
+ """Create MLP based on config.
+
+ Supports per-layer bias configuration mirroring Fast-LLM:
+ - add_linear_biases: default bias setting for all layers
+ - layer_1.bias.enabled: override for up_proj/gate_proj
+ - layer_2.bias.enabled: override for down_proj
+ """
mlp_type = mlp_config.get("type", "mlp")
if mlp_type == "mlp":
intermediate_size = mlp_config["intermediate_size"]
activation = mlp_config.get("activation", "silu")
- gated = mlp_config["gated"]
- bias = mlp_config.get("add_linear_biases", False)
+ gated = mlp_config.get("gated", False)
+
+ # Per-layer bias configuration (mirrors Fast-LLM structure)
+ default_bias = mlp_config.get("add_linear_biases", False)
+
+ def get_layer_bias(layer_name: str) -> bool:
+ layer_cfg = mlp_config.get(layer_name, {})
+ bias_cfg = layer_cfg.get("bias", {})
+ enabled = bias_cfg.get("enabled")
+ return default_bias if enabled is None else enabled
+
+ layer_1_bias = get_layer_bias("layer_1")
+ layer_2_bias = get_layer_bias("layer_2")
if gated:
+ # MistralMLP uses gate_proj, up_proj, down_proj (all bias controlled together)
+ # For now, we use the default bias setting for gated MLPs
+ # TODO: Add per-layer bias support to gated MLP
mlp_cfg = SimpleNamespace(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
@@ -1845,7 +1881,13 @@ def _create_mlp(self, mlp_config: dict, hidden_size: int):
)
return MistralMLP(mlp_cfg)
else:
- return SimpleMLP(hidden_size, intermediate_size, activation, bias)
+ return SimpleMLP(
+ hidden_size,
+ intermediate_size,
+ activation,
+ layer_1_bias=layer_1_bias,
+ layer_2_bias=layer_2_bias,
+ )
else:
raise ValueError(f"Unknown MLP type: {mlp_type}")
@@ -2179,6 +2221,8 @@ def forward(
class Apriel2ForCausalLM(Apriel2PreTrainedModel, GenerationMixin):
"""Apriel2 model with a language modeling head (text-only)."""
+ _tied_weights_keys = ["lm_head.weight"]
+
def __init__(self, config: Apriel2TextConfig):
super().__init__(config)
self.model = Apriel2TextModel(config)
@@ -2186,6 +2230,7 @@ def __init__(self, config: Apriel2TextConfig):
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
+ # post_init() calls init_weights() which calls tie_weights() if config.tie_word_embeddings
self.post_init()
def get_input_embeddings(self):
@@ -2583,14 +2628,26 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
class SimpleMLP(nn.Module):
- """Non-gated MLP: up_proj -> activation -> down_proj."""
+ """Non-gated MLP: up_proj -> activation -> down_proj.
- def __init__(self, hidden_size: int, intermediate_size: int, activation: str = "silu", bias: bool = False):
+ Supports per-layer bias configuration mirroring Fast-LLM:
+ - layer_1_bias: bias for up_proj (layer_1 in Fast-LLM naming)
+ - layer_2_bias: bias for down_proj (layer_2 in Fast-LLM naming)
+ """
+
+ def __init__(
+ self,
+ hidden_size: int,
+ intermediate_size: int,
+ activation: str = "silu",
+ layer_1_bias: bool = False,
+ layer_2_bias: bool = False,
+ ):
super().__init__()
from transformers.activations import ACT2FN
- self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=bias)
- self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=bias)
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=layer_1_bias)
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=layer_2_bias)
self.act_fn = ACT2FN[activation]
def forward(self, x):
diff --git a/fast_llm_external_models/tests/test_apriel2/conftest.py b/fast_llm_external_models/tests/test_apriel2/conftest.py
index 8585aec65..320813747 100644
--- a/fast_llm_external_models/tests/test_apriel2/conftest.py
+++ b/fast_llm_external_models/tests/test_apriel2/conftest.py
@@ -7,6 +7,22 @@
import torch
from transformers import LlavaConfig, LlavaForConditionalGeneration, MistralConfig
+from fast_llm_external_models.apriel2.cache import _AttentionCache, _SSMCache
+
+
+# Register custom marks
+def pytest_configure(config):
+ config.addinivalue_line("markers", "slow: marks tests as slow (deselect with '-m \"not slow\"')")
+
+
+def _can_import_fast_llm():
+ """Check if Fast-LLM is available."""
+ try:
+ from fast_llm.engine.checkpoint.convert import ConvertConfig
+ return True
+ except ImportError:
+ return False
+
# Skip marker for tests that require CUDA for Mamba forward pass
requires_cuda = pytest.mark.skipif(
@@ -14,10 +30,20 @@
reason="SSM mixers (Mamba) require CUDA for forward pass"
)
+# Skip marker for tests that require Fast-LLM
+requires_fastllm = pytest.mark.skipif(
+ not _can_import_fast_llm(),
+ reason="Fast-LLM not available"
+)
+
-@pytest.fixture(autouse=True)
+@pytest.fixture(scope="module", autouse=True)
def set_default_device():
- """Set default device to CUDA for all tests (Mamba requires CUDA)."""
+ """Set default device to CUDA for all tests (Mamba requires CUDA).
+
+ Module-scoped to ensure it runs before any module-scoped fixtures
+ that load models (e.g., qwen2_model_and_tokenizer).
+ """
if torch.cuda.is_available():
old_device = torch.get_default_device()
torch.set_default_device("cuda")
@@ -27,9 +53,12 @@ def set_default_device():
yield
-@pytest.fixture(autouse=True)
+@pytest.fixture(scope="module", autouse=True)
def set_default_dtype():
- """Set default dtype to float32 for numerical comparison tests."""
+ """Set default dtype to float32 for numerical comparison tests.
+
+ Module-scoped to ensure it runs before any module-scoped fixtures.
+ """
old_dtype = torch.get_default_dtype()
torch.set_default_dtype(torch.float32)
yield
@@ -761,6 +790,52 @@ def apriel2_config_comprehensive():
)
+@pytest.fixture
+def apriel2_config_with_bias():
+ """Apriel2 config with Qwen-style per-layer bias and non-gated MLP.
+
+ This config exercises:
+ - Per-layer attention bias (QKV bias enabled, O bias disabled)
+ - Non-gated MLP with per-layer bias (layer_1 enabled, layer_2 disabled)
+ - Config structure parity with Fast-LLM's AffineLinearConfig
+
+ Critical for testing bias inheritance through surgery operations.
+ """
+ from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config
+
+ return Apriel2Config(
+ vocab_size=100,
+ hidden_size=64,
+ decoder={
+ "type": "fixed",
+ "num_blocks": 2,
+ "block": {
+ "mixer": {
+ "type": "attention",
+ "heads": 4,
+ "head_groups": 2,
+ "head_size": 16,
+ "rotary": {"type": "mistral_1d", "theta": 10000.0},
+ # Qwen-style: QKV bias enabled, O bias disabled
+ "query_layer": {"bias": {"enabled": True}},
+ "key_layer": {"bias": {"enabled": True}},
+ "value_layer": {"bias": {"enabled": True}},
+ "dense_layer": {"bias": {"enabled": False}},
+ },
+ "mlp": {
+ "type": "mlp",
+ "intermediate_size": 256,
+ "gated": False, # Non-gated MLP (SimpleMLP)
+ # Per-layer MLP bias
+ "layer_1": {"bias": {"enabled": True}},
+ "layer_2": {"bias": {"enabled": False}},
+ },
+ "normalization": {"type": "rms_norm", "epsilon": 1e-5},
+ },
+ },
+ )
+
+
@pytest.fixture
def apriel2_cache(apriel2_config_tiny):
"""Create empty Apriel2Cache from tiny config."""
@@ -863,6 +938,77 @@ def additive_surgery_chain():
]
+@pytest.fixture
+def bias_surgery_chain():
+ """Surgery chain that exercises bias inheritance through surgery operations.
+
+ Designed to be used with apriel2_config_with_bias as the source config.
+ Tests that per-layer bias settings (Qwen-style QKV bias, non-gated MLP bias)
+ are correctly inherited through:
+ - Stochastic wrapper creation
+ - Adding new sub-mixers that inherit from source
+ - Cross-type derivation (attention → sliding_window)
+
+ Source config has:
+ - Attention: query/key/value bias enabled, dense bias disabled
+ - MLP: layer_1 bias enabled, layer_2 bias disabled (non-gated)
+ """
+ return [
+ # S1: Wrap in stochastic - bias should transfer to attention sub-mixer
+ {
+ "decoder": {
+ "block": {
+ "mixer": {
+ "type": "stochastic",
+ "main_mixer_name": "attention",
+ "mixers": {
+ "attention": {"init": "transfer"},
+ },
+ },
+ "mlp": {"init": "transfer"},
+ "normalization": {"init": "transfer"},
+ },
+ },
+ },
+ # S2: Add sliding_window - should inherit bias from source attention
+ {
+ "decoder": {
+ "block": {
+ "mixer": {
+ "mixers": {
+ "sliding_window": {
+ "type": "attention",
+ "init": "transfer",
+ "window_size": 512,
+ },
+ },
+ },
+ },
+ },
+ },
+ # S3: Add new attention with DIFFERENT bias config (random init)
+ {
+ "decoder": {
+ "block": {
+ "mixer": {
+ "mixers": {
+ "full_bias_attn": {
+ "type": "attention",
+ "init": "random",
+ "heads": 4,
+ "head_groups": 2,
+ "head_size": 16,
+ "rotary": {"type": "mistral_1d", "theta": 10000.0},
+ "add_linear_biases": True, # All biases enabled
+ },
+ },
+ },
+ },
+ },
+ },
+ ]
+
+
@pytest.fixture
def comprehensive_torture_chain():
"""Comprehensive torture chain exercising ALL conversion paths.
@@ -1532,3 +1678,330 @@ def torture_surgery_chain():
},
},
]
+
+
+# =============================================================================
+# Shared Config Dict Fixtures (for compose_configs / plan_surgery tests)
+# =============================================================================
+
+
+@pytest.fixture
+def base_config_dict():
+ """Complete Apriel2 config dict without biases (Llama-style).
+
+ Use this as the base config for testing compose_configs and plan_surgery.
+ Returns a dict (not Apriel2Config) for direct use with compose_configs.
+ """
+ return {
+ "model_type": "apriel2",
+ "hidden_size": 256,
+ "vocab_size": 1000,
+ "bos_token_id": 1,
+ "eos_token_id": 2,
+ "tie_word_embeddings": False,
+ "decoder": {
+ "type": "fixed",
+ "num_blocks": 2,
+ "block": {
+ "mixer": {
+ "type": "attention",
+ "heads": 8,
+ "head_groups": 4,
+ "head_size": 32,
+ "rotary": {"type": "mistral_1d", "theta": 10000.0},
+ },
+ "mlp": {"type": "mlp", "intermediate_size": 512, "gated": True},
+ "normalization": {"type": "rms_norm", "epsilon": 1e-5},
+ },
+ },
+ }
+
+
+@pytest.fixture
+def base_config_with_bias_dict():
+ """Complete Apriel2 config dict with Qwen-style biases.
+
+ - QKV bias enabled, O bias disabled
+ - Gated MLP (no per-layer bias control in this style)
+
+ Use this for testing bias inheritance through surgery operations.
+ Returns a dict (not Apriel2Config) for direct use with compose_configs.
+ """
+ return {
+ "model_type": "apriel2",
+ "hidden_size": 256,
+ "vocab_size": 1000,
+ "bos_token_id": 1,
+ "eos_token_id": 2,
+ "tie_word_embeddings": False,
+ "decoder": {
+ "type": "fixed",
+ "num_blocks": 2,
+ "block": {
+ "mixer": {
+ "type": "attention",
+ "heads": 8,
+ "head_groups": 4,
+ "head_size": 32,
+ "rotary": {"type": "mistral_1d", "theta": 10000.0},
+ "query_layer": {"bias": {"enabled": True}},
+ "key_layer": {"bias": {"enabled": True}},
+ "value_layer": {"bias": {"enabled": True}},
+ "dense_layer": {"bias": {"enabled": False}},
+ },
+ "mlp": {"type": "mlp", "intermediate_size": 512, "gated": True},
+ "normalization": {"type": "rms_norm", "epsilon": 1e-5},
+ },
+ },
+ }
+
+
+def make_weights_for_config(config: dict) -> dict:
+ """Create random weights matching a config's expected schema.
+
+ This is a helper function (not a fixture) for creating test weights.
+ Use it in tests that need weights for plan execution.
+
+ Args:
+ config: Complete Apriel2 config dict
+
+ Returns:
+ Dict mapping weight key strings to torch tensors
+ """
+ from fast_llm_external_models.apriel2.conversion import W
+
+ hidden = config["hidden_size"]
+ vocab = config["vocab_size"]
+ decoder = config["decoder"]
+ num_blocks = decoder["num_blocks"]
+ block = decoder["block"]
+ mixer = block["mixer"]
+ mlp = block["mlp"]
+
+ heads = mixer["heads"]
+ head_groups = mixer["head_groups"]
+ head_size = mixer["head_size"]
+ inter = mlp["intermediate_size"]
+
+ # Check bias settings
+ has_q_bias = mixer.get("query_layer", {}).get("bias", {}).get("enabled", False)
+ has_k_bias = mixer.get("key_layer", {}).get("bias", {}).get("enabled", False)
+ has_v_bias = mixer.get("value_layer", {}).get("bias", {}).get("enabled", False)
+
+ weights = {}
+ weights["model.embed_tokens.weight"] = torch.randn(vocab, hidden)
+
+ for i in range(num_blocks):
+ p = f"model.decoder.blocks.{i}"
+
+ # Attention
+ weights[f"{p}.mixer.q_proj.weight"] = torch.randn(heads * head_size, hidden)
+ weights[f"{p}.mixer.k_proj.weight"] = torch.randn(head_groups * head_size, hidden)
+ weights[f"{p}.mixer.v_proj.weight"] = torch.randn(head_groups * head_size, hidden)
+ weights[f"{p}.mixer.o_proj.weight"] = torch.randn(hidden, heads * head_size)
+
+ if has_q_bias:
+ weights[f"{p}.mixer.q_proj.bias"] = torch.randn(heads * head_size)
+ if has_k_bias:
+ weights[f"{p}.mixer.k_proj.bias"] = torch.randn(head_groups * head_size)
+ if has_v_bias:
+ weights[f"{p}.mixer.v_proj.bias"] = torch.randn(head_groups * head_size)
+
+ # MLP
+ weights[f"{p}.mlp.up_proj.weight"] = torch.randn(inter, hidden)
+ weights[f"{p}.mlp.gate_proj.weight"] = torch.randn(inter, hidden)
+ weights[f"{p}.mlp.down_proj.weight"] = torch.randn(hidden, inter)
+
+ # Norms
+ weights[f"{p}.input_layernorm.weight"] = torch.randn(hidden)
+ weights[f"{p}.post_attention_layernorm.weight"] = torch.randn(hidden)
+
+ weights["model.norm.weight"] = torch.randn(hidden)
+ weights["lm_head.weight"] = torch.randn(vocab, hidden)
+
+ return {W(k): v for k, v in weights.items()}
+
+
+# =============================================================================
+# Cache Test Fixtures - Tensor Dimensions
+# =============================================================================
+
+
+@pytest.fixture
+def batch_size():
+ """Default batch size for cache tests."""
+ return 2
+
+
+@pytest.fixture
+def num_heads():
+ """Default number of attention heads for cache tests."""
+ return 4
+
+
+@pytest.fixture
+def head_dim():
+ """Default head dimension for cache tests."""
+ return 16
+
+
+@pytest.fixture
+def make_kv(batch_size, num_heads, head_dim):
+ """Factory fixture for creating KV tensors."""
+
+ def _make_kv(seq_len):
+ return (
+ torch.randn(batch_size, num_heads, seq_len, head_dim),
+ torch.randn(batch_size, num_heads, seq_len, head_dim),
+ )
+
+ return _make_kv
+
+
+# =============================================================================
+# Cache Test Fixtures - HuggingFace Cache Layers
+# =============================================================================
+
+
+@pytest.fixture
+def hf_dynamic_layer():
+ """HuggingFace DynamicLayer for full attention contract testing."""
+ from transformers.cache_utils import DynamicLayer
+
+ return DynamicLayer()
+
+
+@pytest.fixture
+def hf_sliding_layer(window_size):
+ """HuggingFace DynamicSlidingWindowLayer for sliding window contract testing."""
+ from transformers.cache_utils import DynamicSlidingWindowLayer
+
+ return DynamicSlidingWindowLayer(sliding_window=window_size)
+
+
+# =============================================================================
+# Cache Test Fixtures - Apriel2 Low-level Caches
+# =============================================================================
+
+
+@pytest.fixture
+def apriel_attention_cache():
+ """Apriel2 attention cache without window (full attention)."""
+ return _AttentionCache(window=None)
+
+
+@pytest.fixture
+def apriel_sliding_cache(window_size):
+ """Apriel2 attention cache with sliding window."""
+ return _AttentionCache(window=window_size)
+
+
+@pytest.fixture
+def ssm_cache():
+ """Apriel2 SSM cache for Mamba/GDN/KDA layers."""
+ return _SSMCache()
+
+
+# =============================================================================
+# Cache Test Fixtures - Apriel2 Configs (Simple Versions)
+# =============================================================================
+
+
+@pytest.fixture
+def attention_config():
+ """Pure attention config (2 layers, no sliding window)."""
+ from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config
+
+ return Apriel2Config(
+ vocab_size=100,
+ hidden_size=64,
+ decoder={
+ "type": "fixed",
+ "num_blocks": 2,
+ "block": {
+ "mixer": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16},
+ "mlp": {"type": "mlp", "intermediate_size": 256},
+ "normalization": {"type": "rms_norm", "epsilon": 1e-5},
+ },
+ },
+ )
+
+
+@pytest.fixture
+def swa_config():
+ """Sliding window attention config (2 layers, window=8)."""
+ from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config
+
+ return Apriel2Config(
+ vocab_size=100,
+ hidden_size=64,
+ decoder={
+ "type": "fixed",
+ "num_blocks": 2,
+ "block": {
+ "mixer": {
+ "type": "attention",
+ "heads": 4,
+ "head_groups": 2,
+ "head_size": 16,
+ "window_size": 8,
+ },
+ "mlp": {"type": "mlp", "intermediate_size": 256},
+ "normalization": {"type": "rms_norm", "epsilon": 1e-5},
+ },
+ },
+ )
+
+
+@pytest.fixture
+def ssm_config():
+ """Pure SSM config (2 layers)."""
+ from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config
+
+ return Apriel2Config(
+ vocab_size=100,
+ hidden_size=64,
+ decoder={
+ "type": "fixed",
+ "num_blocks": 2,
+ "block": {
+ "mixer": {"type": "mamba", "state_size": 16},
+ "mlp": {"type": "mlp", "intermediate_size": 256},
+ "normalization": {"type": "rms_norm", "epsilon": 1e-5},
+ },
+ },
+ )
+
+
+@pytest.fixture
+def stochastic_config():
+ """Stochastic mixer config with attention and mamba (2 layers)."""
+ from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config
+
+ return Apriel2Config(
+ vocab_size=100,
+ hidden_size=64,
+ decoder={
+ "type": "fixed",
+ "num_blocks": 2,
+ "block": {
+ "mixer": {
+ "type": "stochastic",
+ "main_mixer_name": "attention",
+ "mixers": {
+ "attention": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16},
+ "mamba": {"type": "mamba", "state_size": 16},
+ },
+ },
+ "mlp": {"type": "mlp", "intermediate_size": 256},
+ "normalization": {"type": "rms_norm", "epsilon": 1e-5},
+ },
+ },
+ )
+
+
+# Parameterized window size fixture (used by hf_sliding_layer and apriel_sliding_cache)
+@pytest.fixture(params=[4, 8, 16, 32])
+def window_size(request):
+ """Parameterized window sizes for sliding window tests."""
+ return request.param
diff --git a/fast_llm_external_models/tests/test_apriel2/test_cache.py b/fast_llm_external_models/tests/test_apriel2/test_cache.py
deleted file mode 100644
index ca8158b4f..000000000
--- a/fast_llm_external_models/tests/test_apriel2/test_cache.py
+++ /dev/null
@@ -1,1258 +0,0 @@
-"""Comprehensive tests for Apriel2Cache.
-
-Architecture Overview
-=====================
-Apriel2Cache manages state for autoregressive generation across different mixer types:
-
-1. **Attention Cache** (_AttentionCache): Stores key/value states
- - Supports sliding window (window_size) for SWA
- - Efficient roll optimization for single-token decode
-
-2. **SSM Cache** (_SSMCache): Stores conv and recurrent states
- - Used by Mamba, GDN, KDA
- - KDA uses tuple conv states (q, k, v), others use single tensor
-
-3. **Stochastic Mixer Routing**: For layers with multiple mixer options
- - Each mixer has independent cache (no sharing)
- - active_mixer pointer routes operations to correct sub-cache
- - Switching mixers preserves each mixer's independent state
-
-Cache Invalidation Semantics
-============================
-When switching between mixers in a stochastic layer:
-- Each mixer maintains its OWN independent history
-- Switching does NOT invalidate the previous mixer's cache
-- Switching does NOT copy state between mixers
-- To invalidate: call reset() explicitly
-
-This is intentional for training with stochastic sampling where each mixer
-should learn from its own history. For inference, main_mixer_name is fixed.
-
-Test Organization
-=================
-1. CREATION & PROPERTIES - Cache initialization, config parsing
-2. ATTENTION CACHE - Updates, sliding window, concatenation
-3. SSM CACHE - Conv states, recurrent states, KDA tuples
-4. STOCHASTIC ROUTING - Active mixer, isolation, switching
-5. CACHE INVALIDATION - Reset, per-mixer reset, coherence
-6. BEAM SEARCH - batch_repeat, reorder, select
-7. HF INTEGRATION - get_mask_sizes, indexing, properties
-8. GENERATION PATTERNS - Prefill→decode, crop→continue
-9. ERROR HANDLING - Guards, bounds, invalid operations
-"""
-
-import pytest
-import torch
-
-from fast_llm_external_models.apriel2.cache import (
- Apriel2Cache,
- _AttentionCache,
- _SSMCache,
-)
-
-
-# =============================================================================
-# FIXTURES - Configs and Sample Data
-# =============================================================================
-
-
-@pytest.fixture
-def tiny_attention_config():
- """Minimal config with pure attention layers."""
- from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config
-
- return Apriel2Config(
- vocab_size=100,
- hidden_size=64,
- decoder={
- "type": "fixed",
- "num_blocks": 2,
- "block": {
- "mixer": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16},
- "mlp": {"type": "mlp", "intermediate_size": 256},
- "normalization": {"type": "rms_norm", "epsilon": 1e-5},
- },
- },
- )
-
-
-@pytest.fixture
-def swa_config():
- """Config with sliding window attention."""
- from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config
-
- return Apriel2Config(
- vocab_size=100,
- hidden_size=64,
- decoder={
- "type": "fixed",
- "num_blocks": 2,
- "block": {
- "mixer": {
- "type": "attention",
- "heads": 4,
- "head_groups": 2,
- "head_size": 16,
- "window_size": 8, # Small for testing
- },
- "mlp": {"type": "mlp", "intermediate_size": 256},
- "normalization": {"type": "rms_norm", "epsilon": 1e-5},
- },
- },
- )
-
-
-@pytest.fixture
-def ssm_config():
- """Config with pure SSM layers (mamba)."""
- from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config
-
- return Apriel2Config(
- vocab_size=100,
- hidden_size=64,
- decoder={
- "type": "fixed",
- "num_blocks": 2,
- "block": {
- "mixer": {
- "type": "mamba",
- "d_inner": 128,
- "d_state": 16,
- "dt_rank": 4,
- "d_conv": 4,
- },
- "mlp": {"type": "mlp", "intermediate_size": 256},
- "normalization": {"type": "rms_norm", "epsilon": 1e-5},
- },
- },
- )
-
-
-@pytest.fixture
-def kda_config():
- """Config with pure KDA layers."""
- from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config
-
- return Apriel2Config(
- vocab_size=100,
- hidden_size=64,
- decoder={
- "type": "fixed",
- "num_blocks": 2,
- "block": {
- "mixer": {
- "type": "kda",
- "heads": 4,
- "head_dim": 16,
- "convolution_layer": {"kernel_size": 4},
- "normalization": {"epsilon": 1e-5},
- },
- "mlp": {"type": "mlp", "intermediate_size": 256},
- "normalization": {"type": "rms_norm", "epsilon": 1e-5},
- },
- },
- )
-
-
-@pytest.fixture
-def stochastic_config():
- """Config with stochastic mixer (attention + mamba)."""
- from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config
-
- return Apriel2Config(
- vocab_size=100,
- hidden_size=64,
- decoder={
- "type": "pattern",
- "num_blocks": 2,
- "pattern": ["attn", "stochastic"],
- "blocks": {
- "attn": {
- "mixer": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16},
- "mlp": {"type": "mlp", "intermediate_size": 256},
- "normalization": {"type": "rms_norm", "epsilon": 1e-5},
- },
- "stochastic": {
- "mixer": {
- "type": "stochastic",
- "main_mixer_name": "attention",
- "mixers": {
- "attention": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16},
- "mamba": {"type": "mamba", "d_inner": 128, "d_state": 16, "dt_rank": 4, "d_conv": 4},
- },
- },
- "mlp": {"type": "mlp", "intermediate_size": 256},
- "normalization": {"type": "rms_norm", "epsilon": 1e-5},
- },
- },
- },
- )
-
-
-@pytest.fixture
-def all_mixers_config():
- """Config with stochastic mixer containing all 5 mixer types."""
- from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config
-
- return Apriel2Config(
- vocab_size=100,
- hidden_size=64,
- decoder={
- "type": "pattern",
- "num_blocks": 2,
- "pattern": ["attn", "all_mixers"],
- "blocks": {
- "attn": {
- "mixer": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16},
- "mlp": {"type": "mlp", "intermediate_size": 256},
- "normalization": {"type": "rms_norm", "epsilon": 1e-5},
- },
- "all_mixers": {
- "mixer": {
- "type": "stochastic",
- "main_mixer_name": "attention",
- "mixers": {
- "attention": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16},
- "swa": {
- "type": "attention",
- "heads": 4,
- "head_groups": 2,
- "head_size": 16,
- "window_size": 1024,
- },
- "mamba": {"type": "mamba", "d_inner": 128, "d_state": 16, "dt_rank": 4, "d_conv": 4},
- "gdn": {
- "type": "gdn",
- "value_heads": 4,
- "key_heads": 2,
- "key_head_dim": 16,
- "value_head_dim": 16,
- "convolution_layer": {"kernel_size": 4},
- },
- "kda": {
- "type": "kda",
- "heads": 4,
- "head_dim": 16,
- "convolution_layer": {"kernel_size": 4},
- "normalization": {"epsilon": 1e-5},
- },
- },
- },
- "mlp": {"type": "mlp", "intermediate_size": 256},
- "normalization": {"type": "rms_norm", "epsilon": 1e-5},
- },
- },
- },
- )
-
-
-@pytest.fixture
-def multi_window_config():
- """Config with multiple different window sizes."""
- from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config
-
- return Apriel2Config(
- vocab_size=100,
- hidden_size=64,
- decoder={
- "type": "pattern",
- "num_blocks": 3,
- "pattern": ["full", "small_window", "large_window"],
- "blocks": {
- "full": {
- "mixer": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16},
- "mlp": {"type": "mlp", "intermediate_size": 256},
- "normalization": {"type": "rms_norm", "epsilon": 1e-5},
- },
- "small_window": {
- "mixer": {
- "type": "attention",
- "heads": 4,
- "head_groups": 2,
- "head_size": 16,
- "window_size": 512,
- },
- "mlp": {"type": "mlp", "intermediate_size": 256},
- "normalization": {"type": "rms_norm", "epsilon": 1e-5},
- },
- "large_window": {
- "mixer": {
- "type": "attention",
- "heads": 4,
- "head_groups": 2,
- "head_size": 16,
- "window_size": 2048,
- },
- "mlp": {"type": "mlp", "intermediate_size": 256},
- "normalization": {"type": "rms_norm", "epsilon": 1e-5},
- },
- },
- },
- )
-
-
-@pytest.fixture
-def sample_kv():
- """Sample key/value tensors: [batch=2, heads=4, seq=10, head_dim=16]."""
- return torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16)
-
-
-@pytest.fixture
-def sample_conv_single():
- """Sample single-tensor conv state: [batch=2, d_inner=128, kernel=4]."""
- return torch.randn(2, 128, 4)
-
-
-@pytest.fixture
-def sample_conv_tuple():
- """Sample tuple conv state for KDA: (q, k, v) each [batch=2, d=64, kernel=3]."""
- return (torch.randn(2, 64, 3), torch.randn(2, 64, 3), torch.randn(2, 64, 3))
-
-
-@pytest.fixture
-def sample_recurrent():
- """Sample recurrent state: [batch=2, heads=4, head_dim=16, d_state=16]."""
- return torch.randn(2, 4, 16, 16)
-
-
-# =============================================================================
-# SECTION 1: CACHE CREATION & PROPERTIES
-# =============================================================================
-
-
-class TestCacheCreation:
- """Test cache initialization from config."""
-
- def test_attention_cache_creation(self, tiny_attention_config):
- """Create cache for pure attention config."""
- cache = Apriel2Cache(tiny_attention_config)
-
- assert len(cache) == 2
- assert cache.mixer_types == ["attention", "attention"]
- assert all(isinstance(l, _AttentionCache) for l in cache.layers)
-
- def test_ssm_cache_creation(self, ssm_config):
- """Create cache for pure SSM config."""
- cache = Apriel2Cache(ssm_config)
-
- assert len(cache) == 2
- assert cache.mixer_types == ["mamba", "mamba"]
- assert all(isinstance(l, _SSMCache) for l in cache.layers)
-
- def test_kda_cache_creation(self, kda_config):
- """Create cache for pure KDA config."""
- cache = Apriel2Cache(kda_config)
-
- assert len(cache) == 2
- assert cache.mixer_types == ["kda", "kda"]
- assert all(isinstance(l, _SSMCache) for l in cache.layers)
-
- def test_stochastic_cache_creation(self, stochastic_config):
- """Create cache for stochastic mixer config."""
- cache = Apriel2Cache(stochastic_config)
-
- assert len(cache) == 2
- # Layer 0: pure attention, Layer 1: stochastic (dict)
- assert isinstance(cache.layers[0], _AttentionCache)
- assert isinstance(cache.layers[1], dict)
- assert set(cache.layers[1].keys()) == {"attention", "mamba"}
-
- def test_swa_window_captured(self, swa_config):
- """Verify sliding window size is captured."""
- cache = Apriel2Cache(swa_config)
-
- assert cache.layers[0].window == 8
- assert cache.is_sliding == [True, True]
-
- def test_active_mixers_initialized_none(self, stochastic_config):
- """Verify active_mixers starts as None for all layers."""
- cache = Apriel2Cache(stochastic_config)
-
- assert cache.active_mixers == [None, None]
-
-
-class TestCacheProperties:
- """Test cache property accessors."""
-
- def test_empty_cache_properties(self, tiny_attention_config):
- """Test properties of uninitialized cache."""
- cache = Apriel2Cache(tiny_attention_config)
-
- assert cache.is_initialized == False
- assert cache.has_previous_state == False
- assert cache.max_batch_size is None
- assert cache.max_cache_len is None
- assert cache.is_compileable == False
-
- def test_is_initialized_attention(self, tiny_attention_config, sample_kv):
- """is_initialized detects attention cache."""
- cache = Apriel2Cache(tiny_attention_config)
- cache.update(*sample_kv, layer_idx=0)
-
- assert cache.is_initialized == True
-
- def test_is_initialized_ssm(self, ssm_config, sample_conv_single):
- """is_initialized detects SSM cache."""
- cache = Apriel2Cache(ssm_config)
- cache.conv_states[0] = sample_conv_single
-
- assert cache.is_initialized == True
-
- def test_has_previous_state_ssm_only(self, ssm_config, sample_conv_single):
- """has_previous_state only looks at SSM conv states."""
- cache = Apriel2Cache(ssm_config)
-
- assert cache.has_previous_state == False
- cache.conv_states[0] = sample_conv_single
- assert cache.has_previous_state == True
-
- def test_has_previous_state_ignores_attention(self, tiny_attention_config, sample_kv):
- """has_previous_state ignores attention cache."""
- cache = Apriel2Cache(tiny_attention_config)
- cache.update(*sample_kv, layer_idx=0)
-
- # Attention cache is set, but has_previous_state only checks SSM
- assert cache.has_previous_state == False
-
- def test_max_batch_size_from_attention(self, tiny_attention_config, sample_kv):
- """max_batch_size from attention cache."""
- cache = Apriel2Cache(tiny_attention_config)
- cache.update(*sample_kv, layer_idx=0)
-
- assert cache.max_batch_size == 2
-
- def test_max_batch_size_from_ssm(self, ssm_config, sample_conv_single):
- """max_batch_size from SSM cache."""
- cache = Apriel2Cache(ssm_config)
- cache.conv_states[0] = sample_conv_single
-
- assert cache.max_batch_size == 2
-
- def test_max_batch_size_from_kda_tuple(self, kda_config, sample_conv_tuple):
- """max_batch_size from KDA tuple conv state."""
- cache = Apriel2Cache(kda_config)
- cache.conv_states[0] = sample_conv_tuple
-
- assert cache.max_batch_size == 2
-
- def test_max_cache_len_single_window(self, swa_config):
- """max_cache_len with single window size."""
- cache = Apriel2Cache(swa_config)
- assert cache.max_cache_len == 8
-
- def test_max_cache_len_multiple_windows(self, multi_window_config):
- """max_cache_len returns minimum window."""
- cache = Apriel2Cache(multi_window_config)
- assert cache.max_cache_len == 512 # min(512, 2048)
-
- def test_max_cache_len_no_windows(self, tiny_attention_config):
- """max_cache_len is None when no windows."""
- cache = Apriel2Cache(tiny_attention_config)
- assert cache.max_cache_len is None
-
- def test_is_sliding_mixed(self, multi_window_config):
- """is_sliding reflects per-layer window presence."""
- cache = Apriel2Cache(multi_window_config)
- assert cache.is_sliding == [False, True, True]
-
-
-# =============================================================================
-# SECTION 2: ATTENTION CACHE OPERATIONS
-# =============================================================================
-
-
-class TestAttentionCacheBasics:
- """Test basic attention cache operations."""
-
- def test_update_stores_kv(self, tiny_attention_config, sample_kv):
- """update() stores key/value states."""
- cache = Apriel2Cache(tiny_attention_config)
- key, value = sample_kv
-
- k_out, v_out = cache.update(key, value, layer_idx=0)
-
- torch.testing.assert_close(k_out, key)
- torch.testing.assert_close(v_out, value)
- assert cache.get_seq_length(0) == 10
-
- def test_update_concatenates(self, tiny_attention_config, sample_kv):
- """Subsequent updates concatenate."""
- cache = Apriel2Cache(tiny_attention_config)
- key, value = sample_kv
-
- cache.update(key, value, layer_idx=0)
- k_out, v_out = cache.update(key, value, layer_idx=0)
-
- assert k_out.shape[-2] == 20
- assert cache.get_seq_length(0) == 20
-
- def test_key_value_cache_accessors(self, tiny_attention_config, sample_kv):
- """Test key_cache and value_cache accessors."""
- cache = Apriel2Cache(tiny_attention_config)
- cache.update(*sample_kv, layer_idx=0)
-
- assert cache.key_cache[0] is not None
- assert cache.value_cache[0] is not None
- torch.testing.assert_close(cache.key_cache[0], sample_kv[0])
-
-
-class TestSlidingWindowAttention:
- """Test sliding window attention behavior."""
-
- def test_initial_within_window(self, swa_config):
- """Initial sequence within window is kept."""
- cache = Apriel2Cache(swa_config)
- key = torch.randn(2, 4, 5, 16) # seq=5 < window=8
- value = torch.randn(2, 4, 5, 16)
-
- cache.update(key, value, layer_idx=0)
-
- assert cache.get_seq_length(0) == 5
-
- def test_initial_exceeds_window(self, swa_config):
- """Initial sequence > window is truncated to last window tokens."""
- cache = Apriel2Cache(swa_config)
- key = torch.arange(12).float().view(1, 1, 12, 1).expand(2, 4, 12, 16)
- value = key.clone()
-
- k_out, v_out = cache.update(key, value, layer_idx=0)
-
- assert cache.get_seq_length(0) == 8
- # Should keep tokens 4-11 (last 8)
- assert k_out[0, 0, 0, 0].item() == 4.0
-
- def test_single_token_roll_path(self, swa_config):
- """Single token decode with full window uses efficient roll."""
- cache = Apriel2Cache(swa_config)
-
- # Fill window exactly
- key1 = torch.arange(8).float().view(1, 1, 8, 1).expand(2, 4, 8, 16)
- cache.update(key1, key1.clone(), layer_idx=0)
-
- # Decode single token
- key2 = torch.full((2, 4, 1, 16), 8.0)
- k_out, _ = cache.update(key2, key2.clone(), layer_idx=0)
-
- assert cache.get_seq_length(0) == 8
- assert k_out[0, 0, 0, 0].item() == 1.0 # Token 0 rolled out
- assert k_out[0, 0, 7, 0].item() == 8.0 # New token at end
-
- def test_multi_token_cat_slice_path(self, swa_config):
- """Multiple tokens use cat+slice path."""
- cache = Apriel2Cache(swa_config)
-
- # Fill window
- key1 = torch.randn(2, 4, 8, 16)
- cache.update(key1, key1.clone(), layer_idx=0)
-
- # Add 3 tokens
- key2 = torch.randn(2, 4, 3, 16)
- k_out, _ = cache.update(key2, key2.clone(), layer_idx=0)
-
- assert cache.get_seq_length(0) == 8
- torch.testing.assert_close(k_out[..., -3:, :], key2)
-
- def test_partial_then_fill_then_overflow(self, swa_config):
- """Progressive filling: partial → full → overflow."""
- cache = Apriel2Cache(swa_config)
-
- cache.update(torch.randn(2, 4, 5, 16), torch.randn(2, 4, 5, 16), layer_idx=0)
- assert cache.get_seq_length(0) == 5
-
- cache.update(torch.randn(2, 4, 3, 16), torch.randn(2, 4, 3, 16), layer_idx=0)
- assert cache.get_seq_length(0) == 8
-
- cache.update(torch.randn(2, 4, 2, 16), torch.randn(2, 4, 2, 16), layer_idx=0)
- assert cache.get_seq_length(0) == 8
-
- def test_contiguous_output(self, swa_config):
- """Outputs are contiguous after windowing."""
- cache = Apriel2Cache(swa_config)
-
- cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0)
- cache.update(torch.randn(2, 4, 5, 16), torch.randn(2, 4, 5, 16), layer_idx=0)
-
- assert cache.layers[0].key.is_contiguous()
- assert cache.layers[0].value.is_contiguous()
-
-
-# =============================================================================
-# SECTION 3: SSM CACHE OPERATIONS
-# =============================================================================
-
-
-class TestSSMCacheBasics:
- """Test basic SSM cache operations."""
-
- def test_conv_states_accessor(self, ssm_config, sample_conv_single):
- """Test conv_states accessor."""
- cache = Apriel2Cache(ssm_config)
-
- cache.conv_states[0] = sample_conv_single
- torch.testing.assert_close(cache.conv_states[0], sample_conv_single)
-
- def test_recurrent_states_accessor(self, ssm_config, sample_recurrent):
- """Test recurrent_states accessor."""
- cache = Apriel2Cache(ssm_config)
-
- cache.recurrent_states[0] = sample_recurrent
- torch.testing.assert_close(cache.recurrent_states[0], sample_recurrent)
-
- def test_ssm_seq_length_always_zero(self, ssm_config, sample_conv_single):
- """get_seq_length returns 0 for SSM (no KV cache)."""
- cache = Apriel2Cache(ssm_config)
- cache.conv_states[0] = sample_conv_single
-
- assert cache.get_seq_length(0) == 0
-
-
-class TestKDACache:
- """Test KDA-specific cache operations with tuple conv states."""
-
- def test_tuple_conv_storage(self, kda_config, sample_conv_tuple):
- """KDA stores tuple conv states."""
- cache = Apriel2Cache(kda_config)
-
- cache.conv_states[0] = sample_conv_tuple
-
- assert isinstance(cache.conv_states[0], tuple)
- assert len(cache.conv_states[0]) == 3
- for i in range(3):
- torch.testing.assert_close(cache.conv_states[0][i], sample_conv_tuple[i])
-
- def test_tuple_with_recurrent(self, kda_config, sample_conv_tuple, sample_recurrent):
- """KDA can have both tuple conv and recurrent states."""
- cache = Apriel2Cache(kda_config)
-
- cache.conv_states[0] = sample_conv_tuple
- cache.recurrent_states[0] = sample_recurrent
-
- assert isinstance(cache.conv_states[0], tuple)
- assert cache.recurrent_states[0] is not None
-
- def test_has_previous_state_detects_tuple(self, kda_config, sample_conv_tuple):
- """has_previous_state works with tuple conv states."""
- cache = Apriel2Cache(kda_config)
-
- assert cache.has_previous_state == False
- cache.conv_states[0] = sample_conv_tuple
- assert cache.has_previous_state == True
-
-
-# =============================================================================
-# SECTION 4: STOCHASTIC ROUTING
-# =============================================================================
-
-
-class TestStochasticRouting:
- """Test stochastic mixer cache routing."""
-
- def test_set_active_mixer(self, stochastic_config):
- """set_active_mixer sets the pointer."""
- cache = Apriel2Cache(stochastic_config)
-
- cache.set_active_mixer(1, "attention")
- assert cache.active_mixers[1] == "attention"
-
- cache.set_active_mixer(1, "mamba")
- assert cache.active_mixers[1] == "mamba"
-
- def test_operations_route_to_active(self, stochastic_config, sample_kv):
- """Operations route to currently active mixer."""
- cache = Apriel2Cache(stochastic_config)
-
- cache.set_active_mixer(1, "attention")
- cache.update(*sample_kv, layer_idx=1)
- attn_len = cache.get_seq_length(1)
-
- cache.set_active_mixer(1, "mamba")
- mamba_len = cache.get_seq_length(1)
-
- assert attn_len == 10
- assert mamba_len == 0 # Mamba cache is separate and empty
-
- def test_each_mixer_independent_cache(self, stochastic_config, sample_kv, sample_conv_single):
- """Each mixer maintains independent cache."""
- cache = Apriel2Cache(stochastic_config)
-
- # Fill attention cache
- cache.set_active_mixer(1, "attention")
- cache.update(*sample_kv, layer_idx=1)
-
- # Fill mamba cache
- cache.set_active_mixer(1, "mamba")
- cache.conv_states[1] = sample_conv_single
-
- # Both preserved
- cache.set_active_mixer(1, "attention")
- assert cache.get_seq_length(1) == 10
-
- cache.set_active_mixer(1, "mamba")
- torch.testing.assert_close(cache.conv_states[1], sample_conv_single)
-
-
-class TestMixerSwitching:
- """Test behavior when switching between mixers mid-generation."""
-
- def test_switch_preserves_previous_state(self, stochastic_config, sample_kv):
- """Switching mixers preserves previous mixer's state."""
- cache = Apriel2Cache(stochastic_config)
-
- cache.set_active_mixer(1, "attention")
- cache.update(*sample_kv, layer_idx=1)
- original_key = cache.layers[1]["attention"].key.clone()
-
- # Switch to mamba, do something
- cache.set_active_mixer(1, "mamba")
- cache.conv_states[1] = torch.randn(2, 128, 4)
-
- # Switch back - attention unchanged
- cache.set_active_mixer(1, "attention")
- torch.testing.assert_close(cache.layers[1]["attention"].key, original_key)
-
- def test_switch_does_not_copy_state(self, stochastic_config, sample_kv):
- """Switching does NOT copy state between mixers."""
- cache = Apriel2Cache(stochastic_config)
-
- # Fill attention with 10 tokens
- cache.set_active_mixer(1, "attention")
- cache.update(*sample_kv, layer_idx=1)
-
- # Switch to mamba - it has NO history from attention
- cache.set_active_mixer(1, "mamba")
- assert cache.conv_states[1] is None
- assert cache.recurrent_states[1] is None
-
- def test_has_previous_state_checks_all_sub_caches(self, stochastic_config):
- """has_previous_state checks ALL sub-caches, not just active."""
- cache = Apriel2Cache(stochastic_config)
-
- cache.set_active_mixer(1, "mamba")
- cache.conv_states[1] = torch.randn(2, 128, 4)
-
- # Even if we switch away, has_previous_state still detects it
- cache.set_active_mixer(1, "attention")
- assert cache.has_previous_state == True
-
-
-class TestAllMixerTypes:
- """Test cache isolation across all 5 mixer types."""
-
- def test_all_five_mixer_types_isolated(self, all_mixers_config):
- """All 5 mixer types maintain isolated caches."""
- cache = Apriel2Cache(all_mixers_config)
- layer_idx = 1 # Stochastic layer
-
- # Fill each mixer's cache
- cache.set_active_mixer(layer_idx, "attention")
- attn_kv = (torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16))
- cache.update(*attn_kv, layer_idx=layer_idx)
-
- cache.set_active_mixer(layer_idx, "swa")
- swa_kv = (torch.randn(2, 4, 5, 16), torch.randn(2, 4, 5, 16))
- cache.update(*swa_kv, layer_idx=layer_idx)
-
- cache.set_active_mixer(layer_idx, "mamba")
- mamba_conv = torch.randn(2, 128, 4)
- cache.conv_states[layer_idx] = mamba_conv
-
- cache.set_active_mixer(layer_idx, "gdn")
- gdn_conv = torch.randn(2, 64, 3)
- cache.conv_states[layer_idx] = gdn_conv
-
- cache.set_active_mixer(layer_idx, "kda")
- kda_conv = (torch.randn(2, 64, 3), torch.randn(2, 64, 3), torch.randn(2, 64, 3))
- cache.conv_states[layer_idx] = kda_conv
-
- # Verify all preserved
- cache.set_active_mixer(layer_idx, "attention")
- assert cache.get_seq_length(layer_idx) == 10
-
- cache.set_active_mixer(layer_idx, "swa")
- assert cache.get_seq_length(layer_idx) == 5
-
- cache.set_active_mixer(layer_idx, "mamba")
- torch.testing.assert_close(cache.conv_states[layer_idx], mamba_conv)
-
- cache.set_active_mixer(layer_idx, "gdn")
- torch.testing.assert_close(cache.conv_states[layer_idx], gdn_conv)
-
- cache.set_active_mixer(layer_idx, "kda")
- assert isinstance(cache.conv_states[layer_idx], tuple)
-
-
-# =============================================================================
-# SECTION 5: CACHE INVALIDATION
-# =============================================================================
-
-
-class TestCacheInvalidation:
- """Test cache invalidation and reset semantics.
-
- Key principle: Each mixer maintains independent state. To invalidate:
- - reset() clears ALL caches across ALL layers and mixers
- - There is no per-mixer reset (by design - each mixer is independent)
- """
-
- def test_reset_clears_attention(self, tiny_attention_config, sample_kv):
- """reset() clears attention cache."""
- cache = Apriel2Cache(tiny_attention_config)
- cache.update(*sample_kv, layer_idx=0)
-
- cache.reset()
-
- assert cache.is_initialized == False
- assert cache.get_seq_length(0) == 0
-
- def test_reset_clears_ssm(self, ssm_config, sample_conv_single, sample_recurrent):
- """reset() clears SSM cache."""
- cache = Apriel2Cache(ssm_config)
- cache.conv_states[0] = sample_conv_single
- cache.recurrent_states[0] = sample_recurrent
-
- cache.reset()
-
- assert cache.has_previous_state == False
- assert cache.conv_states[0] is None
- assert cache.recurrent_states[0] is None
-
- def test_reset_clears_kda_tuple(self, kda_config, sample_conv_tuple):
- """reset() clears KDA tuple conv states."""
- cache = Apriel2Cache(kda_config)
- cache.conv_states[0] = sample_conv_tuple
-
- cache.reset()
-
- assert cache.conv_states[0] is None
-
- def test_reset_clears_all_stochastic_mixers(self, all_mixers_config):
- """reset() clears ALL mixer caches in stochastic layer."""
- cache = Apriel2Cache(all_mixers_config)
- layer_idx = 1
-
- # Fill all mixers
- cache.set_active_mixer(layer_idx, "attention")
- cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=layer_idx)
-
- cache.set_active_mixer(layer_idx, "mamba")
- cache.conv_states[layer_idx] = torch.randn(2, 128, 4)
-
- cache.set_active_mixer(layer_idx, "kda")
- cache.conv_states[layer_idx] = (torch.randn(2, 64, 3),) * 3
-
- cache.reset()
-
- # All cleared
- assert cache.layers[layer_idx]["attention"].key is None
- assert cache.layers[layer_idx]["mamba"].conv is None
- assert cache.layers[layer_idx]["kda"].conv is None
-
- def test_crop_truncates_attention(self, tiny_attention_config, sample_kv):
- """crop() truncates attention cache to max_length."""
- cache = Apriel2Cache(tiny_attention_config)
- cache.update(*sample_kv, layer_idx=0)
-
- cache.crop(5)
-
- assert cache.get_seq_length(0) == 5
-
- def test_crop_affects_all_layers(self, tiny_attention_config, sample_kv):
- """crop() affects all layers."""
- cache = Apriel2Cache(tiny_attention_config)
- cache.update(*sample_kv, layer_idx=0)
- cache.update(*sample_kv, layer_idx=1)
-
- cache.crop(3)
-
- assert cache.get_seq_length(0) == 3
- assert cache.get_seq_length(1) == 3
-
- def test_crop_ignores_ssm(self, ssm_config, sample_conv_single):
- """crop() only affects attention, not SSM."""
- cache = Apriel2Cache(ssm_config)
- cache.conv_states[0] = sample_conv_single
-
- cache.crop(5) # Should not crash
-
- # Conv state unchanged
- torch.testing.assert_close(cache.conv_states[0], sample_conv_single)
-
-
-# =============================================================================
-# SECTION 6: BEAM SEARCH OPERATIONS
-# =============================================================================
-
-
-class TestBatchRepeatInterleave:
- """Test batch_repeat_interleave for beam search expansion."""
-
- def test_repeat_attention(self, tiny_attention_config, sample_kv):
- """Repeat attention cache for beam search."""
- cache = Apriel2Cache(tiny_attention_config)
- cache.update(*sample_kv, layer_idx=0)
-
- cache.batch_repeat_interleave(3)
-
- assert cache.max_batch_size == 6 # 2 * 3
-
- def test_repeat_ssm(self, ssm_config, sample_conv_single, sample_recurrent):
- """Repeat SSM cache for beam search."""
- cache = Apriel2Cache(ssm_config)
- cache.conv_states[0] = sample_conv_single
- cache.recurrent_states[0] = sample_recurrent
-
- cache.batch_repeat_interleave(4)
-
- assert cache.conv_states[0].shape[0] == 8 # 2 * 4
- assert cache.recurrent_states[0].shape[0] == 8
-
- def test_repeat_kda_tuple(self, kda_config, sample_conv_tuple):
- """Repeat KDA tuple conv states."""
- cache = Apriel2Cache(kda_config)
- cache.conv_states[0] = sample_conv_tuple
-
- cache.batch_repeat_interleave(3)
-
- for c in cache.conv_states[0]:
- assert c.shape[0] == 6
-
- def test_repeat_stochastic_all_mixers(self, all_mixers_config):
- """Repeat all mixer caches in stochastic layer."""
- cache = Apriel2Cache(all_mixers_config)
- layer_idx = 1
-
- cache.set_active_mixer(layer_idx, "attention")
- cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=layer_idx)
-
- cache.set_active_mixer(layer_idx, "mamba")
- cache.conv_states[layer_idx] = torch.randn(2, 128, 4)
-
- cache.batch_repeat_interleave(2)
-
- cache.set_active_mixer(layer_idx, "attention")
- assert cache.layers[layer_idx]["attention"].key.shape[0] == 4
-
- cache.set_active_mixer(layer_idx, "mamba")
- assert cache.conv_states[layer_idx].shape[0] == 4
-
- def test_repeat_skips_none(self, tiny_attention_config):
- """Repeat gracefully skips None caches."""
- cache = Apriel2Cache(tiny_attention_config)
- # Don't fill anything
-
- cache.batch_repeat_interleave(3) # Should not crash
-
- assert cache.max_batch_size is None
-
-
-class TestReorderCache:
- """Test reorder_cache for beam search hypothesis selection."""
-
- def test_reorder_attention(self, tiny_attention_config, sample_kv):
- """Reorder attention cache."""
- cache = Apriel2Cache(tiny_attention_config)
- key, value = sample_kv
- # Make batches distinguishable
- key = torch.arange(2).float().view(2, 1, 1, 1).expand(2, 4, 10, 16)
- cache.update(key, key.clone(), layer_idx=0)
-
- beam_idx = torch.tensor([1, 0])
- cache.reorder_cache(beam_idx)
-
- assert cache.layers[0].key[0, 0, 0, 0].item() == 1.0
- assert cache.layers[0].key[1, 0, 0, 0].item() == 0.0
-
- def test_reorder_ssm(self, ssm_config):
- """Reorder SSM cache."""
- cache = Apriel2Cache(ssm_config)
- conv = torch.arange(2).float().view(2, 1, 1).expand(2, 128, 4)
- cache.conv_states[0] = conv.clone()
-
- beam_idx = torch.tensor([1, 0])
- cache.reorder_cache(beam_idx)
-
- assert cache.conv_states[0][0, 0, 0].item() == 1.0
-
- def test_reorder_kda_tuple(self, kda_config):
- """Reorder KDA tuple conv states."""
- cache = Apriel2Cache(kda_config)
- conv_q = torch.arange(2).float().view(2, 1, 1).expand(2, 64, 3)
- cache.conv_states[0] = (conv_q.clone(), conv_q.clone(), conv_q.clone())
-
- beam_idx = torch.tensor([1, 0])
- cache.reorder_cache(beam_idx)
-
- for c in cache.conv_states[0]:
- assert c[0, 0, 0].item() == 1.0
-
-
-class TestBatchSelectIndices:
- """Test batch_select_indices for beam selection."""
-
- def test_select_attention(self, tiny_attention_config, sample_kv):
- """Select subset of attention cache."""
- cache = Apriel2Cache(tiny_attention_config)
- key = torch.arange(4).float().view(4, 1, 1, 1).expand(4, 4, 10, 16)
- cache.update(key, key.clone(), layer_idx=0)
-
- indices = torch.tensor([0, 3])
- cache.batch_select_indices(indices)
-
- assert cache.max_batch_size == 2
- assert cache.layers[0].key[0, 0, 0, 0].item() == 0.0
- assert cache.layers[0].key[1, 0, 0, 0].item() == 3.0
-
- def test_select_kda_tuple(self, kda_config):
- """Select subset of KDA tuple conv states."""
- cache = Apriel2Cache(kda_config)
- conv = tuple(torch.arange(4).float().view(4, 1, 1).expand(4, 64, 3).clone() for _ in range(3))
- cache.conv_states[0] = conv
-
- indices = torch.tensor([1, 2])
- cache.batch_select_indices(indices)
-
- for c in cache.conv_states[0]:
- assert c.shape[0] == 2
- assert c[0, 0, 0].item() == 1.0
-
-
-# =============================================================================
-# SECTION 7: HUGGINGFACE INTEGRATION
-# =============================================================================
-
-
-class TestGetMaskSizes:
- """Test get_mask_sizes() for attention mask computation."""
-
- def test_empty_cache(self, tiny_attention_config):
- """Mask sizes with empty cache."""
- cache = Apriel2Cache(tiny_attention_config)
- cache_position = torch.arange(10)
-
- kv_length, kv_offset = cache.get_mask_sizes(cache_position, layer_idx=0)
-
- assert kv_length == 10
- assert kv_offset == 0
-
- def test_with_cached_tokens(self, tiny_attention_config, sample_kv):
- """Mask sizes with cached tokens."""
- cache = Apriel2Cache(tiny_attention_config)
- cache.update(*sample_kv, layer_idx=0) # 10 tokens
-
- cache_position = torch.arange(5)
- kv_length, kv_offset = cache.get_mask_sizes(cache_position, layer_idx=0)
-
- assert kv_length == 15 # 10 + 5
- assert kv_offset == 10
-
- def test_single_token_decode(self, tiny_attention_config, sample_kv):
- """Mask sizes for single token decode."""
- cache = Apriel2Cache(tiny_attention_config)
- cache.update(*sample_kv, layer_idx=0)
-
- cache_position = torch.arange(1)
- kv_length, kv_offset = cache.get_mask_sizes(cache_position, layer_idx=0)
-
- assert kv_length == 11
- assert kv_offset == 10
-
- def test_ssm_returns_query_only(self, ssm_config, sample_conv_single):
- """SSM layers return query_length (no KV cache)."""
- cache = Apriel2Cache(ssm_config)
- cache.conv_states[0] = sample_conv_single
-
- cache_position = torch.arange(5)
- kv_length, kv_offset = cache.get_mask_sizes(cache_position, layer_idx=0)
-
- assert kv_length == 5
- assert kv_offset == 0
-
-
-class TestCacheIndexing:
- """Test cache[idx] indexing."""
-
- def test_attention_returns_kv(self, tiny_attention_config, sample_kv):
- """Indexing attention layer returns (key, value)."""
- cache = Apriel2Cache(tiny_attention_config)
- cache.update(*sample_kv, layer_idx=0)
-
- result = cache[0]
-
- assert isinstance(result, tuple)
- torch.testing.assert_close(result[0], sample_kv[0])
-
- def test_empty_returns_empty_tensors(self, tiny_attention_config):
- """Indexing empty layer returns empty tensors."""
- cache = Apriel2Cache(tiny_attention_config)
-
- result = cache[0]
-
- assert result[0].numel() == 0
- assert result[1].numel() == 0
-
- def test_ssm_returns_empty(self, ssm_config, sample_conv_single):
- """Indexing SSM layer returns empty (no KV)."""
- cache = Apriel2Cache(ssm_config)
- cache.conv_states[0] = sample_conv_single
-
- result = cache[0]
-
- assert result[0].numel() == 0
-
- def test_stochastic_attention_returns_kv(self, stochastic_config, sample_kv):
- """Indexing stochastic with attention active returns KV."""
- cache = Apriel2Cache(stochastic_config)
- cache.set_active_mixer(1, "attention")
- cache.update(*sample_kv, layer_idx=1)
-
- result = cache[1]
-
- torch.testing.assert_close(result[0], sample_kv[0])
-
-
-# =============================================================================
-# SECTION 8: GENERATION PATTERNS
-# =============================================================================
-
-
-class TestGenerationPatterns:
- """Test real-world generation patterns."""
-
- def test_prefill_then_decode(self, tiny_attention_config, sample_kv):
- """Prefill with long prompt, then decode token-by-token."""
- cache = Apriel2Cache(tiny_attention_config)
- cache.update(*sample_kv, layer_idx=0) # Prefill 10 tokens
-
- for _ in range(5):
- new_kv = (torch.randn(2, 4, 1, 16), torch.randn(2, 4, 1, 16))
- cache.update(*new_kv, layer_idx=0)
-
- assert cache.get_seq_length(0) == 15
-
- def test_crop_then_continue(self, tiny_attention_config, sample_kv):
- """Crop old context, continue generation."""
- cache = Apriel2Cache(tiny_attention_config)
- cache.update(*sample_kv, layer_idx=0)
- cache.update(*sample_kv, layer_idx=0) # 20 tokens
-
- cache.crop(5) # Keep last 5
- cache.update(torch.randn(2, 4, 3, 16), torch.randn(2, 4, 3, 16), layer_idx=0)
-
- assert cache.get_seq_length(0) == 8
-
- def test_reset_between_generations(self, tiny_attention_config, sample_kv):
- """Reset between independent generations."""
- cache = Apriel2Cache(tiny_attention_config)
-
- # First generation
- cache.update(*sample_kv, layer_idx=0)
- assert cache.is_initialized == True
-
- # Reset
- cache.reset()
- assert cache.is_initialized == False
-
- # Second generation
- cache.update(*sample_kv, layer_idx=0)
- assert cache.get_seq_length(0) == 10
-
- def test_multi_layer_consistency(self, tiny_attention_config, sample_kv):
- """All layers updated consistently."""
- cache = Apriel2Cache(tiny_attention_config)
-
- for layer_idx in range(2):
- cache.update(*sample_kv, layer_idx=layer_idx)
- cache.update(torch.randn(2, 4, 1, 16), torch.randn(2, 4, 1, 16), layer_idx=layer_idx)
-
- for layer_idx in range(2):
- assert cache.get_seq_length(layer_idx) == 11
-
-
-# =============================================================================
-# SECTION 9: ERROR HANDLING
-# =============================================================================
-
-
-class TestErrorHandling:
- """Test error conditions and guards."""
-
- def test_stochastic_update_without_active_mixer(self, stochastic_config):
- """update() on stochastic without active_mixer raises."""
- cache = Apriel2Cache(stochastic_config)
-
- with pytest.raises(RuntimeError, match="needs active_mixer set"):
- cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=1)
-
- def test_stochastic_accessor_without_active_mixer(self, stochastic_config):
- """Accessing stochastic cache without active_mixer raises."""
- cache = Apriel2Cache(stochastic_config)
-
- with pytest.raises(RuntimeError, match="requires set_active_mixer"):
- _ = cache.conv_states[1]
-
- def test_accessor_error_lists_available_mixers(self, stochastic_config):
- """Error message lists available mixers."""
- cache = Apriel2Cache(stochastic_config)
-
- with pytest.raises(RuntimeError, match="Available mixers:"):
- _ = cache.key_cache[1]
-
- def test_invalid_mixer_name(self, stochastic_config):
- """Invalid mixer name raises KeyError on access."""
- cache = Apriel2Cache(stochastic_config)
- cache.set_active_mixer(1, "nonexistent")
-
- with pytest.raises(KeyError):
- cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=1)
-
- def test_layer_idx_out_of_bounds(self, tiny_attention_config):
- """Out-of-bounds layer_idx raises IndexError."""
- cache = Apriel2Cache(tiny_attention_config)
-
- with pytest.raises(IndexError):
- cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=999)
-
-
-# =============================================================================
-# SECTION 10: INTERNAL CLASSES
-# =============================================================================
-
-
-class TestAttentionCacheInternal:
- """Test internal _AttentionCache class directly."""
-
- def test_unbounded_growth(self):
- """No window allows unbounded growth."""
- cache = _AttentionCache(window=None)
-
- for _ in range(10):
- cache.update(torch.randn(2, 4, 100, 16), torch.randn(2, 4, 100, 16))
-
- assert cache.key.shape[-2] == 1000
-
- def test_window_enforced(self):
- """Window caps cache size."""
- cache = _AttentionCache(window=50)
-
- for _ in range(10):
- cache.update(torch.randn(2, 4, 100, 16), torch.randn(2, 4, 100, 16))
-
- assert cache.key.shape[-2] == 50
-
-
-class TestSSMCacheInternal:
- """Test internal _SSMCache class directly."""
-
- def test_initial_none(self):
- """Initial states are None."""
- cache = _SSMCache()
-
- assert cache.conv is None
- assert cache.recurrent is None
-
- def test_stores_tuple(self):
- """Can store tuple (for KDA)."""
- cache = _SSMCache()
- cache.conv = (torch.randn(2, 64, 3),) * 3
-
- assert isinstance(cache.conv, tuple)
diff --git a/fast_llm_external_models/tests/test_apriel2/test_cache_apriel2_specific.py b/fast_llm_external_models/tests/test_apriel2/test_cache_apriel2_specific.py
new file mode 100644
index 000000000..e0e4db2d3
--- /dev/null
+++ b/fast_llm_external_models/tests/test_apriel2/test_cache_apriel2_specific.py
@@ -0,0 +1,342 @@
+"""Tests for Apriel2-specific cache behaviors with no HuggingFace equivalent.
+
+This module tests features unique to Apriel2Cache that cannot be validated
+against upstream HF implementations:
+
+1. Stochastic mixer routing (switching between attention/SSM per layer)
+2. Multi-mixer layer support
+3. Error handling and guard rails
+4. Beam search operations (batch_repeat, reorder, select)
+5. Crop operation
+
+Fixtures used from conftest.py:
+ - stochastic_config: Stochastic mixer config with attention and mamba
+ - attention_config: Pure attention config
+ - ssm_config: Pure SSM config
+"""
+
+import pytest
+import torch
+
+from fast_llm_external_models.apriel2.cache import Apriel2Cache, _AttentionCache, _SSMCache
+
+
+# =============================================================================
+# STOCHASTIC MIXER ROUTING
+# =============================================================================
+
+
+class TestStochasticMixerRouting:
+ """Test routing operations to correct sub-cache in stochastic layers."""
+
+ def test_set_active_mixer(self, stochastic_config):
+ """set_active_mixer updates routing for layer."""
+ cache = Apriel2Cache(stochastic_config)
+
+ cache.set_active_mixer(0, "attention")
+ assert cache.active_mixers[0] == "attention"
+
+ cache.set_active_mixer(0, "mamba")
+ assert cache.active_mixers[0] == "mamba"
+
+ def test_update_routes_to_active_mixer(self, stochastic_config):
+ """update() stores in correct sub-cache based on active_mixer."""
+ cache = Apriel2Cache(stochastic_config)
+
+ # Route to attention
+ cache.set_active_mixer(0, "attention")
+ cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0)
+
+ # Attention sub-cache should have data
+ assert cache.layers[0]["attention"].key is not None
+ # Mamba sub-cache should be empty
+ assert cache.layers[0]["mamba"].conv is None
+
+ def test_each_mixer_has_independent_cache(self, stochastic_config):
+ """Each mixer in a stochastic layer has its own independent state."""
+ cache = Apriel2Cache(stochastic_config)
+
+ # Store in attention
+ cache.set_active_mixer(0, "attention")
+ cache.update(torch.randn(2, 4, 5, 16), torch.randn(2, 4, 5, 16), layer_idx=0)
+
+ # Switch to mamba and store
+ cache.set_active_mixer(0, "mamba")
+ cache.layers[0]["mamba"].conv = torch.randn(2, 64, 4)
+
+ # Attention data should be unchanged
+ assert cache.layers[0]["attention"].cumulative_length == 5
+
+ def test_switching_preserves_all_states(self, stochastic_config):
+ """Switching active_mixer doesn't clear other mixer's state."""
+ cache = Apriel2Cache(stochastic_config)
+
+ # Build up attention state
+ cache.set_active_mixer(0, "attention")
+ cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0)
+ attn_key = cache.layers[0]["attention"].key.clone()
+
+ # Switch to mamba
+ cache.set_active_mixer(0, "mamba")
+
+ # Attention state preserved
+ torch.testing.assert_close(cache.layers[0]["attention"].key, attn_key)
+
+
+# =============================================================================
+# ERROR HANDLING
+# =============================================================================
+
+
+class TestErrorHandling:
+ """Test guard rails and error messages."""
+
+ def test_update_without_active_mixer_raises(self, stochastic_config):
+ """update() on stochastic layer without active_mixer raises RuntimeError."""
+ cache = Apriel2Cache(stochastic_config)
+
+ with pytest.raises(RuntimeError, match="needs active_mixer set"):
+ cache.update(torch.randn(2, 4, 5, 16), torch.randn(2, 4, 5, 16), layer_idx=0)
+
+ def test_accessor_without_active_mixer_raises(self, stochastic_config):
+ """Accessing key_cache/value_cache without active_mixer raises RuntimeError."""
+ cache = Apriel2Cache(stochastic_config)
+
+ with pytest.raises(RuntimeError, match="requires set_active_mixer"):
+ _ = cache.key_cache[0]
+
+ def test_error_message_lists_available_mixers(self, stochastic_config):
+ """Error message includes list of available mixers."""
+ cache = Apriel2Cache(stochastic_config)
+
+ with pytest.raises(RuntimeError, match="attention.*mamba|mamba.*attention"):
+ _ = cache.key_cache[0]
+
+
+# =============================================================================
+# BEAM SEARCH OPERATIONS
+# =============================================================================
+
+
+class TestBeamSearchOperations:
+ """Test batch manipulation for beam search."""
+
+ def test_batch_repeat_interleave_attention(self, attention_config):
+ """batch_repeat_interleave expands batch dimension."""
+ cache = Apriel2Cache(attention_config)
+ cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0)
+
+ cache.batch_repeat_interleave(3)
+
+ assert cache.layers[0].key.shape[0] == 6 # 2 * 3
+
+ def test_batch_repeat_interleave_ssm(self, ssm_config):
+ """batch_repeat_interleave works for SSM caches."""
+ cache = Apriel2Cache(ssm_config)
+ cache.layers[0].conv = torch.randn(2, 64, 4)
+
+ cache.batch_repeat_interleave(3)
+
+ assert cache.layers[0].conv.shape[0] == 6
+
+ def test_batch_repeat_interleave_kda_tuple(self, ssm_config):
+ """batch_repeat_interleave handles KDA tuple conv states."""
+ cache = Apriel2Cache(ssm_config)
+ cache.layers[0].conv = (torch.randn(2, 64, 4),) * 3
+
+ cache.batch_repeat_interleave(3)
+
+ assert cache.layers[0].conv[0].shape[0] == 6
+
+ def test_reorder_cache_attention(self, attention_config):
+ """reorder_cache reorders batch dimension."""
+ cache = Apriel2Cache(attention_config)
+ k = torch.arange(4).float().view(4, 1, 1, 1).expand(4, 4, 10, 16)
+ cache.update(k, k.clone(), layer_idx=0)
+
+ beam_idx = torch.tensor([3, 2, 1, 0])
+ cache.reorder_cache(beam_idx)
+
+ # Check reordering
+ assert cache.layers[0].key[0, 0, 0, 0].item() == 3.0
+ assert cache.layers[0].key[3, 0, 0, 0].item() == 0.0
+
+ def test_batch_select_indices(self, attention_config):
+ """batch_select_indices selects subset of batch."""
+ cache = Apriel2Cache(attention_config)
+ cache.update(torch.randn(4, 4, 10, 16), torch.randn(4, 4, 10, 16), layer_idx=0)
+
+ indices = torch.tensor([0, 2])
+ cache.batch_select_indices(indices)
+
+ assert cache.layers[0].key.shape[0] == 2
+
+ def test_reorder_cache_ssm_tuple(self, ssm_config):
+ """reorder_cache handles KDA tuple conv states."""
+ cache = Apriel2Cache(ssm_config)
+ # Create distinguishable tensors for each batch position
+ conv0 = torch.full((1, 64, 4), 0.0)
+ conv1 = torch.full((1, 64, 4), 1.0)
+ conv2 = torch.full((1, 64, 4), 2.0)
+ cache.layers[0].conv = (
+ torch.cat([conv0, conv1, conv2], dim=0),
+ torch.cat([conv0, conv1, conv2], dim=0),
+ torch.cat([conv0, conv1, conv2], dim=0),
+ )
+
+ beam_idx = torch.tensor([2, 1, 0])
+ cache.reorder_cache(beam_idx)
+
+ # Check reordering: batch[0] should now have value 2.0
+ assert cache.layers[0].conv[0][0, 0, 0].item() == 2.0
+ assert cache.layers[0].conv[0][2, 0, 0].item() == 0.0
+
+ def test_batch_select_indices_ssm_tuple(self, ssm_config):
+ """batch_select_indices handles KDA tuple conv states."""
+ cache = Apriel2Cache(ssm_config)
+ cache.layers[0].conv = (torch.randn(4, 64, 4),) * 3
+
+ indices = torch.tensor([0, 2])
+ cache.batch_select_indices(indices)
+
+ assert cache.layers[0].conv[0].shape[0] == 2
+ assert cache.layers[0].conv[1].shape[0] == 2
+ assert cache.layers[0].conv[2].shape[0] == 2
+
+
+# =============================================================================
+# CROP OPERATION
+# =============================================================================
+
+
+class TestCropOperation:
+ """Test cache truncation."""
+
+ def test_crop_truncates_attention(self, attention_config):
+ """crop() truncates attention cache."""
+ cache = Apriel2Cache(attention_config)
+ cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0)
+
+ cache.crop(5)
+
+ assert cache.layers[0].key.shape[-2] == 5
+ assert cache.get_seq_length(0) == 5
+
+ def test_crop_affects_all_layers(self, attention_config):
+ """crop() affects all layers."""
+ cache = Apriel2Cache(attention_config)
+ cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0)
+ cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=1)
+
+ cache.crop(3)
+
+ assert cache.layers[0].key.shape[-2] == 3
+ assert cache.layers[1].key.shape[-2] == 3
+
+ def test_crop_ignores_ssm(self, ssm_config):
+ """crop() doesn't affect SSM caches (they don't have seq dimension)."""
+ cache = Apriel2Cache(ssm_config)
+ cache.layers[0].conv = torch.randn(2, 64, 4)
+
+ # Should not raise
+ cache.crop(5)
+
+ # SSM state unchanged
+ assert cache.layers[0].conv.shape == (2, 64, 4)
+
+
+# =============================================================================
+# CACHE PROPERTIES
+# =============================================================================
+
+
+class TestCacheProperties:
+ """Test cache property methods."""
+
+ def test_is_initialized_attention(self, attention_config):
+ """is_initialized True after update."""
+ cache = Apriel2Cache(attention_config)
+ assert not cache.is_initialized
+
+ cache.update(torch.randn(2, 4, 5, 16), torch.randn(2, 4, 5, 16), layer_idx=0)
+ assert cache.is_initialized
+
+ def test_is_initialized_ssm(self, ssm_config):
+ """is_initialized True after setting conv state."""
+ cache = Apriel2Cache(ssm_config)
+ assert not cache.is_initialized
+
+ cache.layers[0].conv = torch.randn(2, 64, 4)
+ assert cache.is_initialized
+
+ def test_has_previous_state_ssm_only(self, ssm_config):
+ """has_previous_state checks SSM conv states."""
+ cache = Apriel2Cache(ssm_config)
+ assert not cache.has_previous_state
+
+ cache.layers[0].conv = torch.randn(2, 64, 4)
+ assert cache.has_previous_state
+
+ def test_has_previous_state_ignores_attention(self, attention_config):
+ """has_previous_state ignores attention caches."""
+ cache = Apriel2Cache(attention_config)
+ cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0)
+
+ # Attention-only cache returns False for has_previous_state
+ assert not cache.has_previous_state
+
+ def test_reset_clears_ssm_states(self, ssm_config):
+ """reset() clears SSM conv and recurrent states."""
+ cache = Apriel2Cache(ssm_config)
+ cache.layers[0].conv = torch.randn(2, 64, 4)
+ cache.layers[0].recurrent = torch.randn(2, 64, 16)
+
+ cache.reset()
+
+ assert cache.layers[0].conv is None
+ assert cache.layers[0].recurrent is None
+
+ def test_max_batch_size_from_ssm_tuple(self, ssm_config):
+ """max_batch_size works with KDA tuple conv states."""
+ cache = Apriel2Cache(ssm_config)
+ cache.layers[0].conv = (torch.randn(3, 64, 4),) * 3
+
+ assert cache.max_batch_size == 3
+
+ def test_max_batch_size(self, attention_config):
+ """max_batch_size returns batch dimension."""
+ cache = Apriel2Cache(attention_config)
+ cache.update(torch.randn(3, 4, 10, 16), torch.randn(3, 4, 10, 16), layer_idx=0)
+
+ assert cache.max_batch_size == 3
+
+ def test_len_returns_num_layers(self, attention_config):
+ """__len__ returns number of layers."""
+ cache = Apriel2Cache(attention_config)
+ assert len(cache) == 2
+
+
+# =============================================================================
+# INDEXING
+# =============================================================================
+
+
+class TestCacheIndexing:
+ """Test __getitem__ for HF compatibility."""
+
+ def test_getitem_returns_kv_tuple(self, attention_config):
+ """cache[idx] returns (key, value) tuple."""
+ cache = Apriel2Cache(attention_config)
+ cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0)
+
+ k, v = cache[0]
+ assert k.shape == (2, 4, 10, 16)
+ assert v.shape == (2, 4, 10, 16)
+
+ def test_getitem_empty_returns_empty_tensors(self, attention_config):
+ """cache[idx] on empty cache returns empty tensors."""
+ cache = Apriel2Cache(attention_config)
+
+ k, v = cache[0]
+ assert k.numel() == 0
+ assert v.numel() == 0
diff --git a/fast_llm_external_models/tests/test_apriel2/test_cache_contracts.py b/fast_llm_external_models/tests/test_apriel2/test_cache_contracts.py
new file mode 100644
index 000000000..7c38f75b7
--- /dev/null
+++ b/fast_llm_external_models/tests/test_apriel2/test_cache_contracts.py
@@ -0,0 +1,592 @@
+"""Contract tests for Apriel2Cache against HuggingFace cache implementations.
+
+This module tests that Apriel2Cache components behave equivalently to their
+HuggingFace counterparts. This ensures compatibility with HF's generation
+infrastructure (mask creation, beam search, etc.).
+
+Mapping:
+ Apriel2 Component HuggingFace Equivalent
+ ----------------- ----------------------
+ _AttentionCache (no window) -> DynamicLayer
+ _AttentionCache (window) -> DynamicSlidingWindowLayer
+ _SSMCache -> MambaCache (different interface, same concept)
+
+Apriel2-specific features (stochastic routing, multi-mixer layers) are tested
+separately in test_cache_apriel2_specific.py since they have no HF equivalent.
+
+Fixtures used from conftest.py:
+ - batch_size, num_heads, head_dim: Tensor dimensions
+ - hf_dynamic_layer: HuggingFace DynamicLayer
+ - hf_sliding_layer: HuggingFace DynamicSlidingWindowLayer (parameterized by window_size)
+ - apriel_attention_cache: Apriel2 _AttentionCache (no window)
+ - apriel_sliding_cache: Apriel2 _AttentionCache (with window, parameterized)
+ - window_size: Parameterized window sizes [4, 8, 16, 32]
+ - attention_config, swa_config: Apriel2 configs
+"""
+
+import pytest
+import torch
+
+from fast_llm_external_models.apriel2.cache import _AttentionCache, _SSMCache, Apriel2Cache
+
+
+# =============================================================================
+# SECTION 1: FULL ATTENTION - _AttentionCache vs DynamicLayer
+# =============================================================================
+
+
+class TestFullAttentionContract:
+ """Test _AttentionCache (no window) matches HuggingFace DynamicLayer.
+
+ DynamicLayer is the standard cache for full causal attention.
+ We test that our cache produces identical mask parameters.
+ """
+
+ # -------------------------------------------------------------------------
+ # get_seq_length: Must match exactly for generation to work
+ # -------------------------------------------------------------------------
+
+ @pytest.mark.parametrize("seq_len", [1, 5, 10, 50, 100])
+ def test_get_seq_length_after_prefill(
+ self, hf_dynamic_layer, apriel_attention_cache, batch_size, num_heads, head_dim, seq_len
+ ):
+ """After prefill, cumulative_length matches HF get_seq_length."""
+ key = torch.randn(batch_size, num_heads, seq_len, head_dim)
+ value = torch.randn(batch_size, num_heads, seq_len, head_dim)
+
+ hf_dynamic_layer.update(key.clone(), value.clone())
+ apriel_attention_cache.update(key.clone(), value.clone())
+
+ assert apriel_attention_cache.cumulative_length == hf_dynamic_layer.get_seq_length()
+
+ @pytest.mark.parametrize("prefill_len", [1, 5, 10])
+ @pytest.mark.parametrize("decode_steps", [1, 5, 10, 20])
+ def test_get_seq_length_during_decode(
+ self, hf_dynamic_layer, apriel_attention_cache, batch_size, num_heads, head_dim, prefill_len, decode_steps
+ ):
+ """During decode, cumulative_length tracks total tokens seen."""
+ # Prefill
+ key = torch.randn(batch_size, num_heads, prefill_len, head_dim)
+ value = torch.randn(batch_size, num_heads, prefill_len, head_dim)
+ hf_dynamic_layer.update(key.clone(), value.clone())
+ apriel_attention_cache.update(key.clone(), value.clone())
+
+ # Decode
+ for step in range(decode_steps):
+ key = torch.randn(batch_size, num_heads, 1, head_dim)
+ value = torch.randn(batch_size, num_heads, 1, head_dim)
+ hf_dynamic_layer.update(key.clone(), value.clone())
+ apriel_attention_cache.update(key.clone(), value.clone())
+
+ assert apriel_attention_cache.cumulative_length == hf_dynamic_layer.get_seq_length(), (
+ f"Mismatch at decode step {step}"
+ )
+
+ # -------------------------------------------------------------------------
+ # get_mask_sizes: Verify HF behavior for documentation
+ # -------------------------------------------------------------------------
+
+ @pytest.mark.parametrize("prefill_len", [1, 5, 10])
+ @pytest.mark.parametrize("decode_steps", [0, 1, 5, 10])
+ def test_hf_mask_sizes_kv_length(
+ self, hf_dynamic_layer, apriel_attention_cache, batch_size, num_heads, head_dim, prefill_len, decode_steps
+ ):
+ """Document HF's kv_length behavior and verify cumulative_length tracks correctly.
+
+ For full attention, kv_length = cumulative_length + query_length.
+ This test verifies our cache tracks tokens identically to HF.
+ """
+ # Prefill
+ key = torch.randn(batch_size, num_heads, prefill_len, head_dim)
+ value = torch.randn(batch_size, num_heads, prefill_len, head_dim)
+ hf_dynamic_layer.update(key.clone(), value.clone())
+ apriel_attention_cache.update(key.clone(), value.clone())
+
+ # Decode
+ for _ in range(decode_steps):
+ key = torch.randn(batch_size, num_heads, 1, head_dim)
+ value = torch.randn(batch_size, num_heads, 1, head_dim)
+ hf_dynamic_layer.update(key.clone(), value.clone())
+ apriel_attention_cache.update(key.clone(), value.clone())
+
+ # Verify cumulative_length matches HF
+ assert apriel_attention_cache.cumulative_length == hf_dynamic_layer.get_seq_length()
+
+ # Verify HF's kv_length follows the expected formula
+ cache_position = torch.arange(1) # Single token decode
+ hf_kv_len, hf_kv_offset = hf_dynamic_layer.get_mask_sizes(cache_position)
+ expected_kv_len = hf_dynamic_layer.get_seq_length() + cache_position.shape[0]
+ assert hf_kv_len == expected_kv_len
+
+ def test_hf_kv_offset_always_zero(self, hf_dynamic_layer, batch_size, num_heads, head_dim):
+ """Document that HF DynamicLayer always returns kv_offset=0.
+
+ For full attention, all cached KV pairs map to absolute positions
+ starting from 0, so kv_offset is always 0.
+ """
+ # Add many tokens
+ for _ in range(20):
+ key = torch.randn(batch_size, num_heads, 5, head_dim)
+ value = torch.randn(batch_size, num_heads, 5, head_dim)
+ hf_dynamic_layer.update(key.clone(), value.clone())
+
+ cache_position = torch.arange(1)
+ _, hf_kv_offset = hf_dynamic_layer.get_mask_sizes(cache_position)
+
+ assert hf_kv_offset == 0, "DynamicLayer always returns kv_offset=0"
+
+ # -------------------------------------------------------------------------
+ # update: Output shape and values must match
+ # -------------------------------------------------------------------------
+
+ @pytest.mark.parametrize("seq_len", [1, 5, 10])
+ def test_update_returns_same_shape(
+ self, hf_dynamic_layer, apriel_attention_cache, batch_size, num_heads, head_dim, seq_len
+ ):
+ """update() returns tensors with matching shapes."""
+ key = torch.randn(batch_size, num_heads, seq_len, head_dim)
+ value = torch.randn(batch_size, num_heads, seq_len, head_dim)
+
+ hf_k, hf_v = hf_dynamic_layer.update(key.clone(), value.clone())
+ apr_k, apr_v = apriel_attention_cache.update(key.clone(), value.clone())
+
+ assert hf_k.shape == apr_k.shape
+ assert hf_v.shape == apr_v.shape
+
+ def test_update_concatenates_identically(
+ self, hf_dynamic_layer, apriel_attention_cache, batch_size, num_heads, head_dim
+ ):
+ """Multiple updates produce identical concatenated states."""
+ # Use deterministic values for comparison
+ k1 = torch.arange(10).float().view(1, 1, 10, 1).expand(batch_size, num_heads, 10, head_dim)
+ v1 = k1.clone()
+
+ hf_dynamic_layer.update(k1.clone(), v1.clone())
+ apriel_attention_cache.update(k1.clone(), v1.clone())
+
+ k2 = torch.arange(10, 15).float().view(1, 1, 5, 1).expand(batch_size, num_heads, 5, head_dim)
+ v2 = k2.clone()
+
+ hf_k, hf_v = hf_dynamic_layer.update(k2.clone(), v2.clone())
+ apr_k, apr_v = apriel_attention_cache.update(k2.clone(), v2.clone())
+
+ torch.testing.assert_close(hf_k, apr_k)
+ torch.testing.assert_close(hf_v, apr_v)
+
+
+# =============================================================================
+# SECTION 2: SLIDING WINDOW - _AttentionCache vs DynamicSlidingWindowLayer
+# =============================================================================
+
+
+class TestSlidingWindowContract:
+ """Test _AttentionCache (with window) matches HuggingFace DynamicSlidingWindowLayer.
+
+ DynamicSlidingWindowLayer is used for sliding window attention (e.g., Mistral).
+ Critical behaviors:
+ - cumulative_length tracks ALL tokens seen (not just cached)
+ - kv_offset increases once window is exceeded
+ - kv_length is capped at window size
+
+ Uses fixtures from conftest.py:
+ - window_size: parameterized [4, 8, 16, 32]
+ - hf_sliding_layer: DynamicSlidingWindowLayer
+ - apriel_sliding_cache: _AttentionCache with window
+ """
+
+ # -------------------------------------------------------------------------
+ # cumulative_length: Must track total tokens, not cached tokens
+ # -------------------------------------------------------------------------
+
+ @pytest.mark.parametrize("prefill_len", [1, 3, 5, 10, 20])
+ def test_cumulative_length_matches_after_prefill(
+ self, hf_sliding_layer, apriel_sliding_cache, batch_size, num_heads, head_dim, prefill_len
+ ):
+ """cumulative_length matches HF get_seq_length after prefill."""
+ key = torch.randn(batch_size, num_heads, prefill_len, head_dim)
+ value = torch.randn(batch_size, num_heads, prefill_len, head_dim)
+
+ hf_sliding_layer.update(key.clone(), value.clone())
+ apriel_sliding_cache.update(key.clone(), value.clone())
+
+ assert apriel_sliding_cache.cumulative_length == hf_sliding_layer.get_seq_length()
+
+ def test_cumulative_length_continues_past_window(
+ self, hf_sliding_layer, apriel_sliding_cache, window_size, batch_size, num_heads, head_dim
+ ):
+ """cumulative_length keeps growing even after window is full."""
+ total_tokens = window_size * 3 # Way past window
+
+ for i in range(total_tokens):
+ key = torch.randn(batch_size, num_heads, 1, head_dim)
+ value = torch.randn(batch_size, num_heads, 1, head_dim)
+
+ hf_sliding_layer.update(key.clone(), value.clone())
+ apriel_sliding_cache.update(key.clone(), value.clone())
+
+ expected = i + 1
+ assert apriel_sliding_cache.cumulative_length == expected
+ assert hf_sliding_layer.get_seq_length() == expected
+
+ # -------------------------------------------------------------------------
+ # get_mask_sizes: kv_offset must increase once window is exceeded
+ # -------------------------------------------------------------------------
+
+ def test_kv_offset_zero_before_window_full(
+ self, hf_sliding_layer, apriel_sliding_cache, window_size, batch_size, num_heads, head_dim
+ ):
+ """kv_offset is 0 while cumulative < window.
+
+ Before the window is full, kv_offset should be 0 because all cached tokens
+ correspond to absolute positions starting from 0.
+ """
+ # Add tokens up to window-1
+ for i in range(window_size - 1):
+ key = torch.randn(batch_size, num_heads, 1, head_dim)
+ value = torch.randn(batch_size, num_heads, 1, head_dim)
+
+ hf_sliding_layer.update(key.clone(), value.clone())
+ apriel_sliding_cache.update(key.clone(), value.clone())
+
+ cache_position = torch.arange(1)
+ hf_kv_len, hf_kv_offset = hf_sliding_layer.get_mask_sizes(cache_position)
+
+ # Verify HF returns 0 offset before window full
+ assert hf_kv_offset == 0, f"HF offset should be 0 at step {i}"
+ # Verify Apriel cache tracks cumulative correctly
+ assert apriel_sliding_cache.cumulative_length == i + 1
+
+ def test_kv_offset_increases_after_window_full(
+ self, hf_sliding_layer, apriel_sliding_cache, window_size, batch_size, num_heads, head_dim
+ ):
+ """kv_offset increases once cumulative >= window.
+
+ Once the window is full, the cache discards oldest tokens. kv_offset tracks
+ which absolute position KV[0] corresponds to.
+ """
+ # Fill to exactly window
+ for _ in range(window_size):
+ key = torch.randn(batch_size, num_heads, 1, head_dim)
+ value = torch.randn(batch_size, num_heads, 1, head_dim)
+ hf_sliding_layer.update(key.clone(), value.clone())
+ apriel_sliding_cache.update(key.clone(), value.clone())
+
+ cache_position = torch.arange(1)
+ hf_kv_len, hf_kv_offset = hf_sliding_layer.get_mask_sizes(cache_position)
+
+ # At window boundary, offset should be 1
+ assert hf_kv_offset == 1, "HF offset should be 1 at window boundary"
+ assert apriel_sliding_cache.cumulative_length == window_size
+
+ # Add more tokens and verify offset keeps increasing with HF
+ for i in range(5):
+ key = torch.randn(batch_size, num_heads, 1, head_dim)
+ value = torch.randn(batch_size, num_heads, 1, head_dim)
+ hf_sliding_layer.update(key.clone(), value.clone())
+ apriel_sliding_cache.update(key.clone(), value.clone())
+
+ hf_kv_len, hf_kv_offset = hf_sliding_layer.get_mask_sizes(cache_position)
+
+ expected_offset = i + 2
+ assert hf_kv_offset == expected_offset
+ assert apriel_sliding_cache.cumulative_length == window_size + i + 1
+
+ def test_kv_length_capped_at_window(
+ self, hf_sliding_layer, apriel_sliding_cache, window_size, batch_size, num_heads, head_dim
+ ):
+ """kv_length is capped at window size once exceeded.
+
+ For a query of length 1 after the window is full, kv_length = window
+ (window-1 cached tokens + 1 query token).
+ """
+ # Way past window
+ for _ in range(window_size * 2):
+ key = torch.randn(batch_size, num_heads, 1, head_dim)
+ value = torch.randn(batch_size, num_heads, 1, head_dim)
+ hf_sliding_layer.update(key.clone(), value.clone())
+ apriel_sliding_cache.update(key.clone(), value.clone())
+
+ cache_position = torch.arange(1)
+ hf_kv_len, _ = hf_sliding_layer.get_mask_sizes(cache_position)
+
+ # HF returns window (window-1 cached + 1 query)
+ assert hf_kv_len == window_size
+ # Verify our cache tracked cumulative correctly
+ assert apriel_sliding_cache.cumulative_length == window_size * 2
+
+ # -------------------------------------------------------------------------
+ # Full sequence length tracking through generation
+ # -------------------------------------------------------------------------
+
+ @pytest.mark.parametrize("prefill_len", [1, 3, 5, 10, 20])
+ def test_cumulative_length_tracks_all_tokens(
+ self, hf_sliding_layer, apriel_sliding_cache, window_size, batch_size, num_heads, head_dim, prefill_len
+ ):
+ """cumulative_length tracks total tokens seen through prefill + decode.
+
+ This is the foundation for correct mask size computation. We verify that
+ our _AttentionCache tracks tokens identically to HuggingFace's DynamicSlidingWindowLayer.
+ The actual get_mask_sizes computation is tested in TestApriel2CacheIntegration.
+ """
+ # Prefill
+ key = torch.randn(batch_size, num_heads, prefill_len, head_dim)
+ value = torch.randn(batch_size, num_heads, prefill_len, head_dim)
+ hf_sliding_layer.update(key.clone(), value.clone())
+ apriel_sliding_cache.update(key.clone(), value.clone())
+
+ assert apriel_sliding_cache.cumulative_length == hf_sliding_layer.get_seq_length()
+
+ # Decode past window
+ for i in range(window_size + 10):
+ key = torch.randn(batch_size, num_heads, 1, head_dim)
+ value = torch.randn(batch_size, num_heads, 1, head_dim)
+ hf_sliding_layer.update(key.clone(), value.clone())
+ apriel_sliding_cache.update(key.clone(), value.clone())
+
+ assert apriel_sliding_cache.cumulative_length == hf_sliding_layer.get_seq_length(), (
+ f"cumulative_length mismatch at step {i}"
+ )
+
+
+# =============================================================================
+# SECTION 3: SSM CACHE - _SSMCache vs MambaCache concept
+# =============================================================================
+
+
+class TestSSMCacheContract:
+ """Document _SSMCache interface and verify basic contract.
+
+ Unlike attention caches which have HF equivalents (DynamicLayer, DynamicSlidingWindowLayer),
+ SSM caches have no direct HF counterpart with matching interface. HF's MambaCache uses
+ different methods (update_conv_state, update_ssm_state), so we can't do direct comparison.
+
+ These tests document the interface contract:
+ 1. `conv` and `recurrent` attributes for storing states
+ 2. Both support None (lazy initialization)
+ 3. `conv` can be tuple (for KDA which has separate q/k/v conv states)
+
+ Higher-level operations (reorder, batch_repeat, reset) are tested in
+ TestBeamSearchOperations in test_cache_apriel2_specific.py.
+ """
+
+ def test_conv_state_storage(self, ssm_cache):
+ """conv attribute stores conv states (batch, intermediate, kernel_size)."""
+ conv = torch.randn(2, 64, 4)
+ ssm_cache.conv = conv
+ torch.testing.assert_close(ssm_cache.conv, conv)
+
+ def test_recurrent_state_storage(self, ssm_cache):
+ """recurrent attribute stores SSM states (batch, intermediate, state_size)."""
+ recurrent = torch.randn(2, 64, 16)
+ ssm_cache.recurrent = recurrent
+ torch.testing.assert_close(ssm_cache.recurrent, recurrent)
+
+ def test_conv_state_tuple_for_kda(self, ssm_cache):
+ """conv can be tuple for KDA's separate q/k/v convolutions."""
+ conv_tuple = (torch.randn(2, 64, 4), torch.randn(2, 64, 4), torch.randn(2, 64, 4))
+ ssm_cache.conv = conv_tuple
+ assert isinstance(ssm_cache.conv, tuple)
+ assert len(ssm_cache.conv) == 3
+
+ def test_initial_states_none(self, ssm_cache):
+ """States are None initially (lazy initialization pattern)."""
+ assert ssm_cache.conv is None
+ assert ssm_cache.recurrent is None
+
+ def test_states_independent(self, ssm_cache):
+ """conv and recurrent states are independent."""
+ ssm_cache.conv = torch.randn(2, 64, 4)
+ assert ssm_cache.recurrent is None # recurrent unchanged
+
+ ssm_cache.recurrent = torch.randn(2, 64, 16)
+ assert ssm_cache.conv is not None # conv unchanged
+
+
+# =============================================================================
+# SECTION 4: APRIEL2CACHE INTEGRATION
+# =============================================================================
+
+
+class TestApriel2CacheIntegration:
+ """Test Apriel2Cache correctly delegates to underlying caches.
+
+ Uses fixtures from conftest.py:
+ - attention_config: Pure attention config
+ - swa_config: Sliding window attention config (window=8)
+ """
+
+ def test_get_seq_length_matches_dynamic_layer(self, attention_config):
+ """Apriel2Cache.get_seq_length matches DynamicLayer for full attention."""
+ from transformers.cache_utils import DynamicLayer
+
+ cache = Apriel2Cache(attention_config)
+ hf_layer = DynamicLayer()
+
+ key = torch.randn(2, 4, 10, 16)
+ value = torch.randn(2, 4, 10, 16)
+
+ cache.update(key.clone(), value.clone(), layer_idx=0)
+ hf_layer.update(key.clone(), value.clone())
+
+ assert cache.get_seq_length(0) == hf_layer.get_seq_length()
+
+ def test_get_mask_sizes_matches_dynamic_layer(self, attention_config):
+ """Apriel2Cache.get_mask_sizes matches DynamicLayer."""
+ from transformers.cache_utils import DynamicLayer
+
+ cache = Apriel2Cache(attention_config)
+ hf_layer = DynamicLayer()
+
+ key = torch.randn(2, 4, 10, 16)
+ value = torch.randn(2, 4, 10, 16)
+
+ cache.update(key.clone(), value.clone(), layer_idx=0)
+ hf_layer.update(key.clone(), value.clone())
+
+ cache_position = torch.arange(1)
+ hf_kv_len, hf_kv_offset = hf_layer.get_mask_sizes(cache_position)
+ apr_kv_len, apr_kv_offset = cache.get_mask_sizes(cache_position, layer_idx=0)
+
+ assert apr_kv_len == hf_kv_len
+ assert apr_kv_offset == hf_kv_offset
+
+ def test_get_mask_sizes_matches_sliding_layer(self, swa_config):
+ """Apriel2Cache.get_mask_sizes matches DynamicSlidingWindowLayer."""
+ from transformers.cache_utils import DynamicSlidingWindowLayer
+
+ cache = Apriel2Cache(swa_config)
+ hf_layer = DynamicSlidingWindowLayer(sliding_window=8)
+
+ # Fill past window
+ for _ in range(15):
+ key = torch.randn(2, 4, 1, 16)
+ value = torch.randn(2, 4, 1, 16)
+ cache.update(key.clone(), value.clone(), layer_idx=0)
+ hf_layer.update(key.clone(), value.clone())
+
+ cache_position = torch.arange(1)
+ hf_kv_len, hf_kv_offset = hf_layer.get_mask_sizes(cache_position)
+ apr_kv_len, apr_kv_offset = cache.get_mask_sizes(cache_position, layer_idx=0)
+
+ assert apr_kv_len == hf_kv_len
+ assert apr_kv_offset == hf_kv_offset
+
+ def test_reset_clears_cumulative_length(self, attention_config):
+ """reset() clears cumulative_length (matches DynamicLayer.reset)."""
+ cache = Apriel2Cache(attention_config)
+
+ cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0)
+ assert cache.get_seq_length(0) == 10
+
+ cache.reset()
+ assert cache.get_seq_length(0) == 0
+
+
+# =============================================================================
+# SECTION 5: MASK CORRECTNESS (SEMANTIC TESTS)
+# =============================================================================
+
+
+class TestMaskCorrectness:
+ """Test that mask parameters produce semantically correct masks.
+
+ These tests verify the END RESULT: masks created with our parameters
+ allow the correct attention patterns.
+ """
+
+ def test_full_attention_decode_can_attend_to_all(self):
+ """During decode, query can attend to all cached positions."""
+ from transformers.masking_utils import sdpa_mask, causal_mask_function
+
+ cache = _AttentionCache(window=None)
+
+ # Prefill + decode
+ for _ in range(10):
+ cache.update(torch.randn(1, 1, 1, 16), torch.randn(1, 1, 1, 16))
+
+ # Mask for decode step
+ cache_position = torch.tensor([10]) # Position of new token
+ kv_length = cache.cumulative_length + 1
+ kv_offset = 0
+
+ mask = sdpa_mask(
+ batch_size=1,
+ cache_position=cache_position,
+ kv_length=kv_length,
+ kv_offset=kv_offset,
+ mask_function=causal_mask_function,
+ )
+
+ if mask is not None:
+ # Query at position 10 should attend to positions 0-10
+ query_mask = mask[0, 0, 0, :]
+ for kv_idx in range(kv_length):
+ assert query_mask[kv_idx].item() == True, f"Should attend to position {kv_idx}"
+
+ @pytest.mark.parametrize("window_size", [4, 8, 16])
+ def test_sliding_window_decode_respects_window(self, window_size):
+ """During decode, query only attends within sliding window."""
+ from transformers.masking_utils import sdpa_mask, sliding_window_causal_mask_function
+
+ cache = _AttentionCache(window=window_size)
+
+ # Fill way past window
+ total_tokens = window_size * 2
+ for _ in range(total_tokens):
+ cache.update(torch.randn(1, 1, 1, 16), torch.randn(1, 1, 1, 16))
+
+ # Mask for decode step
+ cache_position = torch.tensor([total_tokens])
+ cumulative = cache.cumulative_length
+ kv_offset = max(cumulative - window_size + 1, 0)
+ kv_length = window_size - 1 + 1 # cached + query
+
+ mask = sdpa_mask(
+ batch_size=1,
+ cache_position=cache_position,
+ kv_length=kv_length,
+ kv_offset=kv_offset,
+ mask_function=sliding_window_causal_mask_function(window_size),
+ )
+
+ if mask is not None:
+ query_mask = mask[0, 0, 0, :]
+ query_pos = cache_position[0].item()
+
+ for kv_idx in range(kv_length):
+ abs_pos = kv_offset + kv_idx
+ in_window = abs_pos > query_pos - window_size
+ causal = abs_pos <= query_pos
+ expected = in_window and causal
+
+ assert query_mask[kv_idx].item() == expected, (
+ f"Position {abs_pos}: expected {expected}, got {query_mask[kv_idx].item()}"
+ )
+
+ def test_prefill_has_causal_pattern(self):
+ """During prefill, mask has proper causal (lower triangular) pattern."""
+ from transformers.masking_utils import sdpa_mask, causal_mask_function
+
+ cache = _AttentionCache(window=None)
+ cache.update(torch.randn(1, 1, 5, 16), torch.randn(1, 1, 5, 16))
+
+ cache_position = torch.arange(5)
+ kv_length = cache.cumulative_length
+ kv_offset = 0
+
+ mask = sdpa_mask(
+ batch_size=1,
+ cache_position=cache_position,
+ kv_length=kv_length,
+ kv_offset=kv_offset,
+ mask_function=causal_mask_function,
+ allow_is_causal_skip=False, # Force mask creation
+ )
+
+ if mask is not None:
+ # Check causal pattern
+ for q_idx in range(5):
+ for kv_idx in range(5):
+ expected = kv_idx <= q_idx
+ actual = mask[0, 0, q_idx, kv_idx].item()
+ assert actual == expected, f"q={q_idx}, kv={kv_idx}: expected {expected}"
diff --git a/fast_llm_external_models/tests/test_apriel2/test_compose_configs.py b/fast_llm_external_models/tests/test_apriel2/test_compose_configs.py
index 0bd6ac88d..b1ee15d54 100644
--- a/fast_llm_external_models/tests/test_apriel2/test_compose_configs.py
+++ b/fast_llm_external_models/tests/test_apriel2/test_compose_configs.py
@@ -75,14 +75,10 @@ def source_config(self):
},
}
- def test_identity_empty_surgery(self, source_config):
- """Law 1: compose_configs(config, {}) == config"""
- result = compose_configs(source_config, {})
- assert result == source_config
-
- def test_identity_none_surgery(self, source_config):
- """Law 1: compose_configs(config, None) == config"""
- result = compose_configs(source_config, None)
+ @pytest.mark.parametrize("empty_surgery", [{}, None])
+ def test_identity(self, source_config, empty_surgery):
+ """Law 1: compose_configs(config, empty) == config for empty in [{}, None]"""
+ result = compose_configs(source_config, empty_surgery)
assert result == source_config
def test_override_explicit_values(self, source_config):
@@ -114,7 +110,7 @@ def test_same_type_inheritance(self, source_config):
assert mixer["head_size"] == 32 # Inherited
assert mixer["rope_theta"] == 10000.0 # Inherited
assert mixer["window_size"] == 512 # Added
- assert "init" not in mixer # Stripped by apply_surgery
+ # init is preserved for plan_surgery to see (stripped only at final output)
def test_cross_type_attention_to_gdn(self, source_config):
"""Law 5: attention→gdn derives GDN dims from attention geometry."""
@@ -239,8 +235,14 @@ def test_null_deletion(self, source_config):
assert "vision_encoder" not in result
- def test_init_stripped_from_result(self, source_config):
- """Verify `init` keys are stripped from final result."""
+ def test_init_preserved_for_plan_surgery(self, source_config):
+ """Verify `init` keys are preserved so plan_surgery can see them.
+
+ The `init` field controls weight initialization (transfer vs random).
+ It's preserved through composition and only stripped at final output.
+ """
+ from fast_llm_external_models.apriel2.conversion.config import strip_init_fields
+
surgery = {
"decoder": {
"block": {
@@ -252,20 +254,20 @@ def test_init_stripped_from_result(self, source_config):
"gdn": {"type": "gdn", "init": "random", "convolution_layer": {"kernel_size": 4}},
},
},
- "mlp": {"init": "transfer"},
- "normalization": {"init": "transfer"},
},
},
}
result = compose_configs(source_config, surgery)
- def check_no_init(d, path=""):
- assert "init" not in d, f"Found 'init' key at {path}"
- for k, v in d.items():
- if isinstance(v, dict):
- check_no_init(v, f"{path}.{k}")
+ # init is preserved in composed config
+ mixers = result["decoder"]["block"]["mixer"]["mixers"]
+ assert mixers["attention"].get("init") == "transfer"
+ assert mixers["gdn"].get("init") == "random"
- check_no_init(result)
+ # strip_init_fields removes them for final output
+ stripped = strip_init_fields(result)
+ assert "init" not in stripped["decoder"]["block"]["mixer"]["mixers"]["attention"]
+ assert "init" not in stripped["decoder"]["block"]["mixer"]["mixers"]["gdn"]
def test_init_random_still_inherits_config(self, source_config):
"""init: random is for weights only - config params still inherited."""
@@ -287,6 +289,206 @@ def test_init_random_still_inherits_config(self, source_config):
assert mixer["head_groups"] == 4
assert mixer["window_size"] == 512
+ # =========================================================================
+ # Monoid Laws: compose_configs forms a monoid action on configs
+ # =========================================================================
+
+ def test_surgery_monoid_associativity(self):
+ """MONOID: merge(merge(A, B), C) == merge(A, merge(B, C)) for partial configs."""
+ surgery_a = {"decoder": {"block": {"mixer": {"type": "stochastic", "main_mixer_name": "attention"}}}}
+ surgery_b = {"decoder": {"block": {"mixer": {"mixers": {"sliding_window": {"window_size": 512}}}}}}
+ surgery_c = {"decoder": {"block": {"mixer": {"mixers": {"gdn": {"type": "gdn"}}}}}}
+
+ # Left-associated: (A ∘ B) ∘ C
+ ab_c = compose_configs(compose_configs(surgery_a, surgery_b), surgery_c)
+ # Right-associated: A ∘ (B ∘ C)
+ a_bc = compose_configs(surgery_a, compose_configs(surgery_b, surgery_c))
+
+ assert ab_c == a_bc, "Surgery monoid should be associative"
+
+ @pytest.mark.parametrize("num_surgeries", [2, 3])
+ def test_monoid_action_compatibility(self, source_config, num_surgeries):
+ """MONOID ACTION: apply(apply(c, A), B) == apply(c, merge(A, B))
+
+ This is the key law: applying surgeries sequentially equals merging first.
+ Parameterized to test with 2 and 3 surgeries.
+ """
+ surgeries = [
+ {"decoder": {"block": {"mixer": {"type": "stochastic", "main_mixer_name": "attention", "mixers": {"attention": {}}}}}},
+ {"decoder": {"block": {"mixer": {"mixers": {"sliding_window": {"window_size": 512}}}}}},
+ {"decoder": {"block": {"mixer": {"mixers": {"gdn": {"type": "gdn"}}}}}},
+ ][:num_surgeries]
+
+ # Sequential: ((c ⊳ A) ⊳ B) ⊳ ...
+ result_sequential = source_config
+ for s in surgeries:
+ result_sequential = compose_configs(result_sequential, s)
+
+ # Merged: c ⊳ (A ∘ B ∘ ...)
+ merged = surgeries[0]
+ for s in surgeries[1:]:
+ merged = compose_configs(merged, s)
+ result_merged = compose_configs(source_config, merged)
+
+ assert result_sequential == result_merged, f"Monoid action compatibility failed for {num_surgeries} surgeries"
+
+
+class TestBiasConfigInheritance:
+ """Test per-layer bias inheritance through surgery composition.
+
+ These tests verify that the per-layer bias configuration (mirroring Fast-LLM's
+ AffineLinearConfig) is correctly inherited through surgery operations:
+ - query_layer.bias.enabled, key_layer.bias.enabled, etc. for attention
+ - layer_1.bias.enabled, layer_2.bias.enabled for MLP
+ """
+
+ @pytest.fixture
+ def source_config_with_bias(self):
+ """Source config with Qwen-style bias (QKV enabled, O disabled)."""
+ return {
+ "model_type": "apriel2",
+ "architectures": ["Apriel2ForCausalLM"],
+ "hidden_size": 256,
+ "vocab_size": 1000,
+ "decoder": {
+ "type": "fixed",
+ "num_blocks": 4,
+ "block": {
+ "mixer": {
+ "type": "attention",
+ "heads": 8,
+ "head_groups": 4,
+ "head_size": 32,
+ "rotary": {"type": "mistral_1d", "theta": 10000.0},
+ # Qwen-style per-layer bias
+ "query_layer": {"bias": {"enabled": True}},
+ "key_layer": {"bias": {"enabled": True}},
+ "value_layer": {"bias": {"enabled": True}},
+ "dense_layer": {"bias": {"enabled": False}},
+ },
+ "mlp": {
+ "type": "mlp",
+ "intermediate_size": 512,
+ "gated": False,
+ # Per-layer MLP bias
+ "layer_1": {"bias": {"enabled": True}},
+ "layer_2": {"bias": {"enabled": False}},
+ },
+ "normalization": {"type": "rms_norm", "epsilon": 1e-5},
+ },
+ },
+ }
+
+ def test_same_type_inherits_attention_bias(self, source_config_with_bias):
+ """Same-type surgery inherits per-layer attention bias settings."""
+ surgery = {
+ "decoder": {
+ "block": {
+ "mixer": {
+ "window_size": 512, # Add sliding window behavior
+ },
+ },
+ },
+ }
+ result = compose_configs(source_config_with_bias, surgery)
+
+ mixer = result["decoder"]["block"]["mixer"]
+ assert mixer["query_layer"]["bias"]["enabled"] is True
+ assert mixer["key_layer"]["bias"]["enabled"] is True
+ assert mixer["value_layer"]["bias"]["enabled"] is True
+ assert mixer["dense_layer"]["bias"]["enabled"] is False
+
+ def test_same_type_inherits_mlp_bias(self, source_config_with_bias):
+ """Same-type surgery inherits per-layer MLP bias settings."""
+ surgery = {
+ "decoder": {
+ "block": {
+ "mlp": {
+ "intermediate_size": 1024, # Change size
+ },
+ },
+ },
+ }
+ result = compose_configs(source_config_with_bias, surgery)
+
+ mlp = result["decoder"]["block"]["mlp"]
+ assert mlp["layer_1"]["bias"]["enabled"] is True
+ assert mlp["layer_2"]["bias"]["enabled"] is False
+ assert mlp["intermediate_size"] == 1024
+
+ def test_cross_type_attention_to_sliding_window_preserves_bias(self, source_config_with_bias):
+ """attention→sliding_window cross-type preserves per-layer bias."""
+ surgery = {
+ "decoder": {
+ "block": {
+ "mixer": {
+ "type": "sliding_window", # Cross-type derivation
+ "window_size": 512,
+ },
+ },
+ },
+ }
+ result = compose_configs(source_config_with_bias, surgery)
+
+ mixer = result["decoder"]["block"]["mixer"]
+ assert mixer["type"] == "sliding_window"
+ # Bias settings preserved through cross-type
+ assert mixer["query_layer"]["bias"]["enabled"] is True
+ assert mixer["key_layer"]["bias"]["enabled"] is True
+ assert mixer["value_layer"]["bias"]["enabled"] is True
+ assert mixer["dense_layer"]["bias"]["enabled"] is False
+
+ def test_stochastic_wrapper_inherits_bias(self, source_config_with_bias):
+ """Wrapping in stochastic inherits bias settings to all sub-mixers."""
+ surgery = {
+ "decoder": {
+ "block": {
+ "mixer": {
+ "type": "stochastic",
+ "main_mixer_name": "attention",
+ "mixers": {
+ "attention": {"init": "transfer"},
+ "sliding_window": {
+ "type": "sliding_window",
+ "window_size": 512,
+ "init": "transfer",
+ },
+ },
+ },
+ },
+ },
+ }
+ result = compose_configs(source_config_with_bias, surgery)
+
+ mixers = result["decoder"]["block"]["mixer"]["mixers"]
+
+ # Attention sub-mixer inherits bias
+ assert mixers["attention"]["query_layer"]["bias"]["enabled"] is True
+ assert mixers["attention"]["dense_layer"]["bias"]["enabled"] is False
+
+ # Sliding window sub-mixer also inherits bias
+ assert mixers["sliding_window"]["query_layer"]["bias"]["enabled"] is True
+ assert mixers["sliding_window"]["dense_layer"]["bias"]["enabled"] is False
+
+ def test_surgery_can_override_bias(self, source_config_with_bias):
+ """Surgery can explicitly override inherited bias settings."""
+ surgery = {
+ "decoder": {
+ "block": {
+ "mixer": {
+ "dense_layer": {"bias": {"enabled": True}}, # Override O bias
+ },
+ },
+ },
+ }
+ result = compose_configs(source_config_with_bias, surgery)
+
+ mixer = result["decoder"]["block"]["mixer"]
+ # Q/K/V unchanged
+ assert mixer["query_layer"]["bias"]["enabled"] is True
+ # O bias overridden
+ assert mixer["dense_layer"]["bias"]["enabled"] is True
+
class TestComposeConfigsRealYAML:
"""Test compose_configs with real YAML surgery files."""
@@ -398,160 +600,12 @@ def test_build_plan_returns_complete_config(self, llava_pixtral_checkpoint):
mixer = config.decoder["block"]["mixer"]
assert mixer["type"] == "stochastic"
- # Each sub-mixer should have complete config (no init keys)
+ # Each sub-mixer should have complete config
+ # (init is preserved for plan_surgery, stripped only at final output)
for name, sub_mixer in mixer["mixers"].items():
- assert "init" not in sub_mixer, f"Sub-mixer {name} still has 'init' key"
assert "type" in sub_mixer
-class TestMonoidLaws:
- """Test the algebraic laws of compose_configs.
-
- Surgery specs form a MONOID under deep-merge:
- - Identity: {}
- - Operation: deep merge (overlay wins)
- - Associativity: merge(merge(A, B), C) == merge(A, merge(B, C))
-
- compose_configs is a MONOID ACTION on configs:
- - Identity action: apply(config, {}) == config
- - Compatibility: apply(apply(c, A), B) == apply(c, merge(A, B))
- """
-
- @pytest.fixture
- def complete_config(self):
- """A complete Apriel2 config."""
- return {
- "model_type": "apriel2",
- "architectures": ["Apriel2ForConditionalGeneration"],
- "hidden_size": 256,
- "vocab_size": 1000,
- "bos_token_id": 1,
- "eos_token_id": 2,
- "tie_word_embeddings": False,
- "image_token_index": 100,
- "decoder": {
- "type": "fixed",
- "num_blocks": 4,
- "block": {
- "mixer": {
- "type": "attention",
- "heads": 8,
- "head_groups": 4,
- "head_size": 32,
- "rope_theta": 10000.0,
- },
- "mlp": {"type": "mlp", "intermediate_size": 512},
- "normalization": {"type": "rms_norm", "epsilon": 1e-5},
- },
- },
- }
-
- @pytest.fixture
- def surgery_a(self):
- """First surgery: wrap in stochastic with attention."""
- return {
- "decoder": {
- "block": {
- "mixer": {
- "type": "stochastic",
- "main_mixer_name": "attention",
- "mixers": {
- "attention": {"init": "transfer"},
- },
- },
- },
- },
- }
-
- @pytest.fixture
- def surgery_b(self):
- """Second surgery: add sliding window mixer."""
- return {
- "decoder": {
- "block": {
- "mixer": {
- "mixers": {
- "sliding_window": {"init": "transfer", "window_size": 512},
- },
- },
- },
- },
- }
-
- def test_identity_action(self, complete_config):
- """apply(config, {}) == config"""
- result = compose_configs(complete_config, {})
- assert result == complete_config
-
- def test_surgery_monoid_associativity(self, surgery_a, surgery_b):
- """merge(merge(A, B), C) == merge(A, merge(B, C)) for partial configs."""
- surgery_c = {
- "decoder": {
- "block": {
- "mixer": {
- "mixers": {
- "gdn": {"type": "gdn", "init": "transfer", "convolution_layer": {"kernel_size": 4}},
- },
- },
- },
- },
- }
-
- # Left-associated: (A ∘ B) ∘ C
- ab = compose_configs(surgery_a, surgery_b)
- ab_c = compose_configs(ab, surgery_c)
-
- # Right-associated: A ∘ (B ∘ C)
- bc = compose_configs(surgery_b, surgery_c)
- a_bc = compose_configs(surgery_a, bc)
-
- assert ab_c == a_bc, "Surgery monoid should be associative"
-
- def test_monoid_action_compatibility(self, complete_config, surgery_a, surgery_b):
- """apply(apply(c, A), B) == apply(c, merge(A, B))
-
- This is the key law: applying surgeries sequentially should equal
- merging the surgeries first, then applying once.
- """
- # Sequential application: (c ⊳ A) ⊳ B
- result_sequential = compose_configs(compose_configs(complete_config, surgery_a), surgery_b)
-
- # Merged application: c ⊳ (A ∘ B)
- merged_surgery = compose_configs(surgery_a, surgery_b)
- result_merged = compose_configs(complete_config, merged_surgery)
-
- # These should be equivalent
- assert result_sequential == result_merged, "Monoid action should satisfy compatibility law"
-
- def test_three_way_compatibility(self, complete_config, surgery_a, surgery_b):
- """Test with three surgeries for stronger confidence."""
- surgery_c = {
- "decoder": {
- "block": {
- "mixer": {
- "mixers": {
- "gdn": {"type": "gdn", "init": "transfer", "convolution_layer": {"kernel_size": 4}},
- },
- },
- },
- },
- }
-
- # Sequential: ((c ⊳ A) ⊳ B) ⊳ C
- seq = compose_configs(
- compose_configs(compose_configs(complete_config, surgery_a), surgery_b),
- surgery_c
- )
-
- # Merged: c ⊳ ((A ∘ B) ∘ C)
- merged = compose_configs(
- complete_config,
- compose_configs(compose_configs(surgery_a, surgery_b), surgery_c)
- )
-
- assert seq == merged, "Three-way monoid action should satisfy compatibility"
-
-
class TestCompositionTortureTest:
"""Comprehensive stress test for config composition.
@@ -650,19 +704,29 @@ def test_final_config_structure(self, complete_config, additive_surgery_chain):
assert mixer["mixers"]["sliding_window"]["window_size"] == 512
assert mixer["mixers"]["gdn"]["value_heads"] == 16
- def test_no_init_keys_in_result(self, complete_config, additive_surgery_chain):
- """Verify no 'init' keys leak through."""
+ def test_init_keys_preserved_for_planning(self, complete_config, additive_surgery_chain):
+ """Verify 'init' keys are preserved for plan_surgery to see.
- def check_no_init(d, path=""):
- if isinstance(d, dict):
- assert "init" not in d, f"Found 'init' key at {path}"
- for k, v in d.items():
- check_no_init(v, f"{path}.{k}")
+ The `init` field is metadata for weight initialization. It's preserved
+ through composition and only stripped when saving final output.
+ """
+ from fast_llm_external_models.apriel2.conversion.config import strip_init_fields
result = complete_config
for i, surgery in enumerate(additive_surgery_chain):
result = compose_configs(result, surgery)
- check_no_init(result, f"step_{i+1}")
+
+ # init should be in the composed config
+ mixer = result["decoder"]["block"]["mixer"]
+ if "mixers" in mixer:
+ has_init = any("init" in m for m in mixer["mixers"].values())
+ assert has_init, "init should be preserved in composed config"
+
+ # strip_init_fields removes them
+ stripped = strip_init_fields(result)
+ mixer = stripped["decoder"]["block"]["mixer"]
+ if "mixers" in mixer:
+ assert all("init" not in m for m in mixer["mixers"].values())
def test_full_torture_chain(self, complete_config, torture_surgery_chain):
"""Test the full 10-step torture chain produces valid configs."""
diff --git a/fast_llm_external_models/tests/test_apriel2/test_plan_composition_torture.py b/fast_llm_external_models/tests/test_apriel2/test_conversion_e2e.py
similarity index 83%
rename from fast_llm_external_models/tests/test_apriel2/test_plan_composition_torture.py
rename to fast_llm_external_models/tests/test_apriel2/test_conversion_e2e.py
index 3b4adc7f5..09fb9fa13 100644
--- a/fast_llm_external_models/tests/test_apriel2/test_plan_composition_torture.py
+++ b/fast_llm_external_models/tests/test_apriel2/test_conversion_e2e.py
@@ -1,6 +1,6 @@
-"""End-to-end torture test for plan composition.
+"""test_conversion_e2e.py - End-to-end conversion integration tests.
-This tests the FULL pipeline at every step of a surgery chain:
+Tests the FULL pipeline at every step of a surgery chain:
1. Config composition produces valid configs
2. Plan building works for each surgery
3. Plan execution produces valid weights
@@ -1083,66 +1083,6 @@ def mamba_config(self):
},
}
- def test_config_composition_identical_regardless_of_init_mode(self, base_config):
- """Config composition produces same structure with init: transfer vs init: random."""
- # Surgery with init: transfer
- surgery_transfer = {
- "decoder": {
- "block": {
- "mixer": {
- "type": "stochastic",
- "main_mixer_name": "attention",
- "mixers": {
- "attention": {"type": "attention", "init": "transfer"},
- "swa": {
- "type": "attention",
- "init": "transfer",
- "sliding_window": 512,
- },
- },
- },
- },
- },
- }
-
- # Surgery with init: random
- surgery_random = {
- "decoder": {
- "block": {
- "mixer": {
- "type": "stochastic",
- "main_mixer_name": "attention",
- "mixers": {
- "attention": {"type": "attention", "init": "random"},
- "swa": {
- "type": "attention",
- "init": "random",
- "sliding_window": 512,
- },
- },
- },
- },
- },
- }
-
- # Compose configs
- result_transfer = compose_configs(base_config, surgery_transfer)
- result_random = compose_configs(base_config, surgery_random)
-
- # Both should produce identical structure (init is stripped)
- assert result_transfer == result_random, (
- "Config composition should produce identical structure regardless of init mode"
- )
-
- # Verify the structure is correct
- mixer = result_transfer["decoder"]["block"]["mixer"]
- assert mixer["type"] == "stochastic"
- assert "attention" in mixer["mixers"]
- assert "swa" in mixer["mixers"]
- # init should be stripped
- assert "init" not in mixer["mixers"]["attention"]
- assert "init" not in mixer["mixers"]["swa"]
-
def test_plan_surgery_random_succeeds_for_any_type_pair(self, mamba_config):
"""plan_surgery with init: random should succeed even for mamba -> attention."""
# This surgery changes mamba to attention with random init
@@ -1313,8 +1253,8 @@ class TestMarkovianProperty:
"""
@pytest.fixture
- def attention_config(self):
- """Base config with attention."""
+ def attention_config_dict(self):
+ """Base config dict with attention mixer for compose_configs tests."""
return {
"model_type": "apriel2",
"hidden_size": 256,
@@ -1335,43 +1275,7 @@ def attention_config(self):
},
}
- @pytest.fixture
- def stochastic_config(self):
- """Config with stochastic mixer."""
- return {
- "model_type": "apriel2",
- "hidden_size": 256,
- "vocab_size": 1000,
- "decoder": {
- "type": "fixed",
- "num_blocks": 2,
- "block": {
- "mixer": {
- "type": "stochastic",
- "main_mixer_name": "attention",
- "mixers": {
- "attention": {
- "type": "attention",
- "heads": 8,
- "head_groups": 4,
- "head_size": 32,
- },
- "swa": {
- "type": "sliding_window",
- "heads": 8,
- "head_groups": 4,
- "head_size": 32,
- "window_size": 512,
- },
- },
- },
- "mlp": {"type": "mlp", "intermediate_size": 256},
- "normalization": {"type": "rms_norm", "epsilon": 1e-5},
- },
- },
- }
-
- def test_different_paths_same_config_same_plan(self, attention_config):
+ def test_different_paths_same_config_same_plan(self, attention_config_dict):
"""Two different paths to the same config produce identical plans.
Path A: attention -> stochastic{att, swa}
@@ -1398,7 +1302,7 @@ def test_different_paths_same_config_same_plan(self, attention_config):
},
},
}
- config_a = compose_configs(attention_config, surgery_a)
+ config_a = compose_configs(attention_config_dict, surgery_a)
# Path B: First add attention only, then add swa
surgery_b1 = {
@@ -1414,7 +1318,7 @@ def test_different_paths_same_config_same_plan(self, attention_config):
},
},
}
- intermediate_config = compose_configs(attention_config, surgery_b1)
+ intermediate_config = compose_configs(attention_config_dict, surgery_b1)
surgery_b2 = {
"decoder": {
@@ -1469,7 +1373,7 @@ def test_different_paths_same_config_same_plan(self, attention_config):
keys_b = set(str(k) for k in plan_from_b.mappings.keys())
assert keys_a == keys_b, "Plans from same config via different paths should be identical"
- def test_init_in_source_config_does_not_affect_plan(self, attention_config):
+ def test_init_in_source_config_does_not_affect_plan(self, attention_config_dict):
"""Manually injecting init into source config doesn't change the plan.
This tests that plan_surgery reads init from surgery, not source.
@@ -1479,8 +1383,8 @@ def test_init_in_source_config_does_not_affect_plan(self, attention_config):
import copy
# Create two copies of the config
- config_with_init = copy.deepcopy(attention_config)
- config_without_init = copy.deepcopy(attention_config)
+ config_with_init = copy.deepcopy(attention_config_dict)
+ config_without_init = copy.deepcopy(attention_config_dict)
# Manually inject init into one (bypassing compose_configs)
config_with_init["decoder"]["block"]["mixer"]["init"] = "random"
@@ -1510,232 +1414,6 @@ def test_init_in_source_config_does_not_affect_plan(self, attention_config):
# Plans should be identical - source's init field is ignored
assert keys_with == keys_without, "Plan should not depend on init in source config"
- def test_associativity_of_surgery_composition(self, attention_config):
- """Verify associativity: (A ∘ B) ∘ C == A ∘ (B ∘ C) for surgery specs.
-
- This tests that composing surgeries is associative, which is
- equivalent to Markovianity for plan creation.
- """
- surgery_a = {
- "decoder": {
- "block": {
- "mixer": {
- "type": "stochastic",
- "main_mixer_name": "attention",
- "mixers": {
- "attention": {"type": "attention", "init": "transfer"},
- },
- },
- },
- },
- }
-
- surgery_b = {
- "decoder": {
- "block": {
- "mixer": {
- "mixers": {
- "swa": {
- "type": "sliding_window",
- "init": "transfer",
- "window_size": 512,
- },
- },
- },
- },
- },
- }
-
- surgery_c = {
- "decoder": {
- "block": {
- "mixer": {
- "mixers": {
- "gdn": {
- "type": "gdn",
- "init": "random",
- "value_heads": 8,
- "key_heads": 4,
- "key_head_dim": 32,
- "value_head_dim": 32,
- "convolution_layer": {"kernel_size": 4},
- },
- },
- },
- },
- },
- }
-
- # Left association: ((attention_config ∘ A) ∘ B) ∘ C
- left_1 = compose_configs(attention_config, surgery_a)
- left_2 = compose_configs(left_1, surgery_b)
- left_result = compose_configs(left_2, surgery_c)
-
- # Right association: (attention_config ∘ A) ∘ (B ∘ C)
- # Note: B ∘ C is partial ∘ partial = deep merge of surgery specs
- bc_merged = compose_configs(surgery_b, surgery_c)
- right_1 = compose_configs(attention_config, surgery_a)
- right_result = compose_configs(right_1, bc_merged)
-
- assert left_result == right_result, "Surgery composition should be associative"
-
- def test_complete_configs_have_no_init_fields(self, attention_config):
- """Verify that compose_configs strips init from complete configs.
-
- This is the key invariant that enables Markovianity:
- - Complete configs (states) have no init fields
- - Surgery specs (transitions) have init fields
- - Plans read init from surgery, not state
- """
- surgery_with_init = {
- "decoder": {
- "block": {
- "mixer": {
- "type": "stochastic",
- "main_mixer_name": "attention",
- "mixers": {
- "attention": {"type": "attention", "init": "transfer"},
- "swa": {"type": "sliding_window", "init": "random", "window_size": 512},
- },
- },
- },
- },
- }
-
- result = compose_configs(attention_config, surgery_with_init)
-
- # Recursively check for init fields
- def has_init(obj):
- if isinstance(obj, dict):
- if "init" in obj:
- return True
- return any(has_init(v) for v in obj.values())
- if isinstance(obj, list):
- return any(has_init(v) for v in obj)
- return False
-
- assert not has_init(result), "Complete configs should have no init fields"
-
- def test_monoid_action_law_additive_surgeries(self):
- """Monoid action law HOLDS for additive surgeries.
-
- Additive surgeries (no type: declaration) support:
- apply(apply(s, t1), t2) == apply(s, t1 ∘ t2)
-
- This is because additive operations commute nicely:
- "add {a}" then "add {b}" == "add {a, b}"
- """
- # Start with stochastic (additive surgery target)
- s = {
- "model_type": "apriel2",
- "hidden_size": 256,
- "vocab_size": 1000,
- "decoder": {
- "type": "fixed",
- "num_blocks": 2,
- "block": {
- "mixer": {
- "type": "stochastic",
- "main_mixer_name": "attention",
- "mixers": {
- "attention": {"type": "attention", "heads": 8, "head_groups": 4, "head_size": 32},
- },
- },
- "mlp": {"type": "mlp", "intermediate_size": 256},
- "normalization": {"type": "rms_norm", "epsilon": 1e-5},
- },
- },
- }
-
- # Additive surgeries (no type: declaration)
- t1 = {"decoder": {"block": {"mixer": {"mixers": {"swa": {"type": "sliding_window", "window_size": 512}}}}}}
- t2 = {"decoder": {"block": {"mixer": {"mixers": {"mamba": {"type": "mamba", "d_inner": 512}}}}}}
-
- # Path A: Sequential
- s_prime = compose_configs(s, t1)
- s_double_prime_A = compose_configs(s_prime, t2)
-
- # Path B: Composed
- t1_t2 = compose_configs(t1, t2)
- s_double_prime_B = compose_configs(s, t1_t2)
-
- assert s_double_prime_A == s_double_prime_B, "Monoid action law should hold for additive surgeries"
-
- def test_monoid_action_law_replacement_surgeries_fails(self):
- """Monoid action law FAILS for replacement surgeries (by design).
-
- Replacement surgeries (type: stochastic declared) have:
- apply(apply(s, t1), t2) != apply(s, t1 ∘ t2)
-
- This is FUNDAMENTAL, not a bug:
- - Sequential: "set to {a}" then "set to {b}" → {b} (second wins)
- - Composed: merge({a}, {b}) = {a,b}, then apply → {a,b}
-
- These are genuinely different semantics. The failure documents
- the distinction between declarative composition (merge) and
- operational composition (function application).
- """
- s = {
- "model_type": "apriel2",
- "hidden_size": 256,
- "vocab_size": 1000,
- "decoder": {
- "type": "fixed",
- "num_blocks": 2,
- "block": {
- "mixer": {"type": "attention", "heads": 8, "head_groups": 4, "head_size": 32},
- "mlp": {"type": "mlp", "intermediate_size": 256},
- "normalization": {"type": "rms_norm", "epsilon": 1e-5},
- },
- },
- }
-
- # Replacement surgeries (both declare type: stochastic)
- t1 = {
- "decoder": {
- "block": {
- "mixer": {
- "type": "stochastic",
- "main_mixer_name": "attention",
- "mixers": {"attention": {"type": "attention"}},
- }
- }
- }
- }
- t2 = {
- "decoder": {
- "block": {
- "mixer": {
- "type": "stochastic",
- "main_mixer_name": "swa",
- "mixers": {"swa": {"type": "sliding_window", "window_size": 512}},
- }
- }
- }
- }
-
- # Path A: Sequential (second replacement wins)
- s_prime = compose_configs(s, t1)
- s_double_prime_A = compose_configs(s_prime, t2)
-
- # Path B: Composed (declarations merged)
- t1_t2 = compose_configs(t1, t2)
- s_double_prime_B = compose_configs(s, t1_t2)
-
- # They should be DIFFERENT (law fails)
- assert s_double_prime_A != s_double_prime_B, (
- "Monoid action law should FAIL for replacement surgeries"
- )
-
- # Verify the specific difference:
- # Sequential: only swa (second replacement wins)
- # Composed: both attention and swa (merged declarations)
- mixers_A = set(s_double_prime_A["decoder"]["block"]["mixer"]["mixers"].keys())
- mixers_B = set(s_double_prime_B["decoder"]["block"]["mixer"]["mixers"].keys())
-
- assert mixers_A == {"swa"}, "Sequential: second replacement wins"
- assert mixers_B == {"attention", "swa"}, "Composed: declarations merged"
-
class TestCyclingSurgeryGeneration:
"""Tests for the cycling surgery generation functions.
@@ -1980,3 +1658,151 @@ def test_expand_surgery_chain_preserves_invariant(self):
# After cycling and restore, we should be back to the same state
assert current_config == config_after_original
+
+
+class TestBiasSurgeryChain:
+ """Torture tests for per-layer bias inheritance through surgery operations.
+
+ Uses apriel2_config_with_bias + bias_surgery_chain to test that:
+ - Qwen-style per-layer attention bias (QKV enabled, O disabled) survives surgery
+ - Non-gated MLP per-layer bias (layer_1 enabled, layer_2 disabled) survives surgery
+ - Bias settings are correctly inherited by new sub-mixers
+ - Bias is correctly tracked in surgery plans
+ """
+
+ @pytest.fixture
+ def bias_source_config(self, apriel2_config_with_bias):
+ """Convert Apriel2Config to dict for surgery operations."""
+ return apriel2_config_with_bias.to_dict()
+
+ def test_bias_survives_stochastic_wrapper(self, bias_source_config, bias_surgery_chain):
+ """Test that bias settings survive wrapping in stochastic mixer."""
+ # Apply first surgery (wrap in stochastic)
+ result = compose_configs(bias_source_config, bias_surgery_chain[0])
+
+ # Check attention sub-mixer inherited bias settings
+ mixer = result["decoder"]["block"]["mixer"]
+ assert mixer["type"] == "stochastic"
+
+ attn_mixer = mixer["mixers"]["attention"]
+ assert attn_mixer["query_layer"]["bias"]["enabled"] is True
+ assert attn_mixer["key_layer"]["bias"]["enabled"] is True
+ assert attn_mixer["value_layer"]["bias"]["enabled"] is True
+ assert attn_mixer["dense_layer"]["bias"]["enabled"] is False
+
+ # Check MLP bias survived
+ mlp = result["decoder"]["block"]["mlp"]
+ assert mlp["layer_1"]["bias"]["enabled"] is True
+ assert mlp["layer_2"]["bias"]["enabled"] is False
+
+ def test_new_submixer_inherits_bias(self, bias_source_config, bias_surgery_chain):
+ """Test that new sub-mixers inherit bias from source attention."""
+ # Apply S1 + S2 (wrap in stochastic, add sliding_window)
+ config = bias_source_config
+ for surgery in bias_surgery_chain[:2]:
+ config = compose_configs(config, surgery)
+
+ # sliding_window should inherit bias from source attention
+ mixer = config["decoder"]["block"]["mixer"]
+ sw_mixer = mixer["mixers"]["sliding_window"]
+
+ assert sw_mixer["query_layer"]["bias"]["enabled"] is True
+ assert sw_mixer["key_layer"]["bias"]["enabled"] is True
+ assert sw_mixer["value_layer"]["bias"]["enabled"] is True
+ assert sw_mixer["dense_layer"]["bias"]["enabled"] is False
+
+ def test_full_bias_chain_produces_valid_config(self, bias_source_config, bias_surgery_chain):
+ """Test that full bias surgery chain produces valid config."""
+ config = bias_source_config
+ for surgery in bias_surgery_chain:
+ config = compose_configs(config, surgery)
+
+ # Verify final config structure
+ mixer = config["decoder"]["block"]["mixer"]
+ assert mixer["type"] == "stochastic"
+ assert "attention" in mixer["mixers"]
+ assert "sliding_window" in mixer["mixers"]
+ assert "full_bias_attn" in mixer["mixers"]
+
+ # attention and sliding_window inherit Qwen-style bias
+ for name in ["attention", "sliding_window"]:
+ sub = mixer["mixers"][name]
+ assert sub["query_layer"]["bias"]["enabled"] is True
+ assert sub["dense_layer"]["bias"]["enabled"] is False
+
+ # full_bias_attn has add_linear_biases=True but per-layer settings inherited from
+ # source take precedence, so O bias is still disabled
+ full_bias = mixer["mixers"]["full_bias_attn"]
+ assert full_bias.get("add_linear_biases") is True
+ # Per-layer dense_layer.bias.enabled=False inherited from source takes precedence
+ assert full_bias["dense_layer"]["bias"]["enabled"] is False
+
+ def test_bias_plan_has_correct_mappings(self, bias_source_config, bias_surgery_chain):
+ """Test that surgery plan correctly includes/excludes bias weight mappings."""
+ # Compose config first to get full target config with inherited bias settings
+ target_config = compose_configs(bias_source_config, bias_surgery_chain[0])
+ plan = plan_surgery(bias_source_config, target_config)
+ mapping_strs = [str(k) for k in plan.mappings.keys()]
+
+ # Should have q_proj.bias (enabled)
+ q_bias = [m for m in mapping_strs if "q_proj.bias" in m]
+ assert len(q_bias) > 0, "Should have q_proj.bias mappings"
+
+ # Should NOT have o_proj.bias (disabled)
+ o_bias = [m for m in mapping_strs if "o_proj.bias" in m]
+ assert len(o_bias) == 0, "Should not have o_proj.bias mappings"
+
+ # Should have up_proj.bias (layer_1 enabled)
+ up_bias = [m for m in mapping_strs if "up_proj.bias" in m]
+ assert len(up_bias) > 0, "Should have up_proj.bias mappings"
+
+ # Should NOT have down_proj.bias (layer_2 disabled)
+ down_bias = [m for m in mapping_strs if "down_proj.bias" in m]
+ assert len(down_bias) == 0, "Should not have down_proj.bias mappings"
+
+ def test_bias_chain_produces_working_model(self, bias_source_config, bias_surgery_chain):
+ """Test that bias surgery chain produces a working model."""
+ from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForCausalLM
+
+ # Apply full chain
+ config = bias_source_config
+ for surgery in bias_surgery_chain:
+ config = compose_configs(config, surgery)
+
+ # Create model
+ apriel_config = Apriel2Config(**config)
+ model = Apriel2ForCausalLM(apriel_config)
+ model.eval()
+
+ # Verify model structure has correct biases
+ block = model.model.decoder.blocks[0]
+
+ # attention sub-mixer should have QKV bias, no O bias
+ attn = block.mixer.mixers["attention"]
+ assert attn.q_proj.bias is not None
+ assert attn.k_proj.bias is not None
+ assert attn.v_proj.bias is not None
+ assert attn.o_proj.bias is None
+
+ # sliding_window should also inherit bias settings
+ sw = block.mixer.mixers["sliding_window"]
+ assert sw.q_proj.bias is not None
+ assert sw.o_proj.bias is None
+
+ # full_bias_attn inherits per-layer bias from source (even with add_linear_biases=True,
+ # per-layer settings take precedence in same-type inheritance)
+ full_bias = block.mixer.mixers["full_bias_attn"]
+ assert full_bias.q_proj.bias is not None
+ # O bias is disabled because inherited per-layer dense_layer.bias.enabled=False
+ # takes precedence over add_linear_biases=True
+ assert full_bias.o_proj.bias is None
+
+ # MLP should have layer_1 bias, no layer_2 bias
+ assert block.mlp.up_proj.bias is not None
+ assert block.mlp.down_proj.bias is None
+
+ # Forward pass should work
+ input_ids = torch.randint(0, config["vocab_size"], (1, 10))
+ with torch.no_grad():
+ outputs = model(input_ids, use_cache=False)
+ assert outputs.logits.shape == (1, 10, config["vocab_size"])
diff --git a/fast_llm_external_models/tests/test_apriel2/test_expr_plan.py b/fast_llm_external_models/tests/test_apriel2/test_expr_plan.py
index c487ab3a3..569ed88fd 100644
--- a/fast_llm_external_models/tests/test_apriel2/test_expr_plan.py
+++ b/fast_llm_external_models/tests/test_apriel2/test_expr_plan.py
@@ -1711,3 +1711,205 @@ def test_conversion_plan_targets_match_model_state_dict(self, llava_pixtral_conf
assert not missing_from_plan, f"Plan missing keys that model expects: {sorted(missing_from_plan)}"
assert not extra_in_plan, f"Plan has extra keys model doesn't expect: {sorted(extra_in_plan)}"
+
+
+class TestBiasPlanGeneration:
+ """Test that surgery plans correctly handle per-layer bias configurations.
+
+ These tests verify that plan_surgery correctly includes/excludes bias
+ weight mappings based on the per-layer bias settings:
+ - query_layer.bias.enabled, key_layer.bias.enabled, etc. for attention
+ - layer_1.bias.enabled, layer_2.bias.enabled for MLP
+ """
+
+ @pytest.fixture
+ def source_config_with_bias(self):
+ """Source config with Qwen-style bias (QKV enabled, O disabled)."""
+ return {
+ "model_type": "apriel2",
+ "hidden_size": 256,
+ "vocab_size": 1000,
+ "decoder": {
+ "type": "fixed",
+ "num_blocks": 2,
+ "block": {
+ "mixer": {
+ "type": "attention",
+ "heads": 8,
+ "head_groups": 4,
+ "head_size": 32,
+ "rotary": {"type": "mistral_1d", "theta": 10000.0},
+ # Qwen-style: QKV bias enabled, O bias disabled
+ "query_layer": {"bias": {"enabled": True}},
+ "key_layer": {"bias": {"enabled": True}},
+ "value_layer": {"bias": {"enabled": True}},
+ "dense_layer": {"bias": {"enabled": False}},
+ },
+ "mlp": {
+ "type": "mlp",
+ "intermediate_size": 512,
+ "gated": False,
+ # Per-layer MLP bias: layer_1 enabled, layer_2 disabled
+ "layer_1": {"bias": {"enabled": True}},
+ "layer_2": {"bias": {"enabled": False}},
+ },
+ "normalization": {"type": "rms_norm", "epsilon": 1e-5},
+ },
+ },
+ }
+
+ def test_plan_includes_enabled_attention_biases(self, source_config_with_bias):
+ """Surgery plan includes bias mappings for enabled attention biases."""
+ from fast_llm_external_models.apriel2.conversion.config import compose_configs
+ from fast_llm_external_models.apriel2.conversion.converters import plan_surgery
+
+ target_config = compose_configs(source_config_with_bias, {
+ "decoder": {
+ "block": {
+ "mixer": {
+ "type": "stochastic",
+ "main_mixer_name": "attention",
+ "mixers": {
+ "attention": {"init": "transfer"},
+ },
+ },
+ "mlp": {"init": "transfer"},
+ },
+ },
+ })
+
+ plan = plan_surgery(source_config_with_bias, target_config)
+ mapping_strs = [str(k) for k in plan.mappings.keys()]
+
+ # Should have q_proj.bias, k_proj.bias, v_proj.bias mappings
+ q_bias = [m for m in mapping_strs if "q_proj.bias" in m]
+ k_bias = [m for m in mapping_strs if "k_proj.bias" in m]
+ v_bias = [m for m in mapping_strs if "v_proj.bias" in m]
+
+ assert len(q_bias) > 0, "Should have q_proj.bias mappings"
+ assert len(k_bias) > 0, "Should have k_proj.bias mappings"
+ assert len(v_bias) > 0, "Should have v_proj.bias mappings"
+
+ def test_plan_excludes_disabled_attention_biases(self, source_config_with_bias):
+ """Surgery plan excludes bias mappings for disabled attention biases."""
+ from fast_llm_external_models.apriel2.conversion.config import compose_configs
+ from fast_llm_external_models.apriel2.conversion.converters import plan_surgery
+
+ target_config = compose_configs(source_config_with_bias, {
+ "decoder": {
+ "block": {
+ "mixer": {
+ "type": "stochastic",
+ "main_mixer_name": "attention",
+ "mixers": {
+ "attention": {"init": "transfer"},
+ },
+ },
+ "mlp": {"init": "transfer"},
+ },
+ },
+ })
+
+ plan = plan_surgery(source_config_with_bias, target_config)
+ mapping_strs = [str(k) for k in plan.mappings.keys()]
+
+ # Should NOT have o_proj.bias mappings (disabled)
+ o_bias = [m for m in mapping_strs if "o_proj.bias" in m]
+ assert len(o_bias) == 0, f"Should not have o_proj.bias mappings, found: {o_bias}"
+
+ def test_plan_includes_enabled_mlp_biases(self, source_config_with_bias):
+ """Surgery plan includes bias mappings for enabled MLP biases."""
+ from fast_llm_external_models.apriel2.conversion.config import compose_configs
+ from fast_llm_external_models.apriel2.conversion.converters import plan_surgery
+
+ target_config = compose_configs(source_config_with_bias, {
+ "decoder": {
+ "block": {
+ "mixer": {
+ "type": "stochastic",
+ "main_mixer_name": "attention",
+ "mixers": {
+ "attention": {"init": "transfer"},
+ },
+ },
+ "mlp": {"init": "transfer"},
+ },
+ },
+ })
+
+ plan = plan_surgery(source_config_with_bias, target_config)
+ mapping_strs = [str(k) for k in plan.mappings.keys()]
+
+ # Should have up_proj.bias (layer_1) mappings
+ up_bias = [m for m in mapping_strs if "up_proj.bias" in m]
+ assert len(up_bias) > 0, "Should have up_proj.bias mappings"
+
+ def test_plan_excludes_disabled_mlp_biases(self, source_config_with_bias):
+ """Surgery plan excludes bias mappings for disabled MLP biases."""
+ from fast_llm_external_models.apriel2.conversion.config import compose_configs
+ from fast_llm_external_models.apriel2.conversion.converters import plan_surgery
+
+ target_config = compose_configs(source_config_with_bias, {
+ "decoder": {
+ "block": {
+ "mixer": {
+ "type": "stochastic",
+ "main_mixer_name": "attention",
+ "mixers": {
+ "attention": {"init": "transfer"},
+ },
+ },
+ "mlp": {"init": "transfer"},
+ },
+ },
+ })
+
+ plan = plan_surgery(source_config_with_bias, target_config)
+ mapping_strs = [str(k) for k in plan.mappings.keys()]
+
+ # Should NOT have down_proj.bias (layer_2) mappings
+ down_bias = [m for m in mapping_strs if "down_proj.bias" in m]
+ assert len(down_bias) == 0, f"Should not have down_proj.bias mappings, found: {down_bias}"
+
+ def test_plan_random_init_creates_init_expressions_for_bias(self, source_config_with_bias):
+ """Random init creates Init expressions for bias weights."""
+ from fast_llm_external_models.apriel2.conversion.converters import plan_surgery
+
+ # Surgery spec - pass directly to plan_surgery (NOT composed, to preserve init)
+ surgery = {
+ "decoder": {
+ "block": {
+ "mixer": {
+ "type": "stochastic",
+ "main_mixer_name": "attention",
+ "mixers": {
+ "attention": {"init": "transfer"},
+ "new_attention": {
+ "type": "attention",
+ "init": "random", # This triggers random init
+ "heads": 8,
+ "head_groups": 4,
+ "head_size": 32,
+ "rotary": {"type": "mistral_1d", "theta": 10000.0},
+ "add_linear_biases": True, # All biases enabled
+ },
+ },
+ },
+ },
+ },
+ }
+
+ # Pass surgery spec directly - init fields are preserved
+ plan = plan_surgery(source_config_with_bias, surgery)
+
+ # Check that new_attention biases use Init expressions
+ new_mixer_bias_keys = [
+ k for k in plan.mappings.keys()
+ if "new_attention" in str(k) and "bias" in str(k)
+ ]
+
+ assert len(new_mixer_bias_keys) > 0, "Should have bias mappings for new_attention"
+
+ for key in new_mixer_bias_keys:
+ expr = plan.mappings[key]
+ assert isinstance(expr, Init), f"{key} should be Init, got {type(expr)}"
diff --git a/fast_llm_external_models/tests/test_apriel2/test_integration.py b/fast_llm_external_models/tests/test_apriel2/test_integration.py
new file mode 100644
index 000000000..b90f0774e
--- /dev/null
+++ b/fast_llm_external_models/tests/test_apriel2/test_integration.py
@@ -0,0 +1,342 @@
+"""Integration tests for Qwen2 -> Apriel2 -> Fast-LLM conversion pipeline.
+
+Tests verify the full conversion chain:
+1. Qwen2 -> Apriel2 (external module conversion)
+2. Apriel2 + Surgery -> Supernet (stochastic mixer creation)
+3. Supernet -> Fast-LLM -> Supernet (roundtrip through training format)
+
+Test Strategy:
+- Use real HuggingFace model (Qwen2.5-0.5B) for meaningful validation
+- Separate config preservation tests from numerical equivalence tests
+- Parameterize both conversion stages AND input variations
+- Single test implementation applied across all stages
+"""
+
+import json
+import tempfile
+from pathlib import Path
+
+import pytest
+import torch
+
+from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config
+from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForCausalLM
+from fast_llm_external_models.apriel2.conversion import (
+ compose,
+ compose_configs,
+ execute,
+ plan_surgery,
+)
+from fast_llm_external_models.apriel2.conversion.expr import W
+from fast_llm_external_models.apriel2.conversion.qwen2.config import convert_config as convert_qwen2_config
+from fast_llm_external_models.apriel2.conversion.qwen2.plan import plan_qwen2_to_apriel2
+
+from .conftest import requires_fastllm
+
+
+# =============================================================================
+# Test Input Variations
+# =============================================================================
+
+TEST_INPUTS = pytest.mark.parametrize(
+ "prompts,max_new_tokens",
+ [
+ pytest.param(["Hello world"], 10, id="single_short"),
+ pytest.param(["Hi", "The quick brown fox jumps over the lazy dog"], 20, id="batch_varied"),
+ pytest.param(["Once upon a time"], 50, id="long_generation"),
+ ],
+)
+
+
+# =============================================================================
+# Conversion Fixtures
+# =============================================================================
+
+
+@pytest.fixture(scope="module")
+def qwen2_source():
+ """Load Qwen2.5-0.5B as the source/reference model."""
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
+
+ model_name = "Qwen/Qwen2.5-0.5B"
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
+ model = AutoModelForCausalLM.from_pretrained(
+ model_name, torch_dtype=torch.float32, trust_remote_code=True
+ )
+ config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
+ model.eval()
+
+ if tokenizer.pad_token is None:
+ tokenizer.pad_token = tokenizer.eos_token
+ tokenizer.padding_side = "left"
+
+ return {
+ "model": model,
+ "tokenizer": tokenizer,
+ "config_dict": config.to_dict(),
+ "state_dict": model.state_dict(),
+ }
+
+
+@pytest.fixture(scope="module")
+def apriel2_converted(qwen2_source):
+ """Stage 1: Qwen2 -> Apriel2."""
+ config_dict = convert_qwen2_config(qwen2_source["config_dict"])
+ plan = plan_qwen2_to_apriel2(qwen2_source["config_dict"])
+ weights = execute(plan, {W(k): v for k, v in qwen2_source["state_dict"].items()}, seed=42)
+
+ config = Apriel2Config(**config_dict)
+ model = Apriel2ForCausalLM(config)
+ model.load_state_dict({str(k): v for k, v in weights.items()}, strict=False)
+ model.eval()
+
+ return {"model": model, "config_dict": config_dict, "plan": plan, "name": "Apriel2"}
+
+
+@pytest.fixture(scope="module")
+def supernet_converted(qwen2_source, apriel2_converted):
+ """Stage 2: Apriel2 + Surgery -> Supernet."""
+ surgery_spec = {
+ "decoder": {
+ "block": {
+ "mixer": {
+ "type": "stochastic",
+ "main_mixer_name": "attention",
+ "mixers": {
+ "attention": {"type": "attention", "init": "transfer"},
+ "sliding_window": {
+ "type": "attention",
+ "init": "transfer",
+ "window_size": 4096,
+ },
+ },
+ },
+ },
+ },
+ }
+
+ apriel_config = apriel2_converted["config_dict"]
+ supernet_config = compose_configs(apriel_config, surgery_spec)
+
+ full_plan = compose(
+ apriel2_converted["plan"],
+ plan_surgery(apriel_config, supernet_config),
+ )
+
+ weights = execute(full_plan, {W(k): v for k, v in qwen2_source["state_dict"].items()}, seed=42)
+
+ config = Apriel2Config(**supernet_config)
+ model = Apriel2ForCausalLM(config)
+ model.load_state_dict({str(k): v for k, v in weights.items()}, strict=False)
+ model.eval()
+
+ return {"model": model, "config_dict": supernet_config, "name": "Supernet"}
+
+
+@pytest.fixture(scope="module")
+def roundtrip_converted(supernet_converted, qwen2_source):
+ """Stage 3: Supernet -> Fast-LLM -> Supernet."""
+ if not torch.cuda.is_available():
+ pytest.skip("Roundtrip conversion requires CUDA (integration tests need realistic hardware)")
+
+ from fast_llm.engine.checkpoint.config import (
+ CheckpointLoadConfig,
+ CheckpointSaveConfig,
+ FastLLMCheckpointFormat,
+ )
+ from fast_llm.engine.checkpoint.convert import ConvertConfig
+ from fast_llm.models.gpt.config import GPTModelConfig
+ from fast_llm.models.gpt.conversion.config import Apriel2TextCheckpointFormat
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ tmpdir = Path(tmpdir)
+ supernet_path = tmpdir / "supernet"
+ fastllm_path = tmpdir / "fastllm"
+ roundtrip_path = tmpdir / "roundtrip"
+
+ supernet_converted["model"].save_pretrained(supernet_path)
+ qwen2_source["tokenizer"].save_pretrained(supernet_path)
+
+ ConvertConfig(
+ model=GPTModelConfig,
+ input=CheckpointLoadConfig(path=supernet_path, format=Apriel2TextCheckpointFormat),
+ output=CheckpointSaveConfig(path=fastllm_path, format=FastLLMCheckpointFormat),
+ ).run()
+
+ ConvertConfig(
+ model=GPTModelConfig,
+ input=CheckpointLoadConfig(path=fastllm_path, format=FastLLMCheckpointFormat),
+ output=CheckpointSaveConfig(path=roundtrip_path, format=Apriel2TextCheckpointFormat),
+ ).run()
+
+ model = Apriel2ForCausalLM.from_pretrained(roundtrip_path)
+ model.eval()
+
+ with open(roundtrip_path / "config.json") as f:
+ config_dict = json.load(f)
+
+ yield {"model": model, "config_dict": config_dict, "name": "Roundtrip"}
+
+
+# =============================================================================
+# Parameterized Fixture: All Conversion Stages
+# =============================================================================
+
+
+@pytest.fixture(params=["apriel2", "supernet", "roundtrip"])
+def converted_model(request, apriel2_converted, supernet_converted):
+ """Parameterized fixture providing each conversion stage for testing.
+
+ This allows a single test to run against all stages automatically.
+ """
+ if request.param == "roundtrip":
+ pytest.importorskip("fast_llm")
+ if not torch.cuda.is_available():
+ pytest.skip("Roundtrip tests require CUDA (integration tests need realistic hardware)")
+ # Lazy-load to avoid fixture evaluation when CUDA unavailable
+ roundtrip_converted = request.getfixturevalue("roundtrip_converted")
+ return roundtrip_converted
+
+ return {
+ "apriel2": apriel2_converted,
+ "supernet": supernet_converted,
+ }[request.param]
+
+
+# =============================================================================
+# Config Preservation Tests
+# =============================================================================
+
+
+@pytest.mark.slow
+class TestConfigPreservation:
+ """Verify configs are correctly preserved through the conversion chain."""
+
+ def test_apriel2_structure(self, qwen2_source, apriel2_converted):
+ """Qwen2 -> Apriel2 preserves model dimensions."""
+ qwen = qwen2_source["config_dict"]
+ apriel = apriel2_converted["config_dict"]
+
+ assert apriel["hidden_size"] == qwen["hidden_size"]
+ assert apriel["vocab_size"] == qwen["vocab_size"]
+ assert apriel["decoder"]["num_blocks"] == qwen["num_hidden_layers"]
+
+ def test_apriel2_bias_pattern(self, apriel2_converted):
+ """Qwen2 -> Apriel2 preserves Qwen-style bias (QKV yes, O no)."""
+ mixer = apriel2_converted["config_dict"]["decoder"]["block"]["mixer"]
+
+ assert mixer["query_layer"]["bias"]["enabled"] is True
+ assert mixer["key_layer"]["bias"]["enabled"] is True
+ assert mixer["value_layer"]["bias"]["enabled"] is True
+ assert mixer["dense_layer"]["bias"]["enabled"] is False
+
+ def test_supernet_structure(self, supernet_converted):
+ """Surgery creates correct stochastic mixer structure."""
+ mixer = supernet_converted["config_dict"]["decoder"]["block"]["mixer"]
+
+ assert mixer["type"] == "stochastic"
+ assert mixer["main_mixer_name"] == "attention"
+ assert set(mixer["mixers"].keys()) == {"attention", "sliding_window"}
+
+ def test_supernet_bias_inheritance(self, supernet_converted):
+ """Submixers inherit bias settings from source."""
+ mixer = supernet_converted["config_dict"]["decoder"]["block"]["mixer"]
+
+ for name in ["attention", "sliding_window"]:
+ assert mixer["mixers"][name]["query_layer"]["bias"]["enabled"] is True
+ assert mixer["mixers"][name]["dense_layer"]["bias"]["enabled"] is False
+
+ @requires_fastllm
+ def test_roundtrip_structure(self, roundtrip_converted):
+ """Fast-LLM roundtrip preserves stochastic mixer structure."""
+ mixer = roundtrip_converted["config_dict"]["decoder"]["block"]["mixer"]
+
+ assert mixer["type"] == "stochastic"
+ assert mixer["main_mixer_name"] == "attention"
+ assert set(mixer["mixers"].keys()) == {"attention", "sliding_window"}
+
+ @requires_fastllm
+ def test_roundtrip_bias_preservation(self, roundtrip_converted):
+ """Fast-LLM roundtrip preserves per-layer bias settings."""
+ mixer = roundtrip_converted["config_dict"]["decoder"]["block"]["mixer"]
+
+ for name in ["attention", "sliding_window"]:
+ assert mixer["mixers"][name]["query_layer"]["bias"]["enabled"] is True
+ assert mixer["mixers"][name]["dense_layer"]["bias"]["enabled"] is False
+
+
+# =============================================================================
+# Numerical Equivalence Tests
+# =============================================================================
+
+
+@pytest.mark.slow
+class TestNumericalEquivalence:
+ """Verify all conversion stages produce numerically identical outputs.
+
+ Uses parameterized fixtures to test all stages with all input variations,
+ giving us 3 stages × 3 inputs = 9 test cases from a single test function.
+ """
+
+ @TEST_INPUTS
+ def test_logits_match(self, qwen2_source, converted_model, prompts, max_new_tokens):
+ """Converted model produces identical logits to source."""
+ tokenizer = qwen2_source["tokenizer"]
+ ref_model = qwen2_source["model"]
+ test_model = converted_model["model"]
+ stage = converted_model["name"]
+
+ inputs = tokenizer(prompts, return_tensors="pt", padding=True)
+ ref_device = next(ref_model.parameters()).device
+ test_device = next(test_model.parameters()).device
+
+ with torch.no_grad():
+ ref_logits = ref_model(
+ input_ids=inputs.input_ids.to(ref_device),
+ attention_mask=inputs.attention_mask.to(ref_device),
+ ).logits.cpu()
+
+ test_logits = test_model(
+ input_ids=inputs.input_ids.to(test_device),
+ attention_mask=inputs.attention_mask.to(test_device),
+ ).logits.cpu()
+
+ max_diff = (ref_logits - test_logits).abs().max().item()
+ assert torch.allclose(ref_logits, test_logits, rtol=1e-4, atol=1e-4), (
+ f"{stage} logits mismatch: max diff = {max_diff:.6f}"
+ )
+
+ @TEST_INPUTS
+ def test_generation_match(self, qwen2_source, converted_model, prompts, max_new_tokens):
+ """Converted model produces identical generation to source."""
+ tokenizer = qwen2_source["tokenizer"]
+ ref_model = qwen2_source["model"]
+ test_model = converted_model["model"]
+ stage = converted_model["name"]
+
+ inputs = tokenizer(prompts, return_tensors="pt", padding=True)
+ ref_device = next(ref_model.parameters()).device
+ test_device = next(test_model.parameters()).device
+
+ with torch.no_grad():
+ ref_gen = ref_model.generate(
+ input_ids=inputs.input_ids.to(ref_device),
+ attention_mask=inputs.attention_mask.to(ref_device),
+ max_new_tokens=max_new_tokens,
+ do_sample=False,
+ pad_token_id=tokenizer.pad_token_id,
+ ).cpu()
+
+ test_gen = test_model.generate(
+ input_ids=inputs.input_ids.to(test_device),
+ attention_mask=inputs.attention_mask.to(test_device),
+ max_new_tokens=max_new_tokens,
+ do_sample=False,
+ pad_token_id=tokenizer.pad_token_id,
+ ).cpu()
+
+ assert torch.equal(ref_gen, test_gen), (
+ f"{stage} generation mismatch:\n"
+ f" Reference: {tokenizer.batch_decode(ref_gen, skip_special_tokens=True)}\n"
+ f" Test: {tokenizer.batch_decode(test_gen, skip_special_tokens=True)}"
+ )
diff --git a/fast_llm_external_models/tests/test_apriel2/test_modeling.py b/fast_llm_external_models/tests/test_apriel2/test_modeling.py
index 5dbd36159..47c877d09 100644
--- a/fast_llm_external_models/tests/test_apriel2/test_modeling.py
+++ b/fast_llm_external_models/tests/test_apriel2/test_modeling.py
@@ -12,7 +12,8 @@ class TestApriel2Modeling:
"apriel2_config_tiny",
"apriel2_config_stochastic",
"apriel2_config_multi_mixer",
- "apriel2_config_all_mixers" # Tests all 4 mixer types
+ "apriel2_config_all_mixers", # Tests all 4 mixer types
+ "apriel2_config_with_bias", # Tests per-layer bias and non-gated MLP
])
def test_model_end_to_end(self, config_name, request):
"""Test instantiation, forward pass, cache correctness, and generation.
diff --git a/fast_llm_external_models/tests/test_apriel2/test_plan_execution.py b/fast_llm_external_models/tests/test_apriel2/test_plan_execution.py
new file mode 100644
index 000000000..9a98ec13b
--- /dev/null
+++ b/fast_llm_external_models/tests/test_apriel2/test_plan_execution.py
@@ -0,0 +1,597 @@
+"""test_plan_execution.py - Plan execution and algebraic composition laws.
+
+This module provides rigorous, parameterized tests for the mathematical properties
+that the conversion system must satisfy. Each test class corresponds to one
+algebraic structure, and each test method verifies one specific law.
+
+Conceptual Types
+================
+
+The conversion system operates on three conceptual types (all ``dict`` at runtime):
+
+- **S (State)**: Complete config without ``init`` fields
+- **P (Partial Surgery)**: Incomplete config, may have ``init`` fields
+- **T (Transition Spec)**: Complete config WITH ``init`` fields
+
+Algebraic Structures
+====================
+
+1. **Partial Surgeries (P)** form a **Monoid** under deep merge::
+
+ compose_configs : P × P → P
+ Identity: {}
+ Associativity: (p1 ∘ p2) ∘ p3 = p1 ∘ (p2 ∘ p3)
+
+2. **Surgeries act on States** to produce Transition Specs::
+
+ compose_configs : S × P → T
+ compose_configs : T × P → T
+
+ Action law (additive surgeries): (s · p1) · p2 = s · (p1 ∘ p2)
+
+3. **Plans** form a **Category** with composition::
+
+ compose : Plan(A→B) × Plan(B→C) → Plan(A→C)
+ Associativity: (P1 ∘ P2) ∘ P3 = P1 ∘ (P2 ∘ P3)
+
+4. **plan_surgery is a Functor** from config pairs to plans::
+
+ plan_surgery : S × T → Plan
+ Functoriality: compose(plan(S,T1), plan(T1,T2)) ≡ plan(S,T2)
+
+ This is semantic equivalence: both produce identical weights when executed.
+
+Important Behaviors Tested
+==========================
+
+- **init stripping**: Between surgery iterations, T → S conversion via
+ ``strip_init_fields()`` ensures ``init: random`` applies only to the surgery
+ that introduces a component.
+
+- **Bias inheritance**: Per-layer bias settings propagate through surgery chains.
+
+- **Plan composition**: Composed plans produce identical weights to direct plans.
+
+Design Principles
+=================
+
+- Each law gets ONE parameterized test, not multiple similar tests
+- Fixtures provide diverse configs (with/without biases)
+- Corner cases are covered via parameterization, not test proliferation
+- Tests document the laws they verify in their docstrings
+"""
+
+import pytest
+import torch
+from functools import reduce
+
+from fast_llm_external_models.apriel2.conversion import (
+ compose,
+ compose_configs,
+ execute,
+ plan_surgery,
+ ExprPlan,
+ W,
+ Ref,
+ Concat,
+ Slice,
+ Init,
+)
+
+# Import shared helper from conftest
+from fast_llm_external_models.tests.test_apriel2.conftest import make_weights_for_config
+
+
+# =============================================================================
+# Fixtures: Use shared fixtures from conftest.py where possible
+# =============================================================================
+# - base_config_dict: Complete config without biases (Llama-style)
+# - base_config_with_bias_dict: Complete config with QKV biases
+# - additive_surgery_chain: [wrap_stochastic, add_sliding_window, add_gdn]
+# =============================================================================
+
+
+# =============================================================================
+# Test: Plan Composition Associativity
+# =============================================================================
+
+
+class TestPlanCompositionAssociativity:
+ """
+ LAW: Plan composition is associative.
+
+ (P₁ ∘ P₂) ∘ P₃ = P₁ ∘ (P₂ ∘ P₃)
+
+ where ∘ denotes compose(P1, P2).
+
+ This must hold for the AST structure, not just semantic equivalence.
+ """
+
+ @pytest.mark.parametrize("expr_type", ["ref_chain", "with_concat", "with_slice", "with_init"])
+ def test_associativity(self, expr_type):
+ """Plan composition is associative for various expression types."""
+ # Build three plans that can be composed
+ if expr_type == "ref_chain":
+ p1 = ExprPlan(mappings={W("b"): Ref(key=W("a"))})
+ p2 = ExprPlan(mappings={W("c"): Ref(key=W("b"))})
+ p3 = ExprPlan(mappings={W("d"): Ref(key=W("c"))})
+ elif expr_type == "with_concat":
+ p1 = ExprPlan(mappings={W("x"): Ref(key=W("a")), W("y"): Ref(key=W("b"))})
+ p2 = ExprPlan(mappings={W("xy"): Concat(exprs=(Ref(key=W("x")), Ref(key=W("y"))), dim=0)})
+ p3 = ExprPlan(mappings={W("final"): Ref(key=W("xy"))})
+ elif expr_type == "with_slice":
+ p1 = ExprPlan(mappings={W("full"): Ref(key=W("src"))})
+ p2 = ExprPlan(mappings={W("part"): Slice(expr=Ref(key=W("full")), slices=((0, 5, None),))})
+ p3 = ExprPlan(mappings={W("out"): Ref(key=W("part"))})
+ elif expr_type == "with_init":
+ p1 = ExprPlan(mappings={W("x"): Ref(key=W("a"))})
+ p2 = ExprPlan(mappings={W("y"): Concat(exprs=(Ref(key=W("x")), Init(shape=(5,), init_type="zeros")), dim=0)})
+ p3 = ExprPlan(mappings={W("z"): Ref(key=W("y"))})
+
+ left = compose(compose(p1, p2), p3)
+ right = compose(p1, compose(p2, p3))
+
+ assert left.mappings == right.mappings, f"Associativity failed for {expr_type}"
+
+
+# =============================================================================
+# Test: Functoriality of plan_surgery (THE CRITICAL PROPERTY)
+# =============================================================================
+
+
+class TestPlanSurgeryFunctoriality:
+ """
+ LAW: plan_surgery is functorial with respect to config composition.
+
+ For a surgery chain P₁, P₂, ..., Pₙ applied to base state S₀::
+
+ T₁ = compose_configs(S₀, P₁) # S × P → T
+ T₂ = compose_configs(T₁, P₂) # T × P → T (no stripping!)
+ ...
+ Tₙ = compose_configs(Tₙ₋₁, Pₙ)
+
+ Plan functoriality says::
+
+ compose(plan(S₀,T₁), plan(T₁,T₂), ...) ≡ plan(S₀, Tₙ)
+
+ where ≡ denotes semantic equivalence (identical weights when executed).
+
+ NOTE: This tests T × P composition WITHOUT stripping between steps.
+ This differs from build_plan which strips (T → S) between iterations.
+ Both patterns are valid:
+
+ - Without stripping: init fields accumulate, testing plan composition purity
+ - With stripping: init consumed per-step, testing real usage (see
+ test_build_plan_strips_init_between_iterations)
+
+ The functoriality law holds in both cases because plan composition
+ correctly substitutes Ref expressions with their definitions.
+ """
+
+ @pytest.mark.parametrize("chain_length", [1, 2, 3])
+ @pytest.mark.parametrize("use_bias", [True, False])
+ def test_functoriality(
+ self,
+ chain_length,
+ use_bias,
+ base_config_dict,
+ base_config_with_bias_dict,
+ additive_surgery_chain,
+ ):
+ """
+ Composed incremental plans produce same weights as direct plan.
+
+ Parameterized over:
+ - chain_length: Number of surgeries (1, 2, or 3)
+ - use_bias: Whether base config has biases
+ """
+ base_config = base_config_with_bias_dict if use_bias else base_config_dict
+ surgeries = additive_surgery_chain[:chain_length]
+
+ # Build config chain: C₀ → C₁ → ... → Cₙ
+ configs = [base_config]
+ for s in surgeries:
+ configs.append(compose_configs(configs[-1], s))
+
+ # Build incremental plans: Pₖ = plan_surgery(Cₖ₋₁, Cₖ)
+ plans = [plan_surgery(configs[i], configs[i+1]) for i in range(len(surgeries))]
+
+ # Compose all incremental plans
+ composed_plan = reduce(compose, plans)
+
+ # Build direct plan: plan_surgery(C₀, Cₙ)
+ direct_plan = plan_surgery(configs[0], configs[-1])
+
+ # Execute both on same weights
+ weights = make_weights_for_config(base_config)
+ composed_weights = execute(composed_plan, weights, seed=42)
+ direct_weights = execute(direct_plan, weights, seed=42)
+
+ # Verify semantic equivalence
+ assert set(composed_weights.keys()) == set(direct_weights.keys()), \
+ f"Key sets differ for chain_length={chain_length}, use_bias={use_bias}"
+
+ for key in composed_weights:
+ assert torch.allclose(composed_weights[key], direct_weights[key], atol=1e-6), \
+ f"Weight mismatch for {key} with chain_length={chain_length}, use_bias={use_bias}"
+
+ @pytest.mark.parametrize("split_point", [1, 2])
+ def test_arbitrary_grouping(
+ self,
+ split_point,
+ base_config_with_bias_dict,
+ additive_surgery_chain,
+ ):
+ """
+ Any grouping of surgery chain produces same result.
+
+ For surgeries [S₁, S₂, S₃], tests that:
+ - compose(P₁, compose(P₂, P₃))
+ - compose(compose(P₁, P₂), P₃)
+ - plan_surgery(C₀, C₃)
+
+ all produce identical weights.
+ """
+ surgeries = additive_surgery_chain
+
+ # Build config chain
+ configs = [base_config_with_bias_dict]
+ for s in surgeries:
+ configs.append(compose_configs(configs[-1], s))
+
+ # Build incremental plans
+ plans = [plan_surgery(configs[i], configs[i+1]) for i in range(3)]
+
+ # Different groupings
+ left_grouped = compose(compose(plans[0], plans[1]), plans[2])
+ right_grouped = compose(plans[0], compose(plans[1], plans[2]))
+ direct = plan_surgery(configs[0], configs[-1])
+
+ # Execute all
+ weights = make_weights_for_config(base_config_with_bias_dict)
+ results = {
+ "left": execute(left_grouped, weights, seed=42),
+ "right": execute(right_grouped, weights, seed=42),
+ "direct": execute(direct, weights, seed=42),
+ }
+
+ # All must match
+ keys = set(results["left"].keys())
+ assert keys == set(results["right"].keys()) == set(results["direct"].keys())
+
+ for key in keys:
+ assert torch.allclose(results["left"][key], results["right"][key], atol=1e-6)
+ assert torch.allclose(results["left"][key], results["direct"][key], atol=1e-6)
+
+
+# =============================================================================
+# Test: Bias Inheritance Preservation (Regression for the specific bug)
+# =============================================================================
+
+
+class TestBiasInheritancePreservation:
+ """
+ PROPERTY: Per-layer bias settings must be preserved through surgery chains.
+
+ When a surgery spec does not mention bias settings, they must be inherited
+ from the source config. This is the specific failure mode of the build_plan
+ bug: passing partial surgery specs to plan_surgery lost inherited fields.
+
+ This test verifies the SYMPTOM (missing biases) rather than the LAW
+ (functoriality). It's kept as a focused regression test.
+ """
+
+ @pytest.mark.parametrize("num_surgeries", [1, 2, 3])
+ def test_qkv_biases_preserved_through_chain(
+ self,
+ num_surgeries,
+ base_config_with_bias_dict,
+ additive_surgery_chain,
+ ):
+ """QKV biases (enabled in source) appear in plan after N surgeries."""
+ surgeries = additive_surgery_chain[:num_surgeries]
+
+ # Build config and plan chain
+ configs = [base_config_with_bias_dict]
+ for s in surgeries:
+ configs.append(compose_configs(configs[-1], s))
+
+ plans = [plan_surgery(configs[i], configs[i+1]) for i in range(num_surgeries)]
+ final_plan = reduce(compose, plans) if len(plans) > 1 else plans[0]
+
+ # Check bias keys present
+ target_keys = {str(k) for k in final_plan.target_keys()}
+
+ assert any("q_proj.bias" in k for k in target_keys), \
+ f"q_proj.bias missing after {num_surgeries} surgeries"
+ assert any("k_proj.bias" in k for k in target_keys), \
+ f"k_proj.bias missing after {num_surgeries} surgeries"
+ assert any("v_proj.bias" in k for k in target_keys), \
+ f"v_proj.bias missing after {num_surgeries} surgeries"
+ # O bias should NOT be present (disabled in source)
+ assert not any("o_proj.bias" in k for k in target_keys), \
+ f"o_proj.bias should not be present (disabled in source)"
+
+ def test_bias_values_preserved(
+ self,
+ base_config_with_bias_dict,
+ additive_surgery_chain,
+ ):
+ """Bias tensor values are correctly transferred, not just keys."""
+ surgery = additive_surgery_chain[0] # wrap_stochastic
+ c1 = compose_configs(base_config_with_bias_dict, surgery)
+ plan = plan_surgery(base_config_with_bias_dict, c1)
+
+ weights = make_weights_for_config(base_config_with_bias_dict)
+ result = execute(plan, weights, seed=42)
+
+ # Verify values match (not just that keys exist)
+ for i in range(base_config_with_bias_dict["decoder"]["num_blocks"]):
+ src_key = W(f"model.decoder.blocks.{i}.mixer.q_proj.bias")
+ dst_key = W(f"model.decoder.blocks.{i}.mixer.mixers.attention.q_proj.bias")
+
+ assert dst_key in result, f"Missing {dst_key}"
+ assert torch.allclose(weights[src_key], result[dst_key]), \
+ f"Bias values differ for block {i}"
+
+
+# =============================================================================
+# Test: build_plan Integration (Regression test for convert.py)
+# =============================================================================
+
+
+class TestBuildPlanIntegration:
+ """
+ REGRESSION: build_plan must compose configs before calling plan_surgery.
+
+ The bug was:
+ plan_surgery(current_config, surgery_config) # WRONG: partial
+
+ Should be:
+ target = compose_configs(current_config, surgery_config)
+ plan_surgery(current_config, target) # CORRECT: complete
+
+ This test verifies the fix in convert.py's build_plan function.
+ """
+
+ @pytest.mark.parametrize("num_surgeries", [1, 2])
+ def test_build_plan_preserves_inherited_fields(
+ self,
+ num_surgeries,
+ base_config_with_bias_dict,
+ additive_surgery_chain,
+ ):
+ """build_plan produces plans with inherited bias mappings."""
+ from fast_llm_external_models.apriel2.convert import build_plan
+
+ surgeries = additive_surgery_chain[:num_surgeries]
+
+ plan, final_config = build_plan(
+ base_config_with_bias_dict,
+ surgeries,
+ source_format="apriel2",
+ )
+
+ # Verify inherited biases in config
+ if num_surgeries >= 1:
+ attn = final_config["decoder"]["block"]["mixer"]["mixers"]["attention"]
+ assert attn.get("query_layer", {}).get("bias", {}).get("enabled") is True
+
+ # Verify bias mappings in plan
+ target_keys = {str(k) for k in plan.target_keys()}
+ assert any("q_proj.bias" in k for k in target_keys), \
+ f"build_plan with {num_surgeries} surgeries missing q_proj.bias"
+
+
+# =============================================================================
+# Test: init Field Preservation (Critical for random initialization)
+# =============================================================================
+
+
+class TestInitFieldPreservation:
+ """
+ PROPERTY: The `init` field must be visible to plan_surgery.
+
+ The `init` field controls weight initialization mode:
+ - `init: transfer` → use weight transfer/conversion
+ - `init: random` → use random initialization
+
+ compose_configs must preserve `init` so plan_surgery can see it.
+ Stripping happens only at final output (when saving to disk).
+ """
+
+ def test_init_random_produces_init_expression(self, base_config_with_bias_dict):
+ """Surgery with init: random produces Init expressions in plan."""
+ surgery = {
+ "decoder": {
+ "block": {
+ "mixer": {
+ "type": "stochastic",
+ "main_mixer_name": "attention",
+ "mixers": {
+ "attention": {"init": "transfer"},
+ "gdn": {"type": "gdn", "init": "random", "convolution_layer": {"kernel_size": 4}},
+ },
+ },
+ },
+ },
+ }
+
+ target = compose_configs(base_config_with_bias_dict, surgery)
+ plan = plan_surgery(base_config_with_bias_dict, target)
+
+ # Check that GDN weights use Init expressions (random init)
+ target_keys = {str(k) for k in plan.target_keys()}
+ gdn_keys = [k for k in target_keys if "gdn" in k.lower()]
+
+ assert len(gdn_keys) > 0, "No GDN keys in plan"
+
+ # Verify at least one GDN weight uses Init (random initialization)
+ has_init_expr = False
+ for key in plan.target_keys():
+ if "gdn" in str(key).lower():
+ expr = plan.mappings[key]
+ if isinstance(expr, Init):
+ has_init_expr = True
+ break
+ # Also check inside Concat/other composite expressions
+ if hasattr(expr, 'exprs'):
+ for sub in expr.exprs:
+ if isinstance(sub, Init):
+ has_init_expr = True
+ break
+
+ assert has_init_expr, "init: random should produce Init expressions for GDN weights"
+
+ def test_init_transfer_produces_ref_expression(self, base_config_with_bias_dict):
+ """Surgery with init: transfer produces Ref expressions (weight transfer)."""
+ surgery = {
+ "decoder": {
+ "block": {
+ "mixer": {
+ "type": "stochastic",
+ "main_mixer_name": "attention",
+ "mixers": {
+ "attention": {"init": "transfer"},
+ },
+ },
+ },
+ },
+ }
+
+ target = compose_configs(base_config_with_bias_dict, surgery)
+ plan = plan_surgery(base_config_with_bias_dict, target)
+
+ # Check that attention weights use Ref expressions (transfer)
+ has_ref = False
+ for key in plan.target_keys():
+ if "attention" in str(key) and "q_proj.weight" in str(key):
+ expr = plan.mappings[key]
+ if isinstance(expr, Ref):
+ has_ref = True
+ break
+
+ assert has_ref, "init: transfer should produce Ref expressions for attention weights"
+
+ def test_build_plan_respects_init_random(self, base_config_with_bias_dict):
+ """build_plan correctly uses init: random for weight initialization."""
+ from fast_llm_external_models.apriel2.convert import build_plan
+
+ # Mamba requires many config fields for random init
+ surgery = {
+ "decoder": {
+ "block": {
+ "mixer": {
+ "type": "stochastic",
+ "main_mixer_name": "attention",
+ "mixers": {
+ "attention": {"init": "transfer"},
+ "mamba": {
+ "type": "mamba",
+ "init": "random",
+ "d_inner": 512,
+ "d_state": 16,
+ "dt_rank": 16,
+ "d_xb": 64,
+ "d_conv": 4,
+ "repeat_kv_before_conv": False,
+ "conv_bias": True,
+ "dt_proj_bias": True,
+ "dt_min": 0.001,
+ "dt_max": 0.1,
+ "dt_init_floor": 1e-4,
+ },
+ },
+ },
+ },
+ },
+ }
+
+ plan, final_config = build_plan(
+ base_config_with_bias_dict,
+ [surgery],
+ source_format="apriel2",
+ )
+
+ # Verify mamba weights use Init (random init)
+ has_mamba_init = False
+ for key in plan.target_keys():
+ key_str = str(key)
+ if "mamba" in key_str:
+ expr = plan.mappings[key]
+ if isinstance(expr, Init):
+ has_mamba_init = True
+ break
+
+ assert has_mamba_init, "build_plan should use Init for init: random components"
+
+ def test_build_plan_strips_init_between_iterations(self, base_config_with_bias_dict):
+ """build_plan strips init between iterations (T → S conversion).
+
+ This tests that the intermediate state between surgeries has no init fields.
+ The composed plan will show Init expressions because plan composition
+ substitutes Ref → Init, but the semantics are correct: GDN is initialized
+ once (in surgery 1), not re-randomized in surgery 2.
+ """
+ from fast_llm_external_models.apriel2.conversion import (
+ compose_configs, strip_init_fields, plan_surgery, compose
+ )
+
+ # Surgery 1: Add GDN with random init
+ surgery1 = {
+ "decoder": {
+ "block": {
+ "mixer": {
+ "type": "stochastic",
+ "main_mixer_name": "attention",
+ "mixers": {
+ "attention": {"init": "transfer"},
+ "gdn": {
+ "type": "gdn",
+ "init": "random",
+ "convolution_layer": {"kernel_size": 4},
+ },
+ },
+ },
+ },
+ },
+ }
+
+ # Surgery 2: Add sliding window (doesn't mention GDN)
+ surgery2 = {
+ "decoder": {
+ "block": {
+ "mixer": {
+ "mixers": {
+ "sliding_window": {"init": "transfer", "window_size": 512},
+ },
+ },
+ },
+ },
+ }
+
+ # Simulate build_plan's iteration loop
+ s0 = base_config_with_bias_dict
+
+ # Iteration 1
+ t1 = compose_configs(s0, surgery1)
+ assert t1["decoder"]["block"]["mixer"]["mixers"]["gdn"].get("init") == "random"
+ s1 = strip_init_fields(t1)
+ assert s1["decoder"]["block"]["mixer"]["mixers"]["gdn"].get("init") is None
+
+ # Iteration 2: s1 has no init for GDN
+ t2 = compose_configs(s1, surgery2)
+ assert t2["decoder"]["block"]["mixer"]["mixers"]["gdn"].get("init") is None, \
+ "GDN should have no init in T2 (wasn't in surgery2, stripped from s1)"
+
+ # plan_surgery(s1, t2) should use Ref for GDN (transfer, not random)
+ plan2 = plan_surgery(s1, t2)
+ gdn_uses_ref = False
+ for key in plan2.target_keys():
+ if "gdn" in str(key):
+ expr = plan2.mappings[key]
+ if isinstance(expr, Ref):
+ gdn_uses_ref = True
+ break
+
+ assert gdn_uses_ref, "plan_surgery(s1, t2) should use Ref for GDN (transfer from s1)"
diff --git a/tests/data/test_tokenizer.py b/tests/data/test_tokenizer.py
index c7fdef9ca..f8f07ef0f 100644
--- a/tests/data/test_tokenizer.py
+++ b/tests/data/test_tokenizer.py
@@ -40,3 +40,125 @@ def test_tokenize_with_spans(common_tokenizer, spans, expected_token_spans, expe
expected_token_spans = [(begin + 1, end + 1) for begin, end in expected_token_spans]
Assert.eq(tokens.tolist(), expected_tokens)
Assert.eq(token_spans, expected_token_spans)
+
+
+def test_validate_chat_template_no_template(common_tokenizer):
+ """Tokenizer without chat template raises."""
+ with pytest.raises(ValueError, match="does not have a chat template"):
+ common_tokenizer.validate_chat_template()
+
+
+def test_validate_chat_template_no_markers(common_tokenizer):
+ """Tokenizer with chat template but no markers raises."""
+ common_tokenizer.tokenizer.chat_template = "{{ messages }}"
+ with pytest.raises(ValueError, match="does not contain.*generation"):
+ common_tokenizer.validate_chat_template()
+
+
+def test_validate_chat_template_with_markers(common_tokenizer):
+ """Tokenizer with generation markers validates."""
+ common_tokenizer.tokenizer.chat_template = "{% generation %}{{ m }}{% endgeneration %}"
+ common_tokenizer.validate_chat_template()
+
+
+# Realistic chat template following HF conventions (e.g., SmolLM3):
+# The generation block includes the full assistant turn: opening tag, content, and closing tag.
+# This ensures the model learns to emit the closing tag.
+CHAT_TEMPLATE = (
+ "{% for message in messages %}"
+ "{% if message.role == 'assistant' %}"
+ "{% generation %}{{ message.content }}{% endgeneration %}"
+ "{% else %}"
+ "<{{ message.role }}>{{ message.content }}{{ message.role }}>"
+ "{% endif %}"
+ "{% endfor %}"
+)
+
+
+@pytest.mark.parametrize(
+ ("messages", "expected_tokens", "expected_loss_masking_spans"),
+ (
+ # Single turn: full assistant turn (Hello) is trainable
+ # 15 tokens, trainable indices 7-13, loss mask spans cover 0-6 and 14
+ (
+ [{"role": "user", "content": "Hi"}, {"role": "assistant", "content": "Hello"}],
+ [49152, 27, 789, 29, 16946, 750, 789, 2293, 17822, 29, 7371, 750, 17822, 29, 49152],
+ [(0, 7), (14, 15)],
+ ),
+ # Multi-turn: both assistant turns are fully trainable
+ # 27 tokens, trainable indices 7-13 and 19-25
+ (
+ [
+ {"role": "user", "content": "A"},
+ {"role": "assistant", "content": "B"},
+ {"role": "user", "content": "C"},
+ {"role": "assistant", "content": "D"},
+ ],
+ [49152, 27, 789, 29, 32, 750, 789, 2293, 17822, 29, 33, 750, 17822, 2293, 789, 29, 34, 750, 789, 2293, 17822, 29, 35, 750, 17822, 29, 49152],
+ [(0, 7), (14, 19), (26, 27)],
+ ),
+ # System + user + assistant: full assistant turn trainable
+ # 23 tokens, trainable indices 15-21
+ (
+ [
+ {"role": "system", "content": "You are helpful."},
+ {"role": "user", "content": "Hi"},
+ {"role": "assistant", "content": "Hello"},
+ ],
+ [49152, 27, 3144, 29, 5815, 1139, 44569, 6928, 3144, 2293, 789, 29, 16946, 750, 789, 2293, 17822, 29, 7371, 750, 17822, 29, 49152],
+ [(0, 15), (22, 23)],
+ ),
+ # User only: no trainable tokens
+ # 9 tokens, no trainable indices
+ (
+ [{"role": "user", "content": "Hi"}],
+ [49152, 27, 789, 29, 16946, 750, 789, 29, 49152],
+ [(0, 9)],
+ ),
+ # Long multi-turn (85 tokens, 3 assistant responses with tags, tests span machinery)
+ # Trainable: indices 27-40, 49-62, 70-83
+ (
+ [
+ {"role": "system", "content": "You are a helpful assistant that answers questions."},
+ {"role": "user", "content": "What is the capital of France?"},
+ {"role": "assistant", "content": "The capital of France is Paris."},
+ {"role": "user", "content": "What about Germany?"},
+ {"role": "assistant", "content": "The capital of Germany is Berlin."},
+ {"role": "user", "content": "And Italy?"},
+ {"role": "assistant", "content": "The capital of Italy is Rome."},
+ ],
+ [49152, 27, 3144, 29, 5815, 1139, 373, 44569, 2424, 11886, 954, 15737, 14516, 6928, 3144, 2293, 789, 29, 13938, 438, 331, 25016, 457, 12409, 562, 35838, 789, 2293, 17822, 29, 2111, 25016, 457, 12409, 562, 438, 4235, 280, 6928, 17822, 2293, 789, 29, 13938, 5028, 759, 42226, 35838, 789, 2293, 17822, 29, 2111, 25016, 457, 759, 42226, 438, 29784, 3556, 6928, 17822, 2293, 789, 29, 1996, 4413, 3326, 35838, 789, 2293, 17822, 29, 2111, 25016, 457, 4413, 3326, 438, 613, 1361, 6928, 17822, 29, 49152],
+ [(0, 27), (41, 49), (63, 70), (84, 85)],
+ ),
+ ),
+)
+def test_tokenize_chat(common_tokenizer, messages, expected_tokens, expected_loss_masking_spans):
+ common_tokenizer.tokenizer.chat_template = CHAT_TEMPLATE
+ tokens, loss_masking_spans = common_tokenizer.tokenize_chat(messages)
+ Assert.eq(tokens.tolist(), expected_tokens)
+ Assert.eq(loss_masking_spans, expected_loss_masking_spans)
+
+
+@pytest.mark.parametrize(
+ ("train_mask", "expected_loss_spans"),
+ (
+ # All masked (no trainable tokens)
+ ([False, False, False], [(0, 3)]),
+ # All trainable (no spans)
+ ([True, True, True], []),
+ # Single trainable at start
+ ([True, False, False], [(1, 3)]),
+ # Single trainable at end
+ ([False, False, True], [(0, 2)]),
+ # Single trainable in middle
+ ([False, True, False], [(0, 1), (2, 3)]),
+ # Multiple trainable regions (simulates multi-turn conversation)
+ ([False, False, True, True, False, False, True, True, True, False], [(0, 2), (4, 6), (9, 10)]),
+ # Alternating
+ ([False, True, False, True, False], [(0, 1), (2, 3), (4, 5)]),
+ ),
+)
+def test_train_mask_to_loss_spans(train_mask, expected_loss_spans):
+ from fast_llm.data.preprocessing.tokenizer import _train_mask_to_loss_spans
+
+ Assert.eq(_train_mask_to_loss_spans(train_mask), expected_loss_spans)