diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md
index 50c5a2c1c..a09f78c6c 100644
--- a/.github/ISSUE_TEMPLATE/feature_request.md
+++ b/.github/ISSUE_TEMPLATE/feature_request.md
@@ -8,26 +8,26 @@ assignees: ''
---
# ๐ฏ **Goal (What & Why)**
-> **Clearly state the purpose of this feature.**
+> **Clearly state the purpose of this feature.**
> _(Example: Add FP8 support using torchao to improve training throughput by 1.5x.)_
# ๐ **Execution Plan**
-> _(This section may start as an incomplete draft but must be defined before implementation begins.)_
+> _(This section may start as an incomplete draft but must be defined before implementation begins.)_
### **Step 1: What is the smallest working version?**
-> _(Describe the simplest way to implement this feature with minimal effort.)_
+> _(Describe the simplest way to implement this feature with minimal effort.)_
-### **Step 2: What additional optimizations are possible (but optional)?**
-> _(List potential refinements that can be added in later PRs if needed.)_
+### **Step 2: What additional optimizations are possible (but optional)?**
+> _(List potential refinements that can be added in later PRs if needed.)_
# ๐ **Acceptance Criteria** (Must-Haves for Completion)
-* The feature must be **functional and tested**.
-* The implementation must be **documented in practical terms**.
-* The PR must include a **performance/impact summary**.
-* **No refactors unless directly necessary** for feature completion.
+* The feature must be **functional and tested**.
+* The implementation must be **documented in practical terms**.
+* The PR must include a **performance/impact summary**.
+* **No refactors unless directly necessary** for feature completion.
# ๐ ๏ธ **Project Management**
- [ ] **Assign the project to the Fast-LLM project.**
- [ ] **Set the `Estimate` field (in days) in the GitHub project.**
- [ ] **Use the `Size` field to categorize the PR size (Small/Medium/Large).**
-- [ ] **Assign an owner when opening the issue.**
+- [ ] **Assign an owner when opening the issue.**
diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py
index 41a2fe7ff..0ed4696da 100644
--- a/fast_llm/data/dataset/gpt/config.py
+++ b/fast_llm/data/dataset/gpt/config.py
@@ -65,7 +65,7 @@ def build(self, preprocessing: PreprocessingConfig) -> SamplableDataset[SampleTy
def _load_config(self) -> SampledDatasetConfig[SampleType]:
assert self.path.is_file(), f"File {self.path} does not exist."
config = yaml.safe_load(self.path.open("r"))
- Assert.eq(config.keys(), {"config", "metadata"})
+ # TODO: Assert.eq(config.keys(), {"config", "metadata"}) # Disabled for backward compat
if config.keys() == {"config", "metadata"}:
# Newer format with metadata
config = config["config"]
diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py
index 503b400c3..a1aadf40a 100644
--- a/fast_llm/data/preparator/gpt_memmap/config.py
+++ b/fast_llm/data/preparator/gpt_memmap/config.py
@@ -15,30 +15,81 @@
from fast_llm.data.preparator.gpt_memmap.prepare import GPTMemmapDatasetPreparator
-@config_class()
+@config_class(registry=True)
class LanguageModelSourceConfig(Config):
"""
- A schema holding the name of each relevant column in the dataset.
- Setting optional entries will enable the associated feature.
+ Abstract base class for data source schemas.
+
+ Use `type: document` (default) for documents with text, optional span annotations, and optional images.
+ Use `type: conversation` for structured chat/conversation datasets.
+ """
+
+ @classmethod
+ def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self:
+ if cls is LanguageModelSourceConfig and cls.get_subclass(default.get("type")) is None:
+ # Default to DocumentSourceConfig when type is not specified
+ return DocumentSourceConfig._from_dict(default, strict)
+ return super()._from_dict(default, strict=strict)
+
+ @functools.cached_property
+ def columns(self) -> list[str]:
+ """Columns to read from the dataset."""
+ raise NotImplementedError
+
+ @functools.cached_property
+ def has_loss_masking_span(self) -> bool:
+ return False
+
+ @functools.cached_property
+ def has_preference_spans(self) -> bool:
+ return False
+
+ @functools.cached_property
+ def has_images(self) -> bool:
+ return False
+
+
+@config_class(dynamic_type={LanguageModelSourceConfig: "document"})
+class DocumentSourceConfig(LanguageModelSourceConfig):
+ """
+ Source schema for document datasets with text, optional span annotations, and optional images.
+
+ The dataset should have a text column containing the document text.
+ Optionally, it can have additional columns for:
+ - Loss masking spans: character ranges to mask from loss computation
+ - Preference spans: chosen/rejected text for DPO training
+ - Images: image data with character positions for multimodal training
"""
text: str = Field(
default="text",
- desc="Field of the dataset to use.",
+ desc="Field containing the document text.",
+ hint=FieldHint.optional,
+ )
+ loss_masking_spans: str | None = Field(
+ default=None,
+ desc="Field containing character spans to mask for loss computation.",
hint=FieldHint.optional,
)
- loss_masking_spans: None | str = Field(
- default=None, desc="Field containing character spans to mask for loss computation", hint=FieldHint.optional
+ chosen_span: str | None = Field(
+ default=None,
+ desc="Field containing chosen text for preference optimization.",
+ hint=FieldHint.optional,
)
- chosen_span: None | str = Field(
- default=None, desc="Field containing chosen text for preference optimization", hint=FieldHint.optional
+ rejected_span: str | None = Field(
+ default=None,
+ desc="Field containing rejected text for preference optimization.",
+ hint=FieldHint.optional,
)
- rejected_span: None | str = Field(
- default=None, desc="Field containing rejected text for preference optimization", hint=FieldHint.optional
+ images: str | None = Field(
+ default=None,
+ desc="Field containing images.",
+ hint=FieldHint.optional,
)
- images: None | str = Field(default=None, desc="Field containing images", hint=FieldHint.optional)
- image_positions: None | str = Field(
- default=None, desc="Field containing image positions in the text.", hint=FieldHint.optional
+ image_positions: str | None = Field(
+ default=None,
+ desc="Field containing image positions in the text.",
+ hint=FieldHint.optional,
)
@functools.cached_property
@@ -48,6 +99,8 @@ def columns(self) -> list[str]:
columns.append(self.loss_masking_spans)
if self.has_preference_spans:
columns.extend([self.chosen_span, self.rejected_span])
+ if self.has_images:
+ columns.extend([self.images, self.image_positions])
return columns
@functools.cached_property
@@ -67,7 +120,50 @@ def has_images(self) -> bool:
def _validate(self):
super()._validate()
if self.has_preference_spans and self.has_loss_masking_span:
- raise ValueError(f"Can not enable both loss masking and preference spans.")
+ raise ValueError("Cannot enable both loss masking and preference spans.")
+
+
+@config_class(dynamic_type={LanguageModelSourceConfig: "conversation"})
+class ConversationSourceConfig(LanguageModelSourceConfig):
+ """
+ Source schema for chat/conversation datasets (e.g., Tulu 3, ShareGPT, OpenAI format).
+
+ The dataset should have a messages column containing a list of message dicts,
+ where each message has 'role' and 'content' keys. Common roles include:
+ - 'system': System prompt
+ - 'user': User input
+ - 'assistant': Model response (trained on by default)
+ - 'tool': Tool/function results
+ - 'ipython': Code execution results
+
+ The conversation is formatted using the tokenizer's chat template, which must
+ contain {% generation %}...{% endgeneration %} markers to define which content
+ to train on. Loss masking spans are automatically computed from these markers.
+
+ Example dataset format:
+ {
+ "messages": [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": "Hello!"},
+ {"role": "assistant", "content": "Hi there!"},
+ ]
+ }
+ """
+
+ messages: str = Field(
+ default="messages",
+ desc="Field containing the conversation messages list. Each message should have 'role' and 'content' keys.",
+ hint=FieldHint.core,
+ )
+
+ @functools.cached_property
+ def columns(self) -> list[str]:
+ return [self.messages]
+
+ @functools.cached_property
+ def has_loss_masking_span(self) -> bool:
+ # Conversation format always generates loss masking spans from chat template markers
+ return True
@config_class()
diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py
index e0f5f02fc..325d33c43 100644
--- a/fast_llm/data/preparator/gpt_memmap/prepare.py
+++ b/fast_llm/data/preparator/gpt_memmap/prepare.py
@@ -28,7 +28,12 @@
)
from fast_llm.data.dataset.memmap import MemmapDataset
from fast_llm.data.preparator.config import DatasetPreparator
-from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig, LanguageModelSourceConfig
+from fast_llm.data.preparator.gpt_memmap.config import (
+ ConversationSourceConfig,
+ DocumentSourceConfig,
+ GPTMemmapDatasetPreparatorConfig,
+ LanguageModelSourceConfig,
+)
from fast_llm.data.preprocessing.abstract import NullPreprocessingConfig
from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig
from fast_llm.data.preprocessing.tokenizer import Tokenizer
@@ -132,6 +137,10 @@ def run(self) -> None:
# Load tokenizer
self._tokenizer = self._config.tokenizer.get_tokenizer()
+ # Validate chat template for conversation format
+ if isinstance(self._source_schema, ConversationSourceConfig):
+ self._tokenizer.validate_chat_template()
+
# Decide the datatype based on the tokenizer vocabulary size
self._data_type = (
get_unsigned_integer_type(self._tokenizer.vocab_size)
@@ -216,92 +225,112 @@ def _preprocessing_config(self) -> LanguageModelPreprocessingConfig:
)
def _prepare_sample(self, sample: dict[str, typing.Any]) -> LanguageModelSample:
- text = sample[self._source_schema.text]
- all_spans = []
- if self._source_schema.has_loss_masking_span:
- # Spans are typically stored in the (begin, last) format. We convert to (begin, end) range format.
- loss_masking_spans = _sort_spans(
- (SpanType.loss_masking, (begin, last + 1))
- for begin, last in np.array(sample[self._source_schema.loss_masking_spans], dtype=np.int32)
- .reshape(-1, 2)
- .tolist()
+ token_spans_by_type = collections.defaultdict(list)
+ image_patches = image_token_maps = image_position_ids = patch_counts = None
+
+ if isinstance(self._source_schema, ConversationSourceConfig):
+ # Conversation format: tokenize messages and get loss masking spans from chat template
+ tokens, loss_masking_spans = self._tokenizer.tokenize_chat(
+ sample[self._source_schema.messages],
+ True,
+ True,
+ data_type=self._data_type,
)
- all_spans.extend(loss_masking_spans)
-
- if self._source_schema.has_preference_spans:
- full_chosen_text = text + sample[self._source_schema.chosen_span] + self._tokenizer.tokenizer.eos_token
- full_rejected_text = self._tokenizer.tokenizer.bos_token + text + sample[self._source_schema.rejected_span]
- # compute chosen span
- chosen_spans = [(SpanType.chosen, (len(text), len(full_chosen_text)))]
-
- # compute rejected span
- rejected_span = [
- (
- SpanType.rejected,
- (
- len(full_chosen_text) + len(self._tokenizer.tokenizer.bos_token) + len(text),
- len(full_chosen_text) + len(full_rejected_text),
- ),
+ token_spans_by_type[SpanType.loss_masking] = loss_masking_spans
+ elif isinstance(self._source_schema, DocumentSourceConfig):
+ # Document format: use the text-spans pipeline
+ text = sample[self._source_schema.text]
+ all_spans = []
+
+ if self._source_schema.has_loss_masking_span:
+ # Spans are typically stored in the (begin, last) format. We convert to (begin, end) range format.
+ loss_masking_spans = _sort_spans(
+ (SpanType.loss_masking, (begin, last + 1))
+ for begin, last in np.array(sample[self._source_schema.loss_masking_spans], dtype=np.int32)
+ .reshape(-1, 2)
+ .tolist()
+ )
+ all_spans.extend(loss_masking_spans)
+
+ if self._source_schema.has_preference_spans:
+ full_chosen_text = text + sample[self._source_schema.chosen_span] + self._tokenizer.tokenizer.eos_token
+ full_rejected_text = (
+ self._tokenizer.tokenizer.bos_token + text + sample[self._source_schema.rejected_span]
)
- ]
- # pack texts
- text = full_chosen_text + full_rejected_text
- all_spans.extend(chosen_spans + rejected_span)
-
- if self._source_schema.has_images:
- # Get the images and positions, sorted by position.
- images, image_positions = (
- zip(
- *sorted(
- zip(
- sample[self._source_schema.images],
- sample[self._source_schema.image_positions],
- strict=True,
+ # compute chosen span
+ chosen_spans = [(SpanType.chosen, (len(text), len(full_chosen_text)))]
+
+ # compute rejected span
+ rejected_span = [
+ (
+ SpanType.rejected,
+ (
+ len(full_chosen_text) + len(self._tokenizer.tokenizer.bos_token) + len(text),
+ len(full_chosen_text) + len(full_rejected_text),
),
- key=lambda x: x[1],
)
+ ]
+ # pack texts
+ text = full_chosen_text + full_rejected_text
+ all_spans.extend(chosen_spans + rejected_span)
+
+ if self._source_schema.has_images:
+ # Get the images and positions, sorted by position.
+ images, image_positions = (
+ zip(
+ *sorted(
+ zip(
+ sample[self._source_schema.images],
+ sample[self._source_schema.image_positions],
+ strict=True,
+ ),
+ key=lambda x: x[1],
+ )
+ )
+ if len(sample[self._source_schema.images]) > 0
+ else ([], [])
)
- if len(sample[self._source_schema.images]) > 0
- else ([], [])
- )
- # Get the image patches and associated data.
- image_patches, image_position_ids, image_token_maps, image_token_ids, patch_counts = (
- self._config.image_patches.get_patches_from_images(images, self._data_type)
+ # Get the image patches and associated data.
+ image_patches, image_position_ids, image_token_maps, image_token_ids, patch_counts = (
+ self._config.image_patches.get_patches_from_images(images, self._data_type)
+ )
+ patch_count_cumsum = padded_cumsum(patch_counts).tolist()
+ # Add an empty "span" at each image position so we know where to insert them in the tokenized sequence.
+ all_spans.extend([(SpanType.image, (position, position)) for position in image_positions])
+
+ # Sort the spans by location (begin), keeping track of their type.
+ # Note: overlapping spans are not supported (explicit assertion in the tokenizer).
+ span_types, spans = zip(*_sort_spans(all_spans)) if all_spans else ([], [])
+ # Tokenize the text, and determine the span locations in the tokenized text.
+ tokens, token_spans = self._tokenizer.tokenize_with_spans(
+ text, True, True, text_spans=spans, data_type=self._data_type
)
- patch_count_cumsum = padded_cumsum(patch_counts).tolist()
- # Add an empty "span" at each image position so we know where to insert them in the tokenized sequence.
- all_spans.extend([(SpanType.image, (position, position)) for position in image_positions])
-
- # Sort the spans by location (begin), keeping track of their type.
- # Note: overlapping spans are not supported (explicit assertion in the tokenizer).
- span_types, spans = zip(*_sort_spans(all_spans)) if all_spans else ([], [])
- # Tokenize the text, and determine the span locations in the tokenized text.
- tokens, token_spans = self._tokenizer.tokenize_with_spans(
- text, True, True, text_spans=spans, data_type=self._data_type
- )
- # Gather token spans by type.
- token_spans_by_type = collections.defaultdict(list)
- if self._source_schema.has_images:
- # Insert the image token ids in the token sequence and shift the spans accordingly.
- tokens_shift = 0
- image_index = 0
- for span_type, (begin, end) in zip(span_types, token_spans, strict=True):
- # Account for the tokens already inserted.
- begin = begin + tokens_shift
- end = end + tokens_shift
- if span_type == SpanType.image:
- # Shift the token map to the image location.
- image_token_maps[patch_count_cumsum[image_index] : patch_count_cumsum[image_index + 1]] += begin
- # Insert the placeholder and image break tokens.
- tokens = torch.cat([tokens[:begin], image_token_ids[image_index], tokens[begin:]])
- tokens_shift += len(image_token_ids[image_index])
- image_index += 1
- else:
- token_spans_by_type[span_type].append((begin, end))
+ # Gather token spans by type.
+ if self._source_schema.has_images:
+ # Insert the image token ids in the token sequence and shift the spans accordingly.
+ tokens_shift = 0
+ image_index = 0
+ for span_type, (begin, end) in zip(span_types, token_spans, strict=True):
+ # Account for the tokens already inserted.
+ begin = begin + tokens_shift
+ end = end + tokens_shift
+ if span_type == SpanType.image:
+ # Shift the token map to the image location.
+ image_token_maps[
+ patch_count_cumsum[image_index] : patch_count_cumsum[image_index + 1]
+ ] += begin
+ # Insert the placeholder and image break tokens.
+ tokens = torch.cat([tokens[:begin], image_token_ids[image_index], tokens[begin:]])
+ tokens_shift += len(image_token_ids[image_index])
+ image_index += 1
+ else:
+ token_spans_by_type[span_type].append((begin, end))
+ else:
+ for span_type, token_span in zip(span_types, token_spans, strict=True):
+ token_spans_by_type[span_type].append(token_span)
else:
- for span_type, token_span in zip(span_types, token_spans, strict=True):
- token_spans_by_type[span_type].append(token_span)
+ raise NotImplementedError(f"Unsupported source schema type: {type(self._source_schema)}")
sample_size = len(tokens)
diff --git a/fast_llm/data/preprocessing/tokenizer.py b/fast_llm/data/preprocessing/tokenizer.py
index abfb5b3d2..157744f51 100644
--- a/fast_llm/data/preprocessing/tokenizer.py
+++ b/fast_llm/data/preprocessing/tokenizer.py
@@ -213,3 +213,105 @@ def _remove_delimiters(
@property
def eod(self):
return self.eod_id
+
+ @staticmethod
+ def _has_generation_markers(template: str | None) -> bool:
+ """Check if a template has generation markers."""
+ return template is not None and "{% generation %}" in template
+
+ def validate_chat_template(self) -> None:
+ """
+ Validate the tokenizer's chat template has generation markers.
+
+ Raises:
+ ValueError: If the tokenizer lacks a chat template or generation markers.
+ """
+ template = self.tokenizer.chat_template
+
+ if template is None:
+ raise ValueError(
+ "Tokenizer does not have a chat template. "
+ "Conversation format requires a tokenizer with a built-in chat template "
+ "containing {% generation %}...{% endgeneration %} markers."
+ )
+
+ if not self._has_generation_markers(template):
+ raise ValueError(
+ "Tokenizer's chat template does not contain {% generation %}...{% endgeneration %} markers. "
+ "These markers are required to determine which tokens to train on. "
+ "Please use a tokenizer with generation markers in its chat template."
+ )
+
+ def tokenize_chat(
+ self,
+ messages: list[dict[str, str]],
+ begin: bool = True,
+ end: bool = True,
+ data_type: DataType = DataType.int64,
+ ) -> tuple["torch.Tensor", list[tuple[int, int]]]:
+ """
+ Apply chat template and return (tokens, loss_masking_spans).
+
+ The loss_masking_spans mark token ranges to EXCLUDE from training (where the model
+ should not learn). These are derived from the chat template's generation markers -
+ tokens outside {% generation %}...{% endgeneration %} blocks are masked.
+ """
+ import torch
+
+ result = self.tokenizer.apply_chat_template(
+ messages,
+ tokenize=True,
+ return_assistant_tokens_mask=True,
+ return_dict=True,
+ add_generation_prompt=False,
+ )
+ tokens = result["input_ids"]
+ train_mask = result["assistant_masks"]
+
+ # Prepend BOS / append EOS if not already present anywhere in the sequence.
+ # We check anywhere (not just first/last) because some chat templates add trailing
+ # whitespace after the final EOS token, e.g. "<|im_end|>\n".
+ prepend_bos = begin and self.bod_id not in tokens
+ append_eos = end and self.eod_id not in tokens
+ tokens = [self.bod_id] * prepend_bos + list(tokens) + [self.eod_id] * append_eos
+ train_mask = [False] * prepend_bos + [bool(m) for m in train_mask] + [False] * append_eos
+
+ # Convert boolean train mask to loss masking spans (spans where train_mask[i] == False)
+ loss_masking_spans = _train_mask_to_loss_spans(train_mask)
+
+ if self._config.max_vocab_size is not None:
+ tokens = (
+ torch.tensor(
+ tokens,
+ dtype=torch.int64 if len(self.tokenizer) > torch.iinfo(data_type.torch).max else data_type.torch,
+ )
+ % self._config.max_vocab_size
+ ).to(data_type.torch)
+ else:
+ tokens = torch.tensor(tokens, dtype=data_type.torch)
+ return tokens, loss_masking_spans
+
+
+def _train_mask_to_loss_spans(train_mask: list[bool]) -> list[tuple[int, int]]:
+ """
+ Convert a boolean train mask to loss masking spans.
+
+ Args:
+ train_mask: Boolean list where True = train on this token, False = don't train
+
+ Returns:
+ List of (begin, end) spans marking token ranges to EXCLUDE from training
+ (i.e., where train_mask[i] == False).
+ """
+ spans = []
+ start = None
+ for i, should_train in enumerate(train_mask):
+ if not should_train:
+ if start is None:
+ start = i
+ elif start is not None:
+ spans.append((start, i))
+ start = None
+ if start is not None:
+ spans.append((start, len(train_mask)))
+ return spans
diff --git a/fast_llm/engine/evaluation/config.py b/fast_llm/engine/evaluation/config.py
index df7ab0f51..90881cdc1 100644
--- a/fast_llm/engine/evaluation/config.py
+++ b/fast_llm/engine/evaluation/config.py
@@ -8,6 +8,7 @@
if typing.TYPE_CHECKING:
from fast_llm.engine.evaluation.evaluator import Evaluator, EvaluatorLmEval, LossEvaluator
+ from fast_llm.engine.evaluation.forward_kl.evaluator import ForwardKLEvaluator
@config_class()
@@ -119,3 +120,58 @@ def get_evaluator(
from fast_llm.engine.evaluation.lm_eval.evaluator import LmEvalEvaluator
return LmEvalEvaluator(name, self, batch_config, data_load_num_proc, train_iters)
+
+
+@config_class(dynamic_type={EvaluatorConfig: "forward_kl"})
+class ForwardKLEvaluatorConfig(EvaluatorConfig):
+ _abstract: typing.ClassVar[bool] = False
+
+ dataset_path: str | None = Field(
+ default=None,
+ desc="HuggingFace dataset path containing teacher traces.",
+ hint=FieldHint.core,
+ )
+ split: str = Field(
+ default="validation",
+ desc="Dataset split to evaluate on. Use 'train+validation' syntax to combine multiple splits.",
+ hint=FieldHint.optional,
+ )
+ seed: int = Field(
+ default=42,
+ desc="Random seed for shuffling traces. Ensures reproducible evaluation across runs.",
+ hint=FieldHint.optional,
+ )
+ num_samples: int | None = Field(
+ default=None,
+ desc="Maximum number of traces to evaluate (after shuffling). None for all.",
+ hint=FieldHint.optional,
+ valid=skip_valid_if_none(check_field(Assert.gt, 0)),
+ )
+ batch_size: int = Field(
+ default=8,
+ desc="Batch size for forward passes.",
+ hint=FieldHint.performance,
+ valid=check_field(Assert.gt, 0),
+ )
+ trust_remote_code: bool = Field(
+ default=False,
+ desc="Trust remote code when loading dataset.",
+ hint=FieldHint.optional,
+ )
+ inference_mixer: str | None = Field(
+ default=None,
+ desc="Name of the mixer to use during evaluation (for StochasticMixer models). "
+ "If None, uses the model's default main_mixer_name.",
+ hint=FieldHint.optional,
+ )
+
+ def get_evaluator(
+ self,
+ name: str,
+ batch_config: BatchConfig,
+ data_load_num_proc: int,
+ train_iters: int | None = None,
+ ) -> "ForwardKLEvaluator":
+ from fast_llm.engine.evaluation.forward_kl.evaluator import ForwardKLEvaluator
+
+ return ForwardKLEvaluator(name, self, batch_config, data_load_num_proc, train_iters)
diff --git a/fast_llm/engine/evaluation/forward_kl/__init__.py b/fast_llm/engine/evaluation/forward_kl/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/fast_llm/engine/evaluation/forward_kl/evaluator.py b/fast_llm/engine/evaluation/forward_kl/evaluator.py
new file mode 100644
index 000000000..5e69862d2
--- /dev/null
+++ b/fast_llm/engine/evaluation/forward_kl/evaluator.py
@@ -0,0 +1,451 @@
+import dataclasses
+import gc
+import hashlib
+import logging
+import math
+
+import torch
+import torch.nn.functional as F
+import tqdm
+
+from fast_llm.config import NoAutoValidate
+from fast_llm.core.distributed import ReduceOp, allreduce_scalar, safe_barrier
+from fast_llm.data.data.abstract import Data
+from fast_llm.data.sample.language_model import LanguageModelBatch, LanguageModelSample
+from fast_llm.data.sample.token import TokenSample
+from fast_llm.engine.config_utils.run import Run
+from fast_llm.engine.distributed.config import PhaseType
+from fast_llm.engine.distributed.distributed import Distributed
+from fast_llm.engine.evaluation.config import ForwardKLEvaluatorConfig
+from fast_llm.engine.evaluation.evaluator import (
+ EvaluationMetrics,
+ Evaluator,
+ EvaluatorSamplingParameters,
+ TrainingProgress,
+)
+from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel
+from fast_llm.engine.schedule.runner import ScheduleRunner
+from fast_llm.layers.attention.config import AttentionKwargs
+from fast_llm.models.gpt.config import GPTBatchConfig
+from fast_llm.models.gpt.model import GPTInferenceRunner
+
+logger = logging.getLogger(__name__)
+
+
+@dataclasses.dataclass
+class TraceTensors:
+ tokens: torch.Tensor # (num_traces, sequence_length)
+ prompt_lens: torch.Tensor # (num_traces,)
+ completion_lens: torch.Tensor # (num_traces,)
+ problem_indices: torch.Tensor # (num_traces,)
+ teacher_log_probs: torch.Tensor # (num_traces,)
+ corrects: torch.Tensor # (num_traces,)
+ num_problems: int
+ num_skipped: int
+
+ def __len__(self) -> int:
+ return self.tokens.shape[0]
+
+ @classmethod
+ def empty(cls, sequence_length: int, device: torch.device, num_skipped: int = 0) -> "TraceTensors":
+ return cls(
+ tokens=torch.empty((0, sequence_length), dtype=torch.int64, device=device),
+ prompt_lens=torch.empty(0, dtype=torch.int64, device=device),
+ completion_lens=torch.empty(0, dtype=torch.int64, device=device),
+ problem_indices=torch.empty(0, dtype=torch.int64, device=device),
+ teacher_log_probs=torch.empty(0, dtype=torch.float64, device=device),
+ corrects=torch.empty(0, dtype=torch.bool, device=device),
+ num_problems=0,
+ num_skipped=num_skipped,
+ )
+
+ @classmethod
+ def from_traces(
+ cls,
+ traces: list[dict],
+ sequence_length: int,
+ device: torch.device,
+ ) -> "TraceTensors":
+ pid_to_idx: dict[str, int] = {}
+ valid_traces: list[tuple[list[int], list[int], str, float, bool]] = []
+ num_skipped = 0
+
+ for t in traces:
+ prompt, completion = t["prompt_tokens"], t["completion_tokens"]
+ if len(prompt) + len(completion) > sequence_length:
+ num_skipped += 1
+ continue
+ valid_traces.append((prompt, completion, t["problem_id"], t["teacher_log_prob"], t["correct"]))
+
+ if not valid_traces:
+ return cls.empty(sequence_length, device, num_skipped)
+
+ n = len(valid_traces)
+ tokens = torch.zeros((n, sequence_length), dtype=torch.int64, device=device)
+ prompt_lens = torch.empty(n, dtype=torch.int64, device=device)
+ completion_lens = torch.empty(n, dtype=torch.int64, device=device)
+ problem_indices = torch.empty(n, dtype=torch.int64, device=device)
+ teacher_log_probs = torch.empty(n, dtype=torch.float64, device=device)
+ corrects = torch.empty(n, dtype=torch.bool, device=device)
+
+ for i, (prompt, completion, pid, teacher_lp, correct) in enumerate(valid_traces):
+ seq = prompt + completion
+ tokens[i, : len(seq)] = torch.tensor(seq, dtype=torch.int64, device=device)
+ prompt_lens[i] = len(prompt)
+ completion_lens[i] = len(completion)
+
+ if pid not in pid_to_idx:
+ pid_to_idx[pid] = len(pid_to_idx)
+ problem_indices[i] = pid_to_idx[pid]
+ teacher_log_probs[i] = teacher_lp
+ corrects[i] = correct
+
+ return cls(
+ tokens=tokens,
+ prompt_lens=prompt_lens,
+ completion_lens=completion_lens,
+ problem_indices=problem_indices,
+ teacher_log_probs=teacher_log_probs,
+ corrects=corrects,
+ num_problems=len(pid_to_idx),
+ num_skipped=num_skipped,
+ )
+
+
+class ForwardKLEvaluator[ConfigType: ForwardKLEvaluatorConfig](Evaluator[ConfigType]):
+ """Shard by PROBLEM (not trace) so each rank gets complete problems.
+
+ This allows computing per-problem IS metrics locally, then reducing scalars.
+ """
+
+ _inference_runner: GPTInferenceRunner
+ _sequence_length: int
+ _micro_sequence_length: int
+
+ def setup(
+ self,
+ distributed: Distributed,
+ run: Run,
+ multi_stage: FastLLMModel,
+ runner: ScheduleRunner,
+ data: Data,
+ phase: PhaseType,
+ ) -> None:
+ super().setup(distributed, run, multi_stage, runner, data, phase)
+ self._inference_runner = GPTInferenceRunner(self._multi_stage, runner=self._runner)
+ self._inference_runner.setup()
+ self._sequence_length = self._batch_config.sequence_length
+ self._micro_sequence_length = self._batch_config.micro_sequence_length
+ self._is_setup = True
+
+ def get_sampling_parameters(self) -> EvaluatorSamplingParameters | None:
+ return None
+
+ def run(
+ self,
+ training_progress: TrainingProgress | None = None,
+ run_index: int | None = None,
+ ) -> EvaluationMetrics:
+ assert self._is_setup
+ if self._config.dataset_path is None:
+ return EvaluationMetrics()
+
+ safe_barrier(self._distributed.world_group, f"forward_kl_{self._name} begin")
+ metrics = self._evaluate()
+ safe_barrier(self._distributed.world_group, f"forward_kl_{self._name} end")
+
+ if metrics["num_traces"] == 0:
+ return EvaluationMetrics()
+
+ formatted = (
+ f"IS Eval ({self._name}): "
+ f"acc={metrics['is_accuracy']:.4f}, "
+ f"ESS={metrics['mean_ess']:.2f}/{metrics['samples_per_problem']:.1f}, "
+ f"({metrics['num_problems']} problems, {metrics['num_traces']} traces)"
+ )
+ if metrics["num_skipped"] > 0:
+ formatted += f" [{metrics['num_skipped']} skipped]"
+
+ return EvaluationMetrics(
+ {f"validation.{self._name}": {k: v for k, v in metrics.items() if k != "num_skipped"}},
+ formatted,
+ )
+
+ @torch.inference_mode()
+ def _evaluate(self) -> dict[str, float]:
+ device = self._distributed.device
+ data = self._load_traces(device)
+
+ # Switch to eval mode so StochasticMixer uses the main mixer
+ # instead of randomly sampling.
+ was_training = self._multi_stage._training
+ self._multi_stage.train(False)
+
+ # Optionally override the inference mixer for StochasticMixer layers
+ stochastic_mixers: list = []
+ if self._config.inference_mixer is not None:
+ from fast_llm.layers.decoder.stochastic_mixer import StochasticMixer
+
+ for name, module in self._multi_stage.base_model.named_modules():
+ if isinstance(module, StochasticMixer):
+ stochastic_mixers.append(module)
+ module._inference_mixer_override = self._config.inference_mixer
+ logger.info(f"ForwardKL: Set {name} inference mixer to '{self._config.inference_mixer}'")
+
+ try:
+ batch_size = self._config.batch_size
+ student_log_probs_batches: list[torch.Tensor] = []
+ local_num_batches = math.ceil(len(data) / batch_size) if len(data) > 0 else 0
+
+ # Synchronize batch count across all world ranks.
+ # All ranks must execute the same number of forward passes because the forward
+ # pass involves collective operations (e.g., ZeRO all-gather) that require
+ # participation from all ranks in the process group.
+ max_num_batches = int(
+ allreduce_scalar(local_num_batches, torch.int64, self._distributed.world_group, ReduceOp.MAX)
+ )
+
+ if max_num_batches == 0:
+ return self._reduce_metrics(0.0, 0.0, 0, 0, data.num_skipped)
+
+ # Create dummy data for ranks that have no data or finish early.
+ # These ranks still need to participate in collective operations.
+ dummy_tokens = torch.zeros((batch_size, self._sequence_length), dtype=torch.int64, device=device)
+ dummy_prompt_lens = torch.ones(batch_size, dtype=torch.int64, device=device)
+ dummy_completion_lens = torch.ones(batch_size, dtype=torch.int64, device=device)
+
+ # Only show progress bar on rank 0
+ batch_iter = range(max_num_batches)
+ if self._distributed.config.rank == 0:
+ batch_iter = tqdm.tqdm(
+ batch_iter,
+ total=max_num_batches,
+ desc=f"ForwardKL ({self._name})",
+ unit="batch",
+ )
+
+ for batch_idx in batch_iter:
+ i = batch_idx * batch_size
+ if i < len(data):
+ # This rank has real data for this batch
+ batch_log_probs = self._compute_batch_log_probs(
+ data.tokens[i : i + batch_size],
+ data.prompt_lens[i : i + batch_size],
+ data.completion_lens[i : i + batch_size],
+ )
+ if batch_log_probs is not None:
+ student_log_probs_batches.append(batch_log_probs)
+ else:
+ # This rank has no more data but must still participate in collectives.
+ # Run a dummy forward pass and discard the result.
+ self._compute_batch_log_probs(dummy_tokens, dummy_prompt_lens, dummy_completion_lens)
+
+ if not student_log_probs_batches: # non-last PP rank or no local data
+ return self._reduce_metrics(0.0, 0.0, 0, 0, data.num_skipped)
+ finally:
+ # Clear inference mixer override for StochasticMixer layers
+ for module in stochastic_mixers:
+ module._inference_mixer_override = None
+
+ # Restore original training mode
+ if was_training:
+ self._multi_stage.train(True)
+
+ student_log_probs = torch.cat(student_log_probs_batches)
+ log_w = student_log_probs - data.teacher_log_probs
+
+ # Diagnostic logging with percentiles
+ pcts = torch.tensor([0.01, 0.05, 0.10, 0.25, 0.50, 0.75, 0.90, 0.95, 0.99], device=log_w.device)
+ pct_labels = ["1%", "5%", "10%", "25%", "50%", "75%", "90%", "95%", "99%"]
+
+ def fmt_percentiles(t: torch.Tensor) -> str:
+ q = torch.quantile(t.float(), pcts)
+ return ", ".join(f"{l}={v:.1f}" for l, v in zip(pct_labels, q.tolist()))
+
+ logger.info(f"student_log_probs: [{fmt_percentiles(student_log_probs)}]")
+ logger.info(f"teacher_log_probs: [{fmt_percentiles(data.teacher_log_probs)}]")
+ logger.info(f"log_w: [{fmt_percentiles(log_w)}]")
+
+ log_sum_all = self._scatter_logsumexp(log_w, data.problem_indices, data.num_problems)
+ log_w_correct = log_w.masked_fill(~data.corrects, float("-inf"))
+ log_sum_correct = self._scatter_logsumexp(log_w_correct, data.problem_indices, data.num_problems)
+
+ # IS accuracy; nan_to_num handles -inf - -inf
+ accuracy = (log_sum_correct - log_sum_all).exp().nan_to_num(0.0)
+
+ # ESS = exp(2*logsumexp(log_w) - logsumexp(2*log_w))
+ log_sum_sq = self._scatter_logsumexp(2 * log_w, data.problem_indices, data.num_problems)
+ ess = (2 * log_sum_all - log_sum_sq).exp().clamp(min=0.0)
+
+ # ESS diagnostics with percentiles
+ traces_per_problem = torch.bincount(data.problem_indices, minlength=data.num_problems)
+ multi_trace_mask = traces_per_problem > 1
+ if multi_trace_mask.any():
+ multi_ess = ess[multi_trace_mask]
+ multi_traces = traces_per_problem[multi_trace_mask]
+ logger.info(f"ESS ({multi_trace_mask.sum().item()} multi-trace problems): [{fmt_percentiles(multi_ess)}]")
+ logger.info(f"traces/problem: [{fmt_percentiles(multi_traces.float())}]")
+
+ return self._reduce_metrics(
+ accuracy.sum().item(),
+ ess.sum().item(),
+ data.num_problems,
+ len(data),
+ data.num_skipped,
+ )
+
+ def _load_traces(self, device: torch.device) -> TraceTensors:
+ import datasets
+
+ ds = datasets.load_dataset(
+ self._config.dataset_path,
+ split=self._config.split,
+ trust_remote_code=self._config.trust_remote_code,
+ )
+
+ # Shuffle needed because traces are sorted by problem
+ if self._config.num_samples and len(ds) > self._config.num_samples:
+ ds = ds.shuffle(seed=self._config.seed).select(range(self._config.num_samples))
+
+ dp_rank = self._distributed.config.data_rank
+ dp_size = self._distributed.config.data_parallel
+
+ def belongs_to_shard(example: dict) -> bool:
+ h = hashlib.md5(example["problem_id"].encode(), usedforsecurity=False).digest()
+ return int.from_bytes(h[:4], "little") % dp_size == dp_rank
+
+ ds = ds.filter(belongs_to_shard)
+ traces = list(ds)
+
+ del ds
+ gc.collect()
+
+ return TraceTensors.from_traces(traces, self._sequence_length, device)
+
+ def _compute_batch_log_probs(
+ self,
+ tokens: torch.Tensor,
+ prompt_lens: torch.Tensor,
+ completion_lens: torch.Tensor,
+ ) -> torch.Tensor | None:
+ batch_size = tokens.shape[0]
+ lm_batch = self._prepare_batch(tokens, prompt_lens, completion_lens)
+
+ with NoAutoValidate():
+ batch_config = GPTBatchConfig(
+ micro_batch_size=batch_size,
+ sequence_length=self._sequence_length,
+ micro_sequence_length=self._micro_sequence_length,
+ truncate_documents=False,
+ )
+ batch_config.setup(self._distributed.config)
+ batch_config.validate()
+
+ preprocessed_meta = self._multi_stage.base_model.preprocess_meta(batch_config, PhaseType.inference)
+ preprocessed = self._multi_stage.base_model.preprocess_batch(
+ lm_batch, preprocessed_meta, phase=PhaseType.inference, iteration=0
+ )
+
+ # Loop runs through micro-sequences; final kwargs has the logits
+ for input_, kwargs in preprocessed:
+ kwargs["global_logits"] = True
+ self._inference_runner.forward(input_, kwargs)
+
+ if "logits" not in kwargs: # non-last PP stage
+ return None
+
+ logits = kwargs["logits"]
+ if kwargs.get(AttentionKwargs.sequence_first, False):
+ logits = logits.transpose(0, 1)
+
+ device = logits.device
+ seq_len = logits.shape[1]
+
+ pred_logits = logits[:, :-1, :].contiguous()
+ targets = tokens[:, 1:].contiguous().to(device)
+
+ # Mask: completion predictions are at [prompt_len-1, prompt_len+completion_len-1)
+ mask = self._create_completion_mask(prompt_lens, completion_lens, seq_len - 1)
+
+ ce_loss = F.cross_entropy(
+ pred_logits.view(-1, pred_logits.size(-1)),
+ targets.view(-1),
+ reduction="none",
+ ).view(batch_size, seq_len - 1)
+
+ results = -(ce_loss * mask).sum(dim=1)
+
+ del logits, kwargs, preprocessed, lm_batch
+
+ return results.to(torch.float64)
+
+ def _prepare_batch(
+ self,
+ tokens: torch.Tensor,
+ prompt_lens: torch.Tensor,
+ completion_lens: torch.Tensor,
+ ) -> LanguageModelBatch:
+ samples = []
+ for i in range(tokens.shape[0]):
+ seq_len = int(prompt_lens[i].item()) + int(completion_lens[i].item())
+ sample = LanguageModelSample(TokenSample(tokens[i, :seq_len].cpu()))
+
+ pad_len = self._sequence_length - seq_len
+ if pad_len > 0:
+ sample = LanguageModelSample.from_documents([sample, sample.get_padding(pad_len)])
+
+ samples.append(sample)
+
+ return LanguageModelBatch.from_samples(samples)
+
+ def _create_completion_mask(
+ self,
+ prompt_lens: torch.Tensor,
+ completion_lens: torch.Tensor,
+ seq_len: int,
+ ) -> torch.Tensor:
+ device = prompt_lens.device
+ positions = torch.arange(seq_len, device=device)
+ start = (prompt_lens - 1).unsqueeze(1)
+ end = (prompt_lens + completion_lens - 1).unsqueeze(1)
+ return (positions >= start) & (positions < end)
+
+ def _reduce_metrics(
+ self, sum_accuracy: float, sum_ess: float, num_problems: int, num_traces: int, num_skipped: int
+ ) -> dict[str, float]:
+ group = self._distributed.world_group
+ sum_accuracy = allreduce_scalar(sum_accuracy, group=group)
+ sum_ess = allreduce_scalar(sum_ess, group=group)
+ num_problems = int(allreduce_scalar(num_problems, torch.int64, group=group))
+ num_traces = int(allreduce_scalar(num_traces, torch.int64, group=group))
+ num_skipped = int(allreduce_scalar(num_skipped, torch.int64, group=group))
+
+ if num_problems == 0:
+ return {
+ "is_accuracy": 0.0,
+ "mean_ess": 0.0,
+ "samples_per_problem": 0.0,
+ "num_traces": 0,
+ "num_problems": 0,
+ "num_skipped": num_skipped,
+ }
+
+ return {
+ "is_accuracy": sum_accuracy / num_problems,
+ "mean_ess": sum_ess / num_problems,
+ "samples_per_problem": num_traces / num_problems,
+ "num_traces": num_traces,
+ "num_problems": num_problems,
+ "num_skipped": num_skipped,
+ }
+
+ def _scatter_logsumexp(self, src: torch.Tensor, index: torch.Tensor, num_groups: int) -> torch.Tensor:
+ # Max per group for numerical stability
+ max_vals = torch.full((num_groups,), float("-inf"), device=src.device, dtype=src.dtype)
+ max_vals.scatter_reduce_(0, index, src, reduce="amax")
+
+ src_shifted = (src - max_vals[index]).exp()
+ sum_exp = torch.zeros(num_groups, device=src.device, dtype=src.dtype)
+ sum_exp.scatter_add_(0, index, src_shifted)
+
+ return max_vals + sum_exp.log()
diff --git a/fast_llm/layers/decoder/stochastic_mixer.py b/fast_llm/layers/decoder/stochastic_mixer.py
index 984f34b80..76b261a4e 100644
--- a/fast_llm/layers/decoder/stochastic_mixer.py
+++ b/fast_llm/layers/decoder/stochastic_mixer.py
@@ -106,7 +106,8 @@ def setup(self, distributed: Distributed) -> None:
def _sample_mixer_name(self, kwargs: dict[str, typing.Any]) -> str:
if not self.training:
- return self._config.main_mixer_name
+ # Allow runtime override of the inference mixer (e.g., for evaluation)
+ return getattr(self, "_inference_mixer_override", None) or self._config.main_mixer_name
generator = kwargs[StochasticMixerKwargs.generator]
mixer_idx = torch.multinomial(self._sampling_probs, num_samples=1, generator=generator).item()
diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py
index b1d0c2acd..94ddbded9 100644
--- a/fast_llm/layers/language_model/head.py
+++ b/fast_llm/layers/language_model/head.py
@@ -250,16 +250,10 @@ def _logits_cross_entropy_forward_backward_split(
input_, targets, weight, grad_output, kwargs, losses
)
if targets is None:
- # TODO: Make a proper way of returning the model output.
- loss = loss.detach()
- if kwargs.get("global_logits"):
- if self._vocab_parallel:
- loss = gather_op(loss, self._parallel_dim.group, 2)
- elif self._sequence_parallel_logits:
- loss = gather_op(
- loss, self._parallel_dim.group, 0 if kwargs[LanguageModelKwargs.sequence_first] else 1
- )
- kwargs["logits" if self._prediction_distance == 0 else f"logits_{self._prediction_distance}"] = loss
+ # global_logits: raw logits already stored and gathered in inner function
+ # non-global_logits: store scaled logits for distillation backwards compat
+ if not kwargs.get("global_logits"):
+ kwargs["logits" if self._prediction_distance == 0 else f"logits_{self._prediction_distance}"] = loss.detach()
return None, None
else:
loss = None
@@ -342,6 +336,17 @@ def _logits_cross_entropy_forward_backward(
dims = None
self._debug(logits, "logits", dims, kwargs, scale=self._config.logits_scale_factor)
+ if kwargs.get("global_logits"):
+ logits_for_storage = logits.detach()
+ if self._vocab_parallel:
+ logits_for_storage = gather_op(logits_for_storage, self._parallel_dim.group, 2)
+ elif self._sequence_parallel_logits:
+ logits_for_storage = gather_op(
+ logits_for_storage, self._parallel_dim.group, 0 if kwargs[LanguageModelKwargs.sequence_first] else 1
+ )
+ logits_key = "logits" if self._prediction_distance == 0 else f"logits_{self._prediction_distance}"
+ kwargs[logits_key] = logits_for_storage
+
if targets is None:
return logits * self._config.logits_scale_factor, None
dpo_target, lm_target, distillation_target, loss_mask = targets
diff --git a/fast_llm/models/gpt/conversion/apriel2.py b/fast_llm/models/gpt/conversion/apriel2.py
index 4ed588ed5..91e3be508 100644
--- a/fast_llm/models/gpt/conversion/apriel2.py
+++ b/fast_llm/models/gpt/conversion/apriel2.py
@@ -39,8 +39,20 @@ def import_config(cls, config: dict) -> dict:
"head_groups": config["head_groups"],
"head_size": config["head_size"],
"rotary": rotary,
- "add_linear_biases": config["add_linear_biases"],
}
+ # Per-layer bias configuration mirroring Fast-LLM structure
+ # If per-layer configs exist, use them; otherwise fall back to add_linear_biases
+ if "query_layer" in config:
+ result["query_layer"] = config["query_layer"]
+ if "key_layer" in config:
+ result["key_layer"] = config["key_layer"]
+ if "value_layer" in config:
+ result["value_layer"] = config["value_layer"]
+ if "dense_layer" in config:
+ result["dense_layer"] = config["dense_layer"]
+ # add_linear_biases serves as default for layers without explicit config
+ if "add_linear_biases" in config:
+ result["add_linear_biases"] = config["add_linear_biases"]
if "window_size" in config:
result["window_size"] = config["window_size"]
return result
@@ -58,18 +70,37 @@ def export_config(cls, config: AttentionConfig) -> dict:
else:
raise NotImplementedError(f"Unsupported rotary type: {type(config.rotary).__name__}")
- return {
+ result = {
"type": "attention",
"heads": config.heads,
"head_groups": config.head_groups,
"head_size": config.head_size,
- "add_linear_biases": config.add_linear_biases,
"rotary": {
"type": rotary_type,
"theta": config.rotary.theta,
},
"window_size": config.window_size,
}
+ # Export per-layer bias configuration
+ # Only include if explicitly set (not None)
+ if config.query_layer.bias.enabled is not None:
+ result["query_layer"] = {"bias": {"enabled": config.query_layer.bias.enabled}}
+ if config.key_layer.bias.enabled is not None:
+ result["key_layer"] = {"bias": {"enabled": config.key_layer.bias.enabled}}
+ if config.value_layer.bias.enabled is not None:
+ result["value_layer"] = {"bias": {"enabled": config.value_layer.bias.enabled}}
+ if config.dense_layer.bias.enabled is not None:
+ result["dense_layer"] = {"bias": {"enabled": config.dense_layer.bias.enabled}}
+ # add_linear_biases as fallback default
+ result["add_linear_biases"] = config.add_linear_biases
+ return result
+
+ @classmethod
+ def _get_effective_bias(cls, layer_config, default: bool) -> bool:
+ """Get effective bias setting: use layer-specific if set, else default."""
+ if layer_config.bias.enabled is not None:
+ return layer_config.bias.enabled
+ return default
@classmethod
def get_converters(
@@ -79,11 +110,20 @@ def get_converters(
hf_prefix: str,
drop_on_export: bool = False,
) -> list[WeightConverter]:
+ # Determine effective bias for each projection
+ q_bias = cls._get_effective_bias(config.query_layer, config.add_linear_biases)
+ k_bias = cls._get_effective_bias(config.key_layer, config.add_linear_biases)
+ v_bias = cls._get_effective_bias(config.value_layer, config.add_linear_biases)
+ o_bias = cls._get_effective_bias(config.dense_layer, config.add_linear_biases)
+ # For key_value, both k and v must have same bias setting
+ # (they're combined in Fast-LLM's key_value layer)
+ kv_bias = k_bias and v_bias
+
return [
*get_weight_and_bias_converters(
f"{fast_llm_prefix}.query",
f"{hf_prefix}.q_proj",
- config.add_linear_biases,
+ q_bias,
QueryWeightConverter,
config,
drop_on_export=drop_on_export,
@@ -91,7 +131,7 @@ def get_converters(
*get_weight_and_bias_converters(
f"{fast_llm_prefix}.key_value",
(f"{hf_prefix}.k_proj", f"{hf_prefix}.v_proj"),
- config.add_linear_biases,
+ kv_bias,
KeyValueWeightConverter,
config,
drop_on_export=drop_on_export,
@@ -99,7 +139,7 @@ def get_converters(
*get_weight_and_bias_converters(
f"{fast_llm_prefix}.dense",
f"{hf_prefix}.o_proj",
- config.add_linear_biases,
+ o_bias,
drop_on_export=drop_on_export,
),
]
@@ -524,6 +564,12 @@ def import_config(cls, config: dict, block_config: dict) -> dict:
"gated": mlp_config["gated"],
"add_linear_biases": mlp_config["add_linear_biases"],
}
+ # Import per-layer MLP bias settings (layer_1, layer_2)
+ for layer_name in ("layer_1", "layer_2"):
+ if layer_name in mlp_config:
+ layer_cfg = mlp_config[layer_name]
+ if "bias" in layer_cfg:
+ mlp[layer_name] = {"bias": layer_cfg["bias"]}
normalization = block_config["normalization"]
@@ -578,6 +624,11 @@ def export_config(cls, config: DecoderBlockConfig) -> dict:
"gated": config.mlp.gated,
"add_linear_biases": config.mlp.add_linear_biases,
}
+ # Export per-layer MLP bias settings (layer_1, layer_2)
+ if config.mlp.layer_1.bias.enabled is not None:
+ mlp["layer_1"] = {"bias": {"enabled": config.mlp.layer_1.bias.enabled}}
+ if config.mlp.layer_2.bias.enabled is not None:
+ mlp["layer_2"] = {"bias": {"enabled": config.mlp.layer_2.bias.enabled}}
normalization = {"type": norm_type_str, "epsilon": config.normalization.epsilon}
@@ -624,24 +675,56 @@ def get_converters(
)
)
- converters.extend(
- [
- *get_weight_and_bias_converters(
- f"{fast_llm_prefix}.mlp.layer_1",
- (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"),
- config.mlp.add_linear_biases,
- SplitWeightConverter,
- drop_on_export=drop_on_export,
- ),
- *get_weight_and_bias_converters(
- f"{fast_llm_prefix}.mlp.layer_2",
- f"{hf_prefix}.mlp.down_proj",
- config.mlp.add_linear_biases,
- MLPLayer2Converter,
- drop_on_export=drop_on_export,
- ),
- ]
- )
+ # Per-layer MLP bias: use layer-specific setting if set, else default
+ def get_mlp_layer_bias(layer_config, default: bool) -> bool:
+ if layer_config.bias.enabled is not None:
+ return layer_config.bias.enabled
+ return default
+
+ layer_1_bias = get_mlp_layer_bias(config.mlp.layer_1, config.mlp.add_linear_biases)
+ layer_2_bias = get_mlp_layer_bias(config.mlp.layer_2, config.mlp.add_linear_biases)
+
+ if config.mlp.gated:
+ # Gated MLP: gate_proj + up_proj -> layer_1 (split), down_proj -> layer_2
+ converters.extend(
+ [
+ *get_weight_and_bias_converters(
+ f"{fast_llm_prefix}.mlp.layer_1",
+ (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"),
+ layer_1_bias,
+ SplitWeightConverter,
+ drop_on_export=drop_on_export,
+ ),
+ *get_weight_and_bias_converters(
+ f"{fast_llm_prefix}.mlp.layer_2",
+ f"{hf_prefix}.mlp.down_proj",
+ layer_2_bias,
+ MLPLayer2Converter,
+ drop_on_export=drop_on_export,
+ ),
+ ]
+ )
+ else:
+ # Non-gated MLP: up_proj -> layer_1, down_proj -> layer_2
+ # Note: layer_2 still needs MLPLayer2Converter for the transpose
+ converters.extend(
+ [
+ *get_weight_and_bias_converters(
+ f"{fast_llm_prefix}.mlp.layer_1",
+ f"{hf_prefix}.mlp.up_proj",
+ layer_1_bias,
+ WeightConverter,
+ drop_on_export=drop_on_export,
+ ),
+ *get_weight_and_bias_converters(
+ f"{fast_llm_prefix}.mlp.layer_2",
+ f"{hf_prefix}.mlp.down_proj",
+ layer_2_bias,
+ MLPLayer2Converter,
+ drop_on_export=drop_on_export,
+ ),
+ ]
+ )
converters.extend(
[
diff --git a/fast_llm/models/gpt/conversion/qwen2.py b/fast_llm/models/gpt/conversion/qwen2.py
index a8bc33454..4ebf18c3a 100644
--- a/fast_llm/models/gpt/conversion/qwen2.py
+++ b/fast_llm/models/gpt/conversion/qwen2.py
@@ -1,15 +1,21 @@
import typing
from fast_llm.engine.checkpoint.config import CheckpointFormat
+from fast_llm.engine.checkpoint.external import WeightConverter
from fast_llm.layers.attention.config import AttentionConfig
+from fast_llm.layers.decoder.mlp.config import MLPConfig
from fast_llm.models.gpt.conversion.config import Qwen2CheckpointFormat
from fast_llm.models.gpt.conversion.llama import (
+ KeyValueWeightConverter,
LlamaAttentionConverter,
LlamaBaseModelConverter,
LlamaBlockConverter,
LlamaDecoderConverter,
LlamaHeadConverter,
LlamaHuggingfaceCheckpointHandler,
+ LlamaMLPConverter,
+ QueryWeightConverter,
+ get_weight_and_bias_converters,
)
from fast_llm.utils import Assert
@@ -17,6 +23,22 @@
class Qwen2AttentionConverter(LlamaAttentionConverter):
# TODO: Support sliding window with max_window_layers (need 2 kinds of block?)
+ @classmethod
+ def import_config(cls, config: dict) -> dict:
+ config["attention_bias"] = True
+ out = super().import_config(config)
+ out["query_layer"] = {"bias": {"enabled": True}}
+ out["key_layer"] = {"bias": {"enabled": True}}
+ out["value_layer"] = {"bias": {"enabled": True}}
+ out["dense_layer"] = {"bias": {"enabled": False}}
+ return out
+
+ @classmethod
+ def export_config(cls, config: AttentionConfig) -> dict:
+ out = super().export_config(config)
+ del out["attention_bias"]
+ return out
+
@classmethod
def _check_config(cls, config: AttentionConfig) -> None:
Assert.is_(type(config), AttentionConfig)
@@ -32,9 +54,56 @@ def _check_config(cls, config: AttentionConfig) -> None:
Assert.is_(config.value_layer.bias.enabled, True)
Assert.incl(config.dense_layer.bias.enabled, (None, False))
+ @classmethod
+ def get_converters(
+ cls,
+ config: AttentionConfig,
+ fast_llm_prefix: str,
+ hf_prefix: str,
+ drop_on_export: bool = False,
+ ) -> list[WeightConverter]:
+ return [
+ *get_weight_and_bias_converters(
+ f"{fast_llm_prefix}.query",
+ f"{hf_prefix}.q_proj",
+ True,
+ QueryWeightConverter,
+ config,
+ drop_on_export=drop_on_export,
+ ),
+ *get_weight_and_bias_converters(
+ f"{fast_llm_prefix}.key_value",
+ (f"{hf_prefix}.k_proj", f"{hf_prefix}.v_proj"),
+ True,
+ KeyValueWeightConverter,
+ config,
+ drop_on_export=drop_on_export,
+ ),
+ *get_weight_and_bias_converters(
+ f"{fast_llm_prefix}.dense",
+ f"{hf_prefix}.o_proj",
+ False,
+ drop_on_export=drop_on_export,
+ ),
+ ]
+
+
+class Qwen2MLPConverter(LlamaMLPConverter):
+ @classmethod
+ def import_config(cls, config: dict) -> dict:
+ config["mlp_bias"] = False
+ return super().import_config(config)
+
+ @classmethod
+ def export_config(cls, config: MLPConfig) -> dict:
+ out = super().export_config(config)
+ del out["mlp_bias"]
+ return out
+
class Qwen2BlockConverter(LlamaBlockConverter):
mixer_converter_class: typing.ClassVar[type[Qwen2AttentionConverter]] = Qwen2AttentionConverter
+ mlp_converter_class: typing.ClassVar[type[Qwen2MLPConverter]] = Qwen2MLPConverter
class Qwen2DecoderConverter(LlamaDecoderConverter):
diff --git a/fast_llm/models/multimodal/conversion/apriel2.py b/fast_llm/models/multimodal/conversion/apriel2.py
index b4147a8bf..307a67c63 100644
--- a/fast_llm/models/multimodal/conversion/apriel2.py
+++ b/fast_llm/models/multimodal/conversion/apriel2.py
@@ -326,9 +326,7 @@ class Apriel2MultimodalBaseModelConverter:
@classmethod
def import_config(cls, config: dict) -> dict:
text_config = Apriel2BaseModelConverter.import_config(config)
- vision_config = (
- cls.vision_model_converter_class.import_config(config) if "vision_encoder" in config else None
- )
+ vision_config = cls.vision_model_converter_class.import_config(config) if "vision_encoder" in config else None
result = safe_merge_dicts(
text_config,
@@ -388,10 +386,7 @@ def get_transformers_configuration_class(cls):
@classmethod
def get_model_files(cls) -> tuple[str, str, str | None]:
- from fast_llm_external_models.apriel2 import (
- configuration_apriel2,
- modeling_apriel2,
- )
+ from fast_llm_external_models.apriel2 import configuration_apriel2, modeling_apriel2
return configuration_apriel2.__file__, modeling_apriel2.__file__, None
diff --git a/fast_llm_external_models/apriel2/cache.py b/fast_llm_external_models/apriel2/cache.py
index 86c67a085..f83ae87d6 100644
--- a/fast_llm_external_models/apriel2/cache.py
+++ b/fast_llm_external_models/apriel2/cache.py
@@ -1,17 +1,22 @@
from __future__ import annotations
+
import torch
from transformers.cache_utils import Cache
class _AttentionCache:
- __slots__ = ["key", "value", "window"]
+ __slots__ = ["key", "value", "window", "cumulative_length"]
def __init__(self, window=None):
self.key = None
self.value = None
self.window = window
+ self.cumulative_length = 0
def update(self, key, value):
+ new_tokens = key.shape[-2]
+ self.cumulative_length += new_tokens
+
if self.key is None:
if self.window and key.shape[-2] > self.window:
self.key = key[..., -self.window :, :].contiguous()
@@ -35,6 +40,40 @@ def _window(self, cache, new):
return cache
return torch.cat([cache, new], -2)[..., -self.window :, :].contiguous()
+ def reset(self):
+ self.key = None
+ self.value = None
+ self.cumulative_length = 0
+
+ def reorder(self, beam_idx):
+ if self.key is not None:
+ self.key = self.key.index_select(0, beam_idx.to(self.key.device))
+ self.value = self.value.index_select(0, beam_idx.to(self.value.device))
+
+ def crop(self, max_length):
+ if self.key is not None:
+ self.key = self.key[..., :max_length, :]
+ self.value = self.value[..., :max_length, :]
+ self.cumulative_length = self.key.shape[-2]
+
+ def batch_repeat(self, repeats):
+ if self.key is not None:
+ self.key = self.key.repeat_interleave(repeats, dim=0)
+ self.value = self.value.repeat_interleave(repeats, dim=0)
+
+ def batch_select(self, indices):
+ if self.key is not None:
+ self.key = self.key.index_select(0, indices.to(self.key.device))
+ self.value = self.value.index_select(0, indices.to(self.value.device))
+
+ @property
+ def is_initialized(self):
+ return self.key is not None
+
+ @property
+ def batch_size(self):
+ return self.key.shape[0] if self.key is not None else None
+
class _SSMCache:
__slots__ = ["conv", "recurrent"]
@@ -43,6 +82,52 @@ def __init__(self):
self.conv = None
self.recurrent = None
+ def reset(self):
+ self.conv = None
+ self.recurrent = None
+
+ def reorder(self, beam_idx):
+ if self.conv is not None:
+ if isinstance(self.conv, tuple):
+ self.conv = tuple(c.index_select(0, beam_idx.to(c.device)) for c in self.conv)
+ else:
+ self.conv = self.conv.index_select(0, beam_idx.to(self.conv.device))
+ if self.recurrent is not None:
+ self.recurrent = self.recurrent.index_select(0, beam_idx.to(self.recurrent.device))
+
+ def crop(self, max_length):
+ pass # SSM caches don't have sequence dimension to crop
+
+ def batch_repeat(self, repeats):
+ if self.conv is not None:
+ if isinstance(self.conv, tuple):
+ self.conv = tuple(c.repeat_interleave(repeats, dim=0) for c in self.conv)
+ else:
+ self.conv = self.conv.repeat_interleave(repeats, dim=0)
+ if self.recurrent is not None:
+ self.recurrent = self.recurrent.repeat_interleave(repeats, dim=0)
+
+ def batch_select(self, indices):
+ if self.conv is not None:
+ if isinstance(self.conv, tuple):
+ self.conv = tuple(c.index_select(0, indices.to(c.device)) for c in self.conv)
+ else:
+ self.conv = self.conv.index_select(0, indices.to(self.conv.device))
+ if self.recurrent is not None:
+ self.recurrent = self.recurrent.index_select(0, indices.to(self.recurrent.device))
+
+ @property
+ def is_initialized(self):
+ return self.conv is not None
+
+ @property
+ def batch_size(self):
+ if self.conv is None:
+ return None
+ if isinstance(self.conv, tuple):
+ return self.conv[0].shape[0]
+ return self.conv.shape[0]
+
class _DummyCacheLayer:
pass
@@ -93,14 +178,19 @@ def set_active_mixer(self, layer_idx, mixer_name):
self.active_mixers[layer_idx] = mixer_name
def get_seq_length(self, layer_idx=0):
+ """Returns the cumulative sequence length of tokens seen by the cache.
+
+ For sliding window caches, this returns the total tokens seen (not just cached).
+ This matches HuggingFace's DynamicSlidingWindowLayer behavior.
+ """
layer = self.layers[layer_idx]
if isinstance(layer, dict):
mixer = self.active_mixers[layer_idx]
if mixer and isinstance(layer[mixer], _AttentionCache):
- return layer[mixer].key.shape[-2] if layer[mixer].key is not None else 0
+ return layer[mixer].cumulative_length
return 0
if isinstance(layer, _AttentionCache):
- return layer.key.shape[-2] if layer.key is not None else 0
+ return layer.cumulative_length
return 0
def get_max_cache_shape(self, layer_idx=0):
@@ -114,22 +204,61 @@ def get_max_cache_shape(self, layer_idx=0):
return None
def get_mask_sizes(self, cache_position, layer_idx):
+ """Return the length and offset of the cache, used to generate the attention mask.
+
+ For standard (non-sliding) attention:
+ kv_offset = 0 (KV[0] corresponds to sequence position 0)
+ kv_length = cumulative_length + query_length
+
+ For sliding window attention:
+ kv_offset = max(cumulative_length - window + 1, 0)
+ kv_length = min(cumulative_length, window - 1) + query_length
+
+ For SSM/linear layers:
+ kv_offset = 0, kv_length = query_length (no KV cache to attend to)
+ """
query_length = cache_position.shape[0]
- past_seen_tokens = self.get_seq_length(layer_idx)
- kv_length = query_length + past_seen_tokens
- kv_offset = past_seen_tokens
- return kv_length, kv_offset
+ layer = self.layers[layer_idx]
+
+ # Handle stochastic layers by getting the active mixer's cache
+ if isinstance(layer, dict):
+ mixer = self.active_mixers[layer_idx]
+ if mixer is None:
+ # No active mixer set, return defaults
+ return query_length, 0
+ cache = layer[mixer]
+ else:
+ cache = layer
+
+ # SSM layers don't have KV cache for attention mask purposes
+ if isinstance(cache, _SSMCache):
+ return query_length, 0
+
+ # Attention cache - check if sliding window
+ if isinstance(cache, _AttentionCache):
+ cumulative = cache.cumulative_length
+ window = cache.window
+
+ if window is not None:
+ # Sliding window attention
+ kv_offset = max(cumulative - window + 1, 0)
+ if cumulative >= window:
+ kv_length = window - 1 + query_length
+ else:
+ kv_length = cumulative + query_length
+ else:
+ # Full attention
+ kv_offset = 0
+ kv_length = cumulative + query_length
+
+ return kv_length, kv_offset
+
+ # Fallback
+ return query_length, 0
@property
def has_previous_state(self):
- for layer in self.layers:
- if isinstance(layer, dict):
- for cache in layer.values():
- if isinstance(cache, _SSMCache) and cache.conv is not None:
- return True
- elif isinstance(layer, _SSMCache) and layer.conv is not None:
- return True
- return False
+ return any(isinstance(cache, _SSMCache) and cache.conv is not None for cache in self._iter_caches())
@property
def key_cache(self):
@@ -147,101 +276,33 @@ def conv_states(self):
def recurrent_states(self):
return _LayerListAccessor(self, "recurrent")
- def reorder_cache(self, beam_idx):
- for i, layer in enumerate(self.layers):
+ def _iter_caches(self):
+ """Iterate over all leaf cache objects (flattening stochastic layer dicts)."""
+ for layer in self.layers:
if isinstance(layer, dict):
- for cache in layer.values():
- self._reorder_cache_obj(cache, beam_idx)
+ yield from layer.values()
else:
- self._reorder_cache_obj(layer, beam_idx)
+ yield layer
- def _reorder_cache_obj(self, cache, beam_idx):
- if isinstance(cache, _AttentionCache):
- if cache.key is not None:
- cache.key = cache.key.index_select(0, beam_idx.to(cache.key.device))
- cache.value = cache.value.index_select(0, beam_idx.to(cache.value.device))
- elif isinstance(cache, _SSMCache):
- if cache.conv is not None:
- # Handle both single tensor (GDN/Mamba) and tuple (KDA) conv states
- if isinstance(cache.conv, tuple):
- cache.conv = tuple(c.index_select(0, beam_idx.to(c.device)) for c in cache.conv)
- else:
- cache.conv = cache.conv.index_select(0, beam_idx.to(cache.conv.device))
- if cache.recurrent is not None:
- cache.recurrent = cache.recurrent.index_select(0, beam_idx.to(cache.recurrent.device))
+ def reorder_cache(self, beam_idx):
+ for cache in self._iter_caches():
+ cache.reorder(beam_idx)
def reset(self):
- for layer in self.layers:
- if isinstance(layer, dict):
- for cache in layer.values():
- self._reset_cache_obj(cache)
- else:
- self._reset_cache_obj(layer)
-
- def _reset_cache_obj(self, cache):
- if isinstance(cache, _AttentionCache):
- cache.key = None
- cache.value = None
- elif isinstance(cache, _SSMCache):
- cache.conv = None
- cache.recurrent = None
+ for cache in self._iter_caches():
+ cache.reset()
def crop(self, max_length):
- for layer in self.layers:
- if isinstance(layer, dict):
- for cache in layer.values():
- if isinstance(cache, _AttentionCache) and cache.key is not None:
- cache.key = cache.key[..., :max_length, :]
- cache.value = cache.value[..., :max_length, :]
- elif isinstance(layer, _AttentionCache) and layer.key is not None:
- layer.key = layer.key[..., :max_length, :]
- layer.value = layer.value[..., :max_length, :]
+ for cache in self._iter_caches():
+ cache.crop(max_length)
def batch_repeat_interleave(self, repeats):
- for layer in self.layers:
- if isinstance(layer, dict):
- for cache in layer.values():
- self._batch_repeat_cache_obj(cache, repeats)
- else:
- self._batch_repeat_cache_obj(layer, repeats)
-
- def _batch_repeat_cache_obj(self, cache, repeats):
- if isinstance(cache, _AttentionCache):
- if cache.key is not None:
- cache.key = cache.key.repeat_interleave(repeats, dim=0)
- cache.value = cache.value.repeat_interleave(repeats, dim=0)
- elif isinstance(cache, _SSMCache):
- if cache.conv is not None:
- # Handle both single tensor (GDN/Mamba) and tuple (KDA) conv states
- if isinstance(cache.conv, tuple):
- cache.conv = tuple(c.repeat_interleave(repeats, dim=0) for c in cache.conv)
- else:
- cache.conv = cache.conv.repeat_interleave(repeats, dim=0)
- if cache.recurrent is not None:
- cache.recurrent = cache.recurrent.repeat_interleave(repeats, dim=0)
+ for cache in self._iter_caches():
+ cache.batch_repeat(repeats)
def batch_select_indices(self, indices):
- for layer in self.layers:
- if isinstance(layer, dict):
- for cache in layer.values():
- self._batch_select_cache_obj(cache, indices)
- else:
- self._batch_select_cache_obj(layer, indices)
-
- def _batch_select_cache_obj(self, cache, indices):
- if isinstance(cache, _AttentionCache):
- if cache.key is not None:
- cache.key = cache.key.index_select(0, indices.to(cache.key.device))
- cache.value = cache.value.index_select(0, indices.to(cache.value.device))
- elif isinstance(cache, _SSMCache):
- if cache.conv is not None:
- # Handle both single tensor (GDN/Mamba) and tuple (KDA) conv states
- if isinstance(cache.conv, tuple):
- cache.conv = tuple(c.index_select(0, indices.to(c.device)) for c in cache.conv)
- else:
- cache.conv = cache.conv.index_select(0, indices.to(cache.conv.device))
- if cache.recurrent is not None:
- cache.recurrent = cache.recurrent.index_select(0, indices.to(cache.recurrent.device))
+ for cache in self._iter_caches():
+ cache.batch_select(indices)
@property
def is_compileable(self):
@@ -249,19 +310,7 @@ def is_compileable(self):
@property
def is_initialized(self):
- for layer in self.layers:
- if isinstance(layer, dict):
- for cache in layer.values():
- if isinstance(cache, _AttentionCache) and cache.key is not None:
- return True
- if isinstance(cache, _SSMCache) and cache.conv is not None:
- return True
- else:
- if isinstance(layer, _AttentionCache) and layer.key is not None:
- return True
- if isinstance(layer, _SSMCache) and layer.conv is not None:
- return True
- return False
+ return any(cache.is_initialized for cache in self._iter_caches())
@property
def is_sliding(self):
@@ -280,39 +329,20 @@ def is_sliding(self):
@property
def max_batch_size(self):
- for layer in self.layers:
- if isinstance(layer, dict):
- for cache in layer.values():
- if isinstance(cache, _AttentionCache) and cache.key is not None:
- return cache.key.shape[0]
- if isinstance(cache, _SSMCache) and cache.conv is not None:
- # Handle both single tensor and tuple conv states
- if isinstance(cache.conv, tuple):
- return cache.conv[0].shape[0]
- return cache.conv.shape[0]
- else:
- if isinstance(layer, _AttentionCache) and layer.key is not None:
- return layer.key.shape[0]
- if isinstance(layer, _SSMCache) and layer.conv is not None:
- # Handle both single tensor and tuple conv states
- if isinstance(layer.conv, tuple):
- return layer.conv[0].shape[0]
- return layer.conv.shape[0]
+ for cache in self._iter_caches():
+ bs = cache.batch_size
+ if bs is not None:
+ return bs
return None
@property
def max_cache_len(self):
- max_len = None
- for layer in self.layers:
- if isinstance(layer, dict):
- for cache in layer.values():
- if isinstance(cache, _AttentionCache):
- if cache.window is not None:
- max_len = cache.window if max_len is None else min(max_len, cache.window)
- elif isinstance(layer, _AttentionCache):
- if layer.window is not None:
- max_len = layer.window if max_len is None else min(max_len, layer.window)
- return max_len
+ windows = [
+ cache.window
+ for cache in self._iter_caches()
+ if isinstance(cache, _AttentionCache) and cache.window is not None
+ ]
+ return min(windows) if windows else None
def __len__(self):
return len(self.layers)
diff --git a/fast_llm_external_models/apriel2/conversion/__init__.py b/fast_llm_external_models/apriel2/conversion/__init__.py
index 983a632e0..2c28d1e87 100644
--- a/fast_llm_external_models/apriel2/conversion/__init__.py
+++ b/fast_llm_external_models/apriel2/conversion/__init__.py
@@ -1,88 +1,138 @@
"""Weight conversion system for Apriel2 models.
-Architecture Overview
-=====================
+Overview
+========
-This package implements a declarative weight transformation system with two
-orthogonal concerns:
+This package implements a declarative weight transformation system. The core
+abstraction separates config composition (structural) from plan execution (weights).
-1. **Config Composition** - Structural transformations of model configs
-2. **Plan Building & Execution** - Weight transformations between configs
+Conceptual Types
+================
-These concerns are intentionally separated:
-- Config composition determines WHAT the target architecture looks like
-- Plan building determines HOW weights are transformed to match
-- The `init` field bridges them: it's config metadata consumed by the plan builder
+All configs are ``dict``, but we distinguish three conceptual types:
-Key Design Decisions
-====================
+**State (S)** - A complete model config without ``init`` fields.
+ What you load from disk or save after conversion.
-**Declarative Plans**
- Plans are DATA (JSON-serializable expressions), not functions. This enables:
- - Inspection and debugging of transformations
- - Serialization for distributed execution
- - Composition via substitution rather than function composition
-
-**Separation of Config and Weights**
- The `init` field in surgery specs controls weight handling (transfer vs random)
- but does NOT affect config composition. Config composition is purely structural.
- After composition, `init` fields are stripped from complete configs.
-
-**Composition Semantics**
- Surgery specs use declarative (merge) composition, not operational (function)
- composition. For "additive" surgeries (modifying existing structure), the
- monoid action law holds. For "replacement" surgeries (defining complete new
- structure), sequential application differs from composed application by design.
-
-**Cross-Type Derivation**
- When converting between mixer types (e.g., attention โ mamba), geometric
- parameters are derived where possible:
- - attention.heads โ mamba dimensions (MIL conversion)
- - attention.heads โ gdn heads (DIL conversion)
+**Partial Surgery (P)** - An incomplete config specifying changes.
+ May contain ``init`` fields (``transfer`` or ``random``).
-Module Structure
-================
+**Transition Spec (T)** - A complete config WITH ``init`` fields.
+ The result of applying surgery to a state. Describes both target
+ structure and weight initialization mode.
+
+Algebraic Structure
+===================
+
+**Monoid**: Partial surgeries compose via deep merge::
+
+ compose_configs : P ร P โ P
+
+**Action**: Surgeries act on states to produce transition specs::
+
+ compose_configs : S ร P โ T
+ compose_configs : T ร P โ T
+
+**Extraction**: Strip init to get a state::
+
+ strip_init_fields : T โ S
+
+**Planning**: Build weight transformation from source state + transition spec::
+
+ plan_surgery : S ร T โ Plan
-- `config.py` - Config composition (compose_configs, apply_surgery)
-- `converters.py` - Plan builders (plan_surgery, plan_mil_attention_to_mamba, etc.)
-- `expr.py` - Expression types and plan class (Ref, Slice, Concat, Init, ExprPlan)
-- `executor.py` - Plan execution (StreamingExecutor, execute)
-- `io.py` - Streaming I/O (SafetensorLoader, ShardedSafetensorWriter)
-- `llava/` - Source-specific converter for Llava โ Apriel2
+The ``init`` Field
+==================
-Example Usage
+The ``init`` field in surgeries specifies weight initialization:
+
+- ``init: transfer`` โ transfer/convert weights from source
+- ``init: random`` โ randomly initialize weights
+
+This field is preserved through ``compose_configs`` so ``plan_surgery`` can read it.
+Use ``strip_init_fields`` before saving configs to disk.
+
+Typical Usage
=============
+::
+
from fast_llm_external_models.apriel2.conversion import (
compose_configs,
plan_surgery,
+ strip_init_fields,
execute,
)
- # 1. Compose configs to get target architecture
- target_config = compose_configs(source_config, surgery_spec)
+ # Load source state
+ source_state = load_config(...) # S
- # 2. Build plan for weight transformation
- plan = plan_surgery(source_config, surgery_spec)
+ # Apply surgery
+ surgery = {"decoder": {"block": {"mixer": {"type": "gdn", "init": "random"}}}} # P
+ transition = compose_configs(source_state, surgery) # T
- # 3. Execute plan to transform weights
- target_weights = execute(plan, source_weights, seed=42)
+ # Build and execute plan
+ plan = plan_surgery(source_state, transition)
+ weights = execute(plan, source_weights, seed=42)
-For streaming I/O with large models:
+ # Save (strip init first)
+ target_state = strip_init_fields(transition) # S
+ save_config(target_state)
- from fast_llm_external_models.apriel2.conversion import (
- StreamingExecutor,
- SafetensorLoader,
- ShardedSafetensorWriter,
- )
+For chained surgeries::
+
+ current_state = source_state # S
+ current_plan = identity_plan
+
+ for surgery in surgery_chain: # each P
+ transition = compose_configs(current_state, surgery) # T
+ plan = plan_surgery(current_state, transition)
+ current_plan = compose(current_plan, plan)
+ current_state = strip_init_fields(transition) # S <- IMPORTANT!
+
+**Note**: The ``strip_init_fields`` call is critical. It ensures that ``init: random``
+applies only to the surgery that introduces a component. Without stripping, subsequent
+surgeries would re-randomize existing components. See ``config.py`` docstring for details.
+
+Key Design Decisions
+====================
+
+**Declarative Plans**
+ Plans are data (expressions), not functions. Enables inspection,
+ serialization, and composition via substitution.
+
+**Inheritance Semantics**
+ When S ร P โ T, unspecified fields inherit from source.
+ Cross-type derivation maps geometry (attention.heads โ gdn.value_heads).
- with SafetensorLoader(source_files) as loader:
- executor = StreamingExecutor(plan, loader)
- with ShardedSafetensorWriter(output_dir) as writer:
- for key, tensor in executor.execute(seed=42):
- writer.add(key, tensor)
+**Additive vs Replacement Surgeries**
+ Additive surgeries (no ``type:`` declaration) satisfy the action law.
+ Replacement surgeries (explicit ``type:``) use last-write-wins.
+
+Module Structure
+================
+
+- ``config.py`` - Config composition (compose_configs, strip_init_fields)
+- ``converters.py`` - Plan builders (plan_surgery, plan_mil_attention_to_mamba)
+- ``expr.py`` - Expression types (Ref, Slice, Concat, Init, ExprPlan)
+- ``executor.py`` - Plan execution (StreamingExecutor, execute)
+- ``io.py`` - Streaming I/O (SafetensorLoader, ShardedSafetensorWriter)
"""
+# Config composition
+from fast_llm_external_models.apriel2.conversion.config import compose_configs, strip_init_fields
+
+# Plan builders (generic)
+from fast_llm_external_models.apriel2.conversion.converters import (
+ plan_dil_attention_to_gdn,
+ plan_kil_attention_to_kda,
+ plan_mil_attention_to_mamba,
+ plan_surgery,
+)
+
+# Execution
+from fast_llm_external_models.apriel2.conversion.executor import MAX_SEED, StreamingExecutor, execute
+
# Core types and plan operations
from fast_llm_external_models.apriel2.conversion.expr import (
Concat,
@@ -104,13 +154,6 @@
substitute,
)
-# Execution
-from fast_llm_external_models.apriel2.conversion.executor import (
- MAX_SEED,
- StreamingExecutor,
- execute,
-)
-
# I/O utilities
from fast_llm_external_models.apriel2.conversion.io import (
DEFAULT_MAX_SHARD_SIZE,
@@ -118,22 +161,9 @@
ShardedSafetensorWriter,
)
-# Plan builders (generic)
-from fast_llm_external_models.apriel2.conversion.converters import (
- plan_mil_attention_to_mamba,
- plan_dil_attention_to_gdn,
- plan_kil_attention_to_kda,
- plan_surgery,
-)
-
-# Config composition
-from fast_llm_external_models.apriel2.conversion.config import compose_configs
-
# Source-specific converters
-from fast_llm_external_models.apriel2.conversion.llava import (
- convert_config as convert_llava_config,
- plan_llava_to_apriel2,
-)
+from fast_llm_external_models.apriel2.conversion.llava import convert_config as convert_llava_config
+from fast_llm_external_models.apriel2.conversion.llava import plan_llava_to_apriel2
# Rendering (optional, imported lazily by ExprPlan.render_tree)
# from fast_llm_external_models.apriel2.conversion.render import render_tree
@@ -175,6 +205,7 @@
"plan_kil_attention_to_kda",
# Config composition
"compose_configs",
+ "strip_init_fields",
# Source-specific converters
"convert_llava_config",
"plan_llava_to_apriel2",
diff --git a/fast_llm_external_models/apriel2/conversion/config.py b/fast_llm_external_models/apriel2/conversion/config.py
index 48f8ff44b..3752688c1 100644
--- a/fast_llm_external_models/apriel2/conversion/config.py
+++ b/fast_llm_external_models/apriel2/conversion/config.py
@@ -1,56 +1,136 @@
"""Config composition for Apriel2 architecture transformations.
-This module handles STRUCTURAL composition of configs, independent of weight handling.
-The `init` field in surgery specs is preserved as metadata for the plan builder but
-does not affect how configs are composed.
+Conceptual Types
+================
-Composition Cases
-=================
+The system operates on three conceptual types, all represented as ``dict``:
-compose_configs(base, overlay) handles four cases based on completeness:
+**State (S)**
+ A complete structural description of a model. Has ``hidden_size`` and ``decoder``.
+ Does NOT contain ``init`` fields. Represents WHAT a model looks like.
-1. **Complete + Partial** โ Apply surgery semantics (inheritance, cross-type derivation)
-2. **Partial + Partial** โ Deep merge (monoid operation on surgery specs)
-3. **Partial + Complete** โ Overlay wins (complete config replaces partial)
-4. **Complete + Complete** โ Deep merge, then strip `init` fields
+ Example: A saved config.json, or a model you're about to transform.
-A config is "complete" if it has `hidden_size` and `decoder` (i.e., it's a full model
-config, not a surgery spec).
+**Partial Surgery (P)**
+ An incomplete config specifying fields to change. Missing ``hidden_size`` or
+ ``decoder``. May contain ``init`` fields specifying weight initialization mode.
-Surgery Semantics
-=================
+ Example: ``{"decoder": {"block": {"mixer": {"type": "gdn", "init": "random"}}}}``
-When applying a surgery spec to a complete config:
+**Transition Spec (T)**
+ A complete config WITH ``init`` fields. Describes both the target structure
+ AND how to initialize weights. This is the output of applying a surgery to
+ a state - it's a complete specification of the transformation.
-**Inheritance**
- Unspecified parameters inherit from the source config. New blocks inherit
- from the "default" block (first block in pattern, or the single fixed block).
+ Example: The result of ``compose_configs(state, surgery)`` before stripping.
-**Cross-Type Derivation**
- When changing mixer types, geometric parameters are derived where possible:
- - attention โ sliding_window: preserve heads, head_groups, head_size
- - attention โ gdn: heads โ value_heads, head_groups โ key_heads
- - attention โ mamba: derive d_inner, d_xb, dt_rank from hidden_size
- - attention โ kda: preserve heads, head_size โ head_dim
+The distinction between S and T is semantic (presence of ``init``), not structural.
+Both are "complete" in the sense of having ``hidden_size`` and ``decoder``.
-**Stochastic Mixer Composition**
- Two semantics based on whether surgery declares `type: stochastic`:
- - Replacement: surgery declares type โ only surgery's sub-mixers included
- - Additive: surgery omits type โ source sub-mixers preserved, surgery adds/modifies
+Algebraic Structure
+===================
- This distinction means the monoid action law holds for additive surgeries but
- intentionally fails for replacement surgeries (they have "last-write-wins" semantics).
+**Partial Surgeries form a Monoid (P, โ, {})**::
-The `init` Field
-================
+ compose_configs : P ร P โ P (deep merge, overlay wins)
+
+ Identity: compose_configs(p, {}) = compose_configs({}, p) = p
+ Associativity: compose_configs(compose_configs(a, b), c)
+ = compose_configs(a, compose_configs(b, c))
+
+**Surgeries act on States to produce Transition Specs**::
+
+ compose_configs : S ร P โ T (apply surgery with inheritance)
+ compose_configs : T ร P โ T (extend transition with more surgery)
+
+**Action Law (for additive surgeries)**::
+
+ compose_configs(compose_configs(s, pโ), pโ) = compose_configs(s, compose_configs(pโ, pโ))
+
+This law holds when surgeries are "additive" (modifying existing structure without
+declaring new types). For "replacement" surgeries (explicitly declaring ``type:``),
+the action law intentionally fails - this is last-write-wins semantics.
+
+**State Extraction**::
+
+ strip_init_fields : T โ S (remove init metadata for saving)
+
+Operations Summary
+==================
+
+``compose_configs(base, overlay)`` dispatches based on completeness:
+
+1. **S ร P โ T** : Apply surgery to state (inheritance, cross-type derivation)
+2. **T ร P โ T** : Extend transition spec with more surgery
+3. **P ร P โ P** : Merge partial surgeries (monoid operation)
+4. **S ร S โ S** : Merge states (deep merge, rare)
+5. **P ร S โ S** : Overlay wins (complete replaces partial)
+
+``strip_init_fields(config)`` removes all ``init`` fields, converting T โ S.
+
+Inheritance Semantics
+=====================
+
+When applying a surgery (S ร P โ T):
+
+- Unspecified fields inherit from source state
+- New decoder blocks inherit from the "default" block
+- Cross-type derivation maps geometry (attention.heads โ gdn.value_heads, etc.)
+- Stochastic mixers: additive surgery preserves source mixers, replacement replaces
+
+The ``init`` Field
+==================
+
+The ``init`` field specifies weight initialization mode for ``plan_surgery()``:
+
+- ``init: transfer`` โ transfer weights from source (possibly with conversion)
+- ``init: random`` โ randomly initialize weights
+
+**Key invariant**: ``init`` is preserved through composition so ``plan_surgery()``
+can read it. Use ``strip_init_fields()`` to obtain a pure state for:
+
+- Saving to disk (config.json should not contain ``init``)
+- Starting the next surgery iteration (current_state should be S, not T)
+
+Typical Usage Pattern
+=====================
+
+::
+
+ current_state: S = load_config(...)
+
+ for surgery: P in surgery_chain:
+ transition: T = compose_configs(current_state, surgery) # S ร P โ T
+ plan = plan_surgery(current_state, transition) # plan reads init from T
+ current_state: S = strip_init_fields(transition) # T โ S for next iteration
+
+ save_config(current_state) # S has no init fields
+
+Sequential vs Merged Surgery Application
+========================================
+
+**IMPORTANT**: Applying surgeries sequentially (with stripping) differs from merging
+surgeries first then applying once. This affects ``init`` semantics:
+
+**Sequential** (recommended)::
-The `init` field is metadata for the plan builder, NOT for config composition:
-- `init: transfer` โ plan builder creates weight transfer mappings
-- `init: random` โ plan builder creates random initialization
+ t1 = compose_configs(s, p1) # GDN gets init: random
+ s1 = strip_init_fields(t1) # GDN loses init
+ t2 = compose_configs(s1, p2) # GDN has init: None โ transfer mode
-After surgery is applied to produce a complete config, ALL `init` fields are stripped.
-This ensures configs are purely structural and plan creation is Markovian (depends only
-on current config + surgery, not on history).
+**Merged**::
+
+ merged = compose_configs(p1, p2) # GDN keeps init: random from p1
+ t = compose_configs(s, merged) # GDN has init: random โ random mode
+
+The sequential approach means ``init: random`` applies **only to the surgery that
+introduces a component**. Subsequent surgeries transfer existing weights by default.
+
+This is the intended behavior: if surgery 1 adds GDN with random init, and surgery 2
+adds sliding window (not mentioning GDN), GDN keeps its weights from surgery 1.
+
+The merged approach would re-randomize GDN in every execution, which is rarely desired.
+Always use the sequential pattern shown in "Typical Usage Pattern" above.
"""
from __future__ import annotations
@@ -65,14 +145,42 @@ def is_complete(config: dict) -> bool:
def compose_configs(base: dict, overlay: dict | None) -> dict:
- """Compose two configs.
+ """Compose configs. Dispatches based on completeness of arguments.
+
+ Type Signatures (see module docstring for S, P, T definitions)::
+
+ S ร P โ T Apply surgery to state, get transition spec
+ T ร P โ T Extend transition spec with more surgery
+ P ร P โ P Merge partial surgeries (monoid operation)
+ S ร S โ S Merge states (deep merge)
+ P ร S โ S Overlay wins
+
+ The ``init`` field is preserved in all cases. Use ``strip_init_fields()``
+ to convert T โ S for saving or iteration.
Args:
- base: Base config (complete or partial surgery spec).
- overlay: Overlay config (complete or partial surgery spec).
+ base: State (S), transition spec (T), or partial surgery (P).
+ overlay: Partial surgery (P) or state (S).
Returns:
- Composed config.
+ Composed config. Type depends on inputs (see signatures above).
+
+ Algebraic Properties:
+ Monoid: ``compose(compose(p1, p2), p3) == compose(p1, compose(p2, p3))``
+
+ Action law (additive surgeries):
+ ``compose(compose(s, p1), p2) == compose(s, compose(p1, p2))``
+
+ Example::
+
+ # S ร P โ T (apply surgery to state)
+ state = {"hidden_size": 256, "decoder": {...}}
+ surgery = {"decoder": {"block": {"mixer": {"init": "random"}}}}
+ transition = compose_configs(state, surgery) # T, has init
+
+ # Build plan, then extract state
+ plan = plan_surgery(state, transition)
+ new_state = strip_init_fields(transition) # S, no init
"""
if not overlay:
return copy.deepcopy(base)
@@ -94,9 +202,8 @@ def compose_configs(base: dict, overlay: dict | None) -> dict:
if not base_complete and overlay_complete:
return copy.deepcopy(overlay)
- # Case 4: Both complete -> deep merge
+ # Case 4: Both complete -> deep merge (init preserved for plan_surgery)
result = _deep_merge(base, overlay)
- _strip_keys(result, {"init"})
return result
@@ -128,26 +235,53 @@ def _strip_keys(config: Any, keys_to_strip: set[str]) -> None:
_strip_keys(item, keys_to_strip)
+def strip_init_fields(config: dict) -> dict:
+ """Return a copy of config with all ``init`` fields stripped (T โ S).
+
+ Converts a transition spec (T) to a state (S) by removing ``init`` metadata.
+ Use this:
+
+ 1. Before saving configs to disk (config.json should be purely structural)
+ 2. Between surgery iterations (so subsequent surgeries don't re-randomize)
+
+ See module docstring section "Sequential vs Merged Surgery Application" for
+ why stripping between iterations is critical.
+
+ Args:
+ config: Config dict (not modified). Typically a transition spec (T).
+
+ Returns:
+ A deep copy with all ``init`` fields recursively removed (a state S).
+ """
+ result = copy.deepcopy(config)
+ _strip_keys(result, {"init"})
+ return result
+
+
# =============================================================================
# Surgery application with full semantics
# =============================================================================
def apply_surgery(source_config: dict, surgery_config: dict | None) -> dict:
- """Apply surgery specification to a complete source config.
+ """Apply surgery spec to complete config (the monoid action).
- This handles:
- - Top-level scalar overrides
- - Decoder composition (fixed vs pattern)
- - Stochastic mixer sub-mixer inheritance
- - Cross-type derivation (attention โ gdn, attention โ mamba)
+ This is the internal implementation of the monoid action: surgery specs
+ acting on complete configs. Called by compose_configs when base is complete
+ and overlay is partial.
+
+ Implements inheritance semantics:
+ - Unspecified fields inherit from source
+ - Cross-type derivation maps geometry (attention โ gdn, etc.)
+ - Stochastic sub-mixers inherit from source's main mixer
+ - `init` fields are PRESERVED for plan_surgery() to see
Args:
- source_config: Complete Apriel2 config.
- surgery_config: Partial surgery specification.
+ source_config: Complete Apriel2 config (the state being acted on).
+ surgery_config: Partial surgery spec (the monoid element acting).
Returns:
- Complete Apriel2 config with surgery applied.
+ Complete config with surgery applied. `init` fields preserved.
"""
if not surgery_config:
return copy.deepcopy(source_config)
@@ -189,8 +323,9 @@ def apply_surgery(source_config: dict, surgery_config: dict | None) -> dict:
surgery_config["vision_encoder"],
)
- # Strip init keys from final result
- _strip_keys(result, {"init"})
+ # NOTE: We do NOT strip init keys here. The `init` field is preserved through
+ # composition so that plan_surgery() can see it and decide between transfer
+ # vs random initialization. The caller (convert.py) strips init before saving.
return result
@@ -392,6 +527,12 @@ def _compose_single_mixer(source: dict, surgery: dict, hidden_size: int) -> dict
result[key] = surgery[key]
elif key in source:
result[key] = source[key]
+ # Copy per-layer bias settings (query_layer, key_layer, value_layer, dense_layer)
+ for key in ["query_layer", "key_layer", "value_layer", "dense_layer", "add_linear_biases"]:
+ if key in surgery:
+ result[key] = surgery[key]
+ elif key in source:
+ result[key] = copy.deepcopy(source[key])
# Preserve init
if "init" in surgery:
result["init"] = surgery["init"]
diff --git a/fast_llm_external_models/apriel2/conversion/converters.py b/fast_llm_external_models/apriel2/conversion/converters.py
index 6d1350c54..9c9238bb0 100644
--- a/fast_llm_external_models/apriel2/conversion/converters.py
+++ b/fast_llm_external_models/apriel2/conversion/converters.py
@@ -61,16 +61,7 @@
from __future__ import annotations
-from fast_llm_external_models.apriel2.conversion.expr import (
- Concat,
- Expr,
- ExprPlan,
- Init,
- Ref,
- Slice,
- W,
-)
-
+from fast_llm_external_models.apriel2.conversion.expr import Concat, Expr, ExprPlan, Init, Ref, Slice, W
# =============================================================================
# SECTION 1: Per-Mixer Plan Functions
@@ -79,6 +70,21 @@
# This is the single source of truth for each mixer's weight schema.
+def _get_attention_bias_enabled(config: dict, layer_name: str) -> bool:
+ """Get whether bias is enabled for an attention layer.
+
+ Checks per-layer bias config (e.g., query_layer.bias.enabled).
+ Falls back to add_linear_biases if not set.
+ """
+ layer_cfg = config.get(layer_name, {})
+ bias_cfg = layer_cfg.get("bias", {})
+ enabled = bias_cfg.get("enabled")
+ if enabled is not None:
+ return enabled
+ # Fall back to add_linear_biases
+ return config.get("add_linear_biases", False)
+
+
def _plan_attention_mixer(
*,
prefix: W,
@@ -90,9 +96,13 @@ def _plan_attention_mixer(
Weight schema:
- q_proj.weight: (q_size, hidden_size)
+ - q_proj.bias: (q_size,) [if query_layer.bias.enabled]
- k_proj.weight: (kv_size, hidden_size)
+ - k_proj.bias: (kv_size,) [if key_layer.bias.enabled]
- v_proj.weight: (kv_size, hidden_size)
+ - v_proj.bias: (kv_size,) [if value_layer.bias.enabled]
- o_proj.weight: (hidden_size, q_size)
+ - o_proj.bias: (hidden_size,) [if dense_layer.bias.enabled]
Args:
prefix: Target weight path prefix.
@@ -100,12 +110,28 @@ def _plan_attention_mixer(
hidden_size: Model hidden size.
source_prefix: If provided, passthrough from source. If None, random init.
"""
+ # Check per-layer bias configuration
+ q_bias = _get_attention_bias_enabled(config, "query_layer")
+ k_bias = _get_attention_bias_enabled(config, "key_layer")
+ v_bias = _get_attention_bias_enabled(config, "value_layer")
+ o_bias = _get_attention_bias_enabled(config, "dense_layer")
+
if source_prefix is not None:
- # Passthrough
- return ExprPlan(mappings={
+ # Passthrough weights
+ mappings: dict[W, Expr] = {
prefix / proj / "weight": Ref(key=source_prefix / proj / "weight")
for proj in ["q_proj", "k_proj", "v_proj", "o_proj"]
- })
+ }
+ # Passthrough biases if enabled
+ if q_bias:
+ mappings[prefix / "q_proj" / "bias"] = Ref(key=source_prefix / "q_proj" / "bias")
+ if k_bias:
+ mappings[prefix / "k_proj" / "bias"] = Ref(key=source_prefix / "k_proj" / "bias")
+ if v_bias:
+ mappings[prefix / "v_proj" / "bias"] = Ref(key=source_prefix / "v_proj" / "bias")
+ if o_bias:
+ mappings[prefix / "o_proj" / "bias"] = Ref(key=source_prefix / "o_proj" / "bias")
+ return ExprPlan(mappings=mappings)
# Random init
heads = config["heads"]
@@ -114,12 +140,22 @@ def _plan_attention_mixer(
q_size = heads * head_size
kv_size = head_groups * head_size
- return ExprPlan(mappings={
+ mappings = {
prefix / "q_proj" / "weight": Init(shape=(q_size, hidden_size), init_type="kaiming"),
prefix / "k_proj" / "weight": Init(shape=(kv_size, hidden_size), init_type="kaiming"),
prefix / "v_proj" / "weight": Init(shape=(kv_size, hidden_size), init_type="kaiming"),
prefix / "o_proj" / "weight": Init(shape=(hidden_size, q_size), init_type="kaiming"),
- })
+ }
+ # Random init biases if enabled
+ if q_bias:
+ mappings[prefix / "q_proj" / "bias"] = Init(shape=(q_size,), init_type="zeros")
+ if k_bias:
+ mappings[prefix / "k_proj" / "bias"] = Init(shape=(kv_size,), init_type="zeros")
+ if v_bias:
+ mappings[prefix / "v_proj" / "bias"] = Init(shape=(kv_size,), init_type="zeros")
+ if o_bias:
+ mappings[prefix / "o_proj" / "bias"] = Init(shape=(hidden_size,), init_type="zeros")
+ return ExprPlan(mappings=mappings)
def _plan_mamba_mixer(
@@ -150,20 +186,22 @@ def _plan_mamba_mixer(
"""
if source_prefix is not None:
# Passthrough - include all possible weights
- return ExprPlan(mappings={
- prefix / name: Ref(key=source_prefix / name)
- for name in [
- "in_proj.weight",
- "out_proj.weight",
- "dt_in_proj.weight",
- "dt_proj.weight",
- "dt_proj.bias",
- "conv1d.weight",
- "conv1d.bias",
- "A_log",
- "D",
- ]
- })
+ return ExprPlan(
+ mappings={
+ prefix / name: Ref(key=source_prefix / name)
+ for name in [
+ "in_proj.weight",
+ "out_proj.weight",
+ "dt_in_proj.weight",
+ "dt_proj.weight",
+ "dt_proj.bias",
+ "conv1d.weight",
+ "conv1d.bias",
+ "A_log",
+ "D",
+ ]
+ }
+ )
# Random init
d_inner = config["d_inner"]
@@ -181,9 +219,7 @@ def _plan_mamba_mixer(
conv_channels = d_inner if repeat_kv_before_conv else d_xb
mappings: dict[W, Expr] = {
- prefix / "in_proj" / "weight": Init(
- shape=(2 * d_inner + 2 * d_xb, hidden_size), init_type="kaiming"
- ),
+ prefix / "in_proj" / "weight": Init(shape=(2 * d_inner + 2 * d_xb, hidden_size), init_type="kaiming"),
prefix / "out_proj" / "weight": Init(shape=(hidden_size, d_inner), init_type="kaiming"),
prefix / "dt_in_proj" / "weight": Init(shape=(dt_rank, hidden_size), init_type="kaiming"),
prefix / "dt_proj" / "weight": Init(shape=(d_inner, dt_rank), init_type="kaiming"),
@@ -230,18 +266,20 @@ def _plan_gdn_mixer(
"""
if source_prefix is not None:
# Passthrough
- return ExprPlan(mappings={
- prefix / name: Ref(key=source_prefix / name)
- for name in [
- "in_proj_qkvz.weight",
- "in_proj_ba.weight",
- "out_proj.weight",
- "convolution.weight",
- "A_log",
- "dt_bias",
- "norm.weight",
- ]
- })
+ return ExprPlan(
+ mappings={
+ prefix / name: Ref(key=source_prefix / name)
+ for name in [
+ "in_proj_qkvz.weight",
+ "in_proj_ba.weight",
+ "out_proj.weight",
+ "convolution.weight",
+ "A_log",
+ "dt_bias",
+ "norm.weight",
+ ]
+ }
+ )
# Random init
num_v_heads = config["value_heads"]
@@ -255,17 +293,19 @@ def _plan_gdn_mixer(
conv_dim = key_dim * 2 + value_dim
qkvz_size = key_dim * 2 + value_dim * 2 # Q, K both key_dim; V, Z both value_dim
- return ExprPlan(mappings={
- prefix / "in_proj_qkvz" / "weight": Init(shape=(qkvz_size, hidden_size), init_type="kaiming"),
- prefix / "in_proj_ba" / "weight": Init(shape=(num_v_heads * 2, hidden_size), init_type="zeros"),
- prefix / "out_proj" / "weight": Init(shape=(hidden_size, value_dim), init_type="kaiming"),
- prefix / "convolution" / "weight": Init(
- shape=(conv_dim, 1, conv_kernel_size), init_type="scaled_identity_conv"
- ),
- prefix / "A_log": Init(shape=(num_v_heads,), init_type="slow_decay"),
- prefix / "dt_bias": Init(shape=(num_v_heads,), init_type="zeros"),
- prefix / "norm" / "weight": Init(shape=(head_v_dim,), init_type="ones"),
- })
+ return ExprPlan(
+ mappings={
+ prefix / "in_proj_qkvz" / "weight": Init(shape=(qkvz_size, hidden_size), init_type="kaiming"),
+ prefix / "in_proj_ba" / "weight": Init(shape=(num_v_heads * 2, hidden_size), init_type="zeros"),
+ prefix / "out_proj" / "weight": Init(shape=(hidden_size, value_dim), init_type="kaiming"),
+ prefix
+ / "convolution"
+ / "weight": Init(shape=(conv_dim, 1, conv_kernel_size), init_type="scaled_identity_conv"),
+ prefix / "A_log": Init(shape=(num_v_heads,), init_type="slow_decay"),
+ prefix / "dt_bias": Init(shape=(num_v_heads,), init_type="zeros"),
+ prefix / "norm" / "weight": Init(shape=(head_v_dim,), init_type="ones"),
+ }
+ )
def _plan_kda_mixer(
@@ -298,26 +338,28 @@ def _plan_kda_mixer(
"""
if source_prefix is not None:
# Passthrough
- return ExprPlan(mappings={
- prefix / name: Ref(key=source_prefix / name)
- for name in [
- "q_proj.weight",
- "k_proj.weight",
- "v_proj.weight",
- "o_proj.weight",
- "q_conv.weight",
- "k_conv.weight",
- "v_conv.weight",
- "f_a_proj.weight",
- "f_b_proj.weight",
- "g_a_proj.weight",
- "g_b_proj.weight",
- "beta_proj.weight",
- "A_log",
- "dt_bias",
- "norm.weight",
- ]
- })
+ return ExprPlan(
+ mappings={
+ prefix / name: Ref(key=source_prefix / name)
+ for name in [
+ "q_proj.weight",
+ "k_proj.weight",
+ "v_proj.weight",
+ "o_proj.weight",
+ "q_conv.weight",
+ "k_conv.weight",
+ "v_conv.weight",
+ "f_a_proj.weight",
+ "f_b_proj.weight",
+ "g_a_proj.weight",
+ "g_b_proj.weight",
+ "beta_proj.weight",
+ "A_log",
+ "dt_bias",
+ "norm.weight",
+ ]
+ }
+ )
# Random init
num_heads = config["heads"]
@@ -325,36 +367,38 @@ def _plan_kda_mixer(
projection_size = num_heads * head_dim
conv_kernel_size = config.get("convolution_layer", {}).get("kernel_size", 4)
- return ExprPlan(mappings={
- # Main projections
- prefix / "q_proj" / "weight": Init(shape=(projection_size, hidden_size), init_type="kaiming"),
- prefix / "k_proj" / "weight": Init(shape=(projection_size, hidden_size), init_type="kaiming"),
- prefix / "v_proj" / "weight": Init(shape=(projection_size, hidden_size), init_type="kaiming"),
- prefix / "o_proj" / "weight": Init(shape=(hidden_size, projection_size), init_type="kaiming"),
- # Convolutions
- prefix / "q_conv" / "weight": Init(
- shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv"
- ),
- prefix / "k_conv" / "weight": Init(
- shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv"
- ),
- prefix / "v_conv" / "weight": Init(
- shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv"
- ),
- # Gate kernels (low-rank factorization)
- prefix / "f_a_proj" / "weight": Init(shape=(head_dim, hidden_size), init_type="kaiming"),
- prefix / "f_b_proj" / "weight": Init(shape=(projection_size, head_dim), init_type="kaiming"),
- # Output gate (low-rank factorization)
- prefix / "g_a_proj" / "weight": Init(shape=(head_dim, hidden_size), init_type="kaiming"),
- prefix / "g_b_proj" / "weight": Init(shape=(projection_size, head_dim), init_type="kaiming"),
- # Beta projection
- prefix / "beta_proj" / "weight": Init(shape=(num_heads, hidden_size), init_type="kaiming"),
- # Learnable parameters
- prefix / "A_log": Init(shape=(num_heads,), init_type="slow_decay"),
- prefix / "dt_bias": Init(shape=(projection_size,), init_type="zeros"),
- # Normalization
- prefix / "norm" / "weight": Init(shape=(head_dim,), init_type="ones"),
- })
+ return ExprPlan(
+ mappings={
+ # Main projections
+ prefix / "q_proj" / "weight": Init(shape=(projection_size, hidden_size), init_type="kaiming"),
+ prefix / "k_proj" / "weight": Init(shape=(projection_size, hidden_size), init_type="kaiming"),
+ prefix / "v_proj" / "weight": Init(shape=(projection_size, hidden_size), init_type="kaiming"),
+ prefix / "o_proj" / "weight": Init(shape=(hidden_size, projection_size), init_type="kaiming"),
+ # Convolutions
+ prefix
+ / "q_conv"
+ / "weight": Init(shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv"),
+ prefix
+ / "k_conv"
+ / "weight": Init(shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv"),
+ prefix
+ / "v_conv"
+ / "weight": Init(shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv"),
+ # Gate kernels (low-rank factorization)
+ prefix / "f_a_proj" / "weight": Init(shape=(head_dim, hidden_size), init_type="kaiming"),
+ prefix / "f_b_proj" / "weight": Init(shape=(projection_size, head_dim), init_type="kaiming"),
+ # Output gate (low-rank factorization)
+ prefix / "g_a_proj" / "weight": Init(shape=(head_dim, hidden_size), init_type="kaiming"),
+ prefix / "g_b_proj" / "weight": Init(shape=(projection_size, head_dim), init_type="kaiming"),
+ # Beta projection
+ prefix / "beta_proj" / "weight": Init(shape=(num_heads, hidden_size), init_type="kaiming"),
+ # Learnable parameters
+ prefix / "A_log": Init(shape=(num_heads,), init_type="slow_decay"),
+ prefix / "dt_bias": Init(shape=(projection_size,), init_type="zeros"),
+ # Normalization
+ prefix / "norm" / "weight": Init(shape=(head_dim,), init_type="ones"),
+ }
+ )
# Dispatcher for per-mixer plan functions
@@ -409,16 +453,13 @@ def plan_mil_attention_to_mamba(
exprs=(
Init(shape=(d_inner, hidden_size), init_type="kaiming"), # z: random
Slice(
- expr=Ref(key=source_prefix / "v_proj" / "weight"),
- slices=((0, d_xb, None), (None, None, None))
+ expr=Ref(key=source_prefix / "v_proj" / "weight"), slices=((0, d_xb, None), (None, None, None))
), # x <- V
Slice(
- expr=Ref(key=source_prefix / "k_proj" / "weight"),
- slices=((0, d_xb, None), (None, None, None))
+ expr=Ref(key=source_prefix / "k_proj" / "weight"), slices=((0, d_xb, None), (None, None, None))
), # B <- K
Slice(
- expr=Ref(key=source_prefix / "q_proj" / "weight"),
- slices=((0, d_inner, None), (None, None, None))
+ expr=Ref(key=source_prefix / "q_proj" / "weight"), slices=((0, d_inner, None), (None, None, None))
), # C <- Q
),
dim=0,
@@ -532,19 +573,21 @@ def plan_dil_attention_to_gdn(
dim=0,
)
- return ExprPlan(mappings={
- target_prefix / "in_proj_qkvz" / "weight": in_proj_qkvz_expr,
- target_prefix / "in_proj_ba" / "weight": Init(
- shape=(2 * num_v_heads, hidden_size), init_type="zeros"
- ), # b=a=0 โ ฮฒ=0.5
- target_prefix / "out_proj" / "weight": Ref(key=source_prefix / "o_proj" / "weight"),
- target_prefix / "convolution" / "weight": Init(
- shape=(conv_dim, 1, conv_kernel_size), init_type="scaled_identity_conv"
- ),
- target_prefix / "A_log": Init(shape=(num_v_heads,), init_type="slow_decay"),
- target_prefix / "dt_bias": Init(shape=(num_v_heads,), init_type="zeros"),
- target_prefix / "norm" / "weight": Init(shape=(head_v_dim,), init_type="ones"),
- })
+ return ExprPlan(
+ mappings={
+ target_prefix / "in_proj_qkvz" / "weight": in_proj_qkvz_expr,
+ target_prefix
+ / "in_proj_ba"
+ / "weight": Init(shape=(2 * num_v_heads, hidden_size), init_type="zeros"), # b=a=0 โ ฮฒ=0.5
+ target_prefix / "out_proj" / "weight": Ref(key=source_prefix / "o_proj" / "weight"),
+ target_prefix
+ / "convolution"
+ / "weight": Init(shape=(conv_dim, 1, conv_kernel_size), init_type="scaled_identity_conv"),
+ target_prefix / "A_log": Init(shape=(num_v_heads,), init_type="slow_decay"),
+ target_prefix / "dt_bias": Init(shape=(num_v_heads,), init_type="zeros"),
+ target_prefix / "norm" / "weight": Init(shape=(head_v_dim,), init_type="ones"),
+ }
+ )
def plan_kil_attention_to_kda(
@@ -595,9 +638,7 @@ def plan_kil_attention_to_kda(
for h in range(num_heads):
src_h = h % source_num_q_heads
row_start = src_h * source_head_dim
- q_slices.append(
- Slice(expr=q_ref, slices=((row_start, row_start + head_dim, None), (None, None, None)))
- )
+ q_slices.append(Slice(expr=q_ref, slices=((row_start, row_start + head_dim, None), (None, None, None))))
q_expr = Concat(exprs=tuple(q_slices), dim=0)
# K: tile source KV heads to fill target projection_size
@@ -608,9 +649,7 @@ def plan_kil_attention_to_kda(
for h in range(num_heads):
src_h = h % source_num_kv_heads
row_start = src_h * source_head_dim
- k_slices.append(
- Slice(expr=k_ref, slices=((row_start, row_start + head_dim, None), (None, None, None)))
- )
+ k_slices.append(Slice(expr=k_ref, slices=((row_start, row_start + head_dim, None), (None, None, None))))
k_expr = Concat(exprs=tuple(k_slices), dim=0)
# V: tile source KV heads to fill target projection_size
@@ -621,41 +660,41 @@ def plan_kil_attention_to_kda(
for h in range(num_heads):
src_h = h % source_num_kv_heads
row_start = src_h * source_head_dim
- v_slices.append(
- Slice(expr=v_ref, slices=((row_start, row_start + head_dim, None), (None, None, None)))
- )
+ v_slices.append(Slice(expr=v_ref, slices=((row_start, row_start + head_dim, None), (None, None, None))))
v_expr = Concat(exprs=tuple(v_slices), dim=0)
- return ExprPlan(mappings={
- # Transfer main projections
- target_prefix / "q_proj" / "weight": q_expr,
- target_prefix / "k_proj" / "weight": k_expr,
- target_prefix / "v_proj" / "weight": v_expr,
- target_prefix / "o_proj" / "weight": Ref(key=source_prefix / "o_proj" / "weight"),
- # Random init: convolutions (scaled identity for near-passthrough initially)
- target_prefix / "q_conv" / "weight": Init(
- shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv"
- ),
- target_prefix / "k_conv" / "weight": Init(
- shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv"
- ),
- target_prefix / "v_conv" / "weight": Init(
- shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv"
- ),
- # Random init: gate kernels (low-rank factorization)
- target_prefix / "f_a_proj" / "weight": Init(shape=(head_dim, hidden_size), init_type="kaiming"),
- target_prefix / "f_b_proj" / "weight": Init(shape=(projection_size, head_dim), init_type="kaiming"),
- # Random init: output gate (low-rank factorization)
- target_prefix / "g_a_proj" / "weight": Init(shape=(head_dim, hidden_size), init_type="kaiming"),
- target_prefix / "g_b_proj" / "weight": Init(shape=(projection_size, head_dim), init_type="kaiming"),
- # Random init: beta projection
- target_prefix / "beta_proj" / "weight": Init(shape=(num_heads, hidden_size), init_type="kaiming"),
- # Random init: learnable parameters
- target_prefix / "A_log": Init(shape=(num_heads,), init_type="slow_decay"),
- target_prefix / "dt_bias": Init(shape=(projection_size,), init_type="zeros"),
- # Random init: normalization
- target_prefix / "norm" / "weight": Init(shape=(head_dim,), init_type="ones"),
- })
+ return ExprPlan(
+ mappings={
+ # Transfer main projections
+ target_prefix / "q_proj" / "weight": q_expr,
+ target_prefix / "k_proj" / "weight": k_expr,
+ target_prefix / "v_proj" / "weight": v_expr,
+ target_prefix / "o_proj" / "weight": Ref(key=source_prefix / "o_proj" / "weight"),
+ # Random init: convolutions (scaled identity for near-passthrough initially)
+ target_prefix
+ / "q_conv"
+ / "weight": Init(shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv"),
+ target_prefix
+ / "k_conv"
+ / "weight": Init(shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv"),
+ target_prefix
+ / "v_conv"
+ / "weight": Init(shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv"),
+ # Random init: gate kernels (low-rank factorization)
+ target_prefix / "f_a_proj" / "weight": Init(shape=(head_dim, hidden_size), init_type="kaiming"),
+ target_prefix / "f_b_proj" / "weight": Init(shape=(projection_size, head_dim), init_type="kaiming"),
+ # Random init: output gate (low-rank factorization)
+ target_prefix / "g_a_proj" / "weight": Init(shape=(head_dim, hidden_size), init_type="kaiming"),
+ target_prefix / "g_b_proj" / "weight": Init(shape=(projection_size, head_dim), init_type="kaiming"),
+ # Random init: beta projection
+ target_prefix / "beta_proj" / "weight": Init(shape=(num_heads, hidden_size), init_type="kaiming"),
+ # Random init: learnable parameters
+ target_prefix / "A_log": Init(shape=(num_heads,), init_type="slow_decay"),
+ target_prefix / "dt_bias": Init(shape=(projection_size,), init_type="zeros"),
+ # Random init: normalization
+ target_prefix / "norm" / "weight": Init(shape=(head_dim,), init_type="ones"),
+ }
+ )
# =============================================================================
@@ -786,7 +825,70 @@ def plan_surgery(
source_config: dict,
target_config: dict,
) -> ExprPlan:
- """Build plan for Apriel2โApriel2 surgery (MIL, DIL, KIL, stochastic mixers, etc.)."""
+ """Build a weight conversion plan: S ร T โ Plan.
+
+ Creates an ExprPlan mapping target weight keys to expressions over source weights.
+ Handles same-type passthrough, cross-type conversions (MIL, DIL, KIL), and
+ stochastic mixer routing.
+
+ Type Signature::
+
+ plan_surgery : S ร T โ Plan
+
+ Where S is a state (source) and T is a transition spec (target with ``init`` fields).
+
+ The ``init`` Field
+ ------------------
+
+ The ``init`` field in ``target_config`` controls weight initialization:
+
+ - ``init: transfer`` (or absent) โ create Ref expressions (transfer from source)
+ - ``init: random`` โ create Init expressions (random initialization)
+
+ This is why ``target_config`` should be a transition spec (T) from ``compose_configs``,
+ not a stripped state (S). If ``init`` fields are missing, all components default to
+ transfer mode.
+
+ Args:
+ source_config: State (S) - complete config describing source architecture.
+ Must have hidden_size, decoder, etc. No ``init`` fields expected.
+ target_config: Transition spec (T) - complete config with ``init`` fields.
+ Use ``compose_configs(source, surgery)`` to produce this.
+
+ Returns:
+ ExprPlan mapping target weight keys to expressions over source weights.
+
+ Example::
+
+ # Apply a surgery that wraps attention in a stochastic mixer
+ surgery_spec = {
+ "decoder": {"block": {"mixer": {
+ "type": "stochastic",
+ "mixers": {
+ "attention": {"init": "transfer"},
+ "gdn": {"type": "gdn", "init": "random"},
+ }
+ }}}
+ }
+
+ # S ร P โ T
+ transition = compose_configs(source_config, surgery_spec)
+
+ # S ร T โ Plan
+ plan = plan_surgery(source_config, transition)
+
+ # Execute
+ new_weights = execute(plan, source_weights, seed=42)
+
+ # T โ S for saving
+ target_state = strip_init_fields(transition)
+
+ Note:
+ Both arguments must be complete (have hidden_size and decoder).
+ The target_config should retain ``init`` fields from the surgery spec.
+ Passing a stripped state as target will cause all components to use
+ transfer mode, which may not be intended.
+ """
hidden_size = target_config.get("hidden_size", source_config.get("hidden_size"))
assert hidden_size is not None, "hidden_size must be specified in source or target config"
@@ -804,18 +906,24 @@ def plan_surgery(
target_block = _get_block_config(target_decoder, target_layer_idx)
plan += _plan_mixer(
- target_layer_idx, source_layer_idx,
- source_block.get("mixer", {}), target_block.get("mixer", {}),
+ target_layer_idx,
+ source_layer_idx,
+ source_block.get("mixer", {}),
+ target_block.get("mixer", {}),
hidden_size,
)
plan += _plan_mlp(
- target_layer_idx, source_layer_idx,
- source_block.get("mlp", {}), target_block.get("mlp", {}),
+ target_layer_idx,
+ source_layer_idx,
+ source_block.get("mlp", {}),
+ target_block.get("mlp", {}),
hidden_size,
)
plan += _plan_norms(
- target_layer_idx, source_layer_idx,
- source_block, target_block,
+ target_layer_idx,
+ source_layer_idx,
+ source_block,
+ target_block,
hidden_size,
)
@@ -839,14 +947,16 @@ def _plan_non_decoder_weights(config: dict) -> ExprPlan:
embed = W("model", "embed_tokens", "weight")
mappings[embed] = Ref(key=embed)
- head = W("lm_head", "weight")
- mappings[head] = Ref(key=head)
+ # lm_head only if not tied to embeddings
+ if not config.get("tie_word_embeddings", False):
+ head = W("lm_head", "weight")
+ mappings[head] = Ref(key=head)
norm = W("model", "norm", "weight")
mappings[norm] = Ref(key=norm)
- if "vision_encoder" in config:
- vision_config = config["vision_encoder"]
+ vision_config = config.get("vision_encoder")
+ if vision_config:
vision = W("model", "vision_encoder")
patch_emb = vision / "embeddings" / "patch_embeddings" / "weight"
@@ -950,9 +1060,13 @@ def _plan_mixer(
source_prefix = source_mixer_base
plan += _plan_mixer_transfer(
- matched_source_type, sub_type,
- matched_source, sub_config,
- source_prefix, target_prefix, hidden_size,
+ matched_source_type,
+ sub_type,
+ matched_source,
+ sub_config,
+ source_prefix,
+ target_prefix,
+ hidden_size,
)
# Passthrough source sub-mixers not in target spec
@@ -963,8 +1077,13 @@ def _plan_mixer(
source_prefix = source_layer / "mixer" / "mixers" / sub_name
target_prefix = target_layer / "mixer" / "mixers" / sub_name
plan += _plan_mixer_transfer(
- sub_type, sub_type, sub_config, sub_config,
- source_prefix, target_prefix, hidden_size,
+ sub_type,
+ sub_type,
+ sub_config,
+ sub_config,
+ source_prefix,
+ target_prefix,
+ hidden_size,
)
return plan
@@ -980,12 +1099,34 @@ def _plan_mixer(
source_prefix = source_layer / "mixer"
return _plan_mixer_transfer(
- main_source_type, target_type,
- main_source, target_mixer,
- source_prefix, target_prefix, hidden_size,
+ main_source_type,
+ target_type,
+ main_source,
+ target_mixer,
+ source_prefix,
+ target_prefix,
+ hidden_size,
)
+def _get_mlp_bias_enabled(config: dict, layer_name: str) -> bool:
+ """Get whether bias is enabled for an MLP layer.
+
+ Checks per-layer bias config (e.g., layer_1.bias.enabled, layer_2.bias.enabled).
+ Falls back to add_linear_biases if not set.
+
+ Note: layer_1 corresponds to gate_proj and up_proj (gated MLP) or just up_proj (non-gated)
+ layer_2 corresponds to down_proj
+ """
+ layer_cfg = config.get(layer_name, {})
+ bias_cfg = layer_cfg.get("bias", {})
+ enabled = bias_cfg.get("enabled")
+ if enabled is not None:
+ return enabled
+ # Fall back to add_linear_biases
+ return config.get("add_linear_biases", False)
+
+
def _plan_mlp(
target_layer_idx: int,
source_layer_idx: int,
@@ -1006,7 +1147,7 @@ def _plan_mlp_transfer(
target_mlp: dict,
hidden_size: int,
) -> ExprPlan:
- """Passthrough for MLP weights."""
+ """Passthrough for MLP weights and biases."""
source_mlp_path = W("model", "decoder", "blocks", source_layer_idx, "mlp")
target_mlp_path = W("model", "decoder", "blocks", target_layer_idx, "mlp")
@@ -1019,10 +1160,36 @@ def _plan_mlp_transfer(
f"Use 'init: random' to initialize randomly."
)
- return ExprPlan(mappings={
- target_mlp_path / proj / "weight": Ref(key=source_mlp_path / proj / "weight")
- for proj in ["gate_proj", "up_proj", "down_proj"]
- })
+ # Check per-layer bias configuration
+ layer_1_bias = _get_mlp_bias_enabled(target_mlp, "layer_1")
+ layer_2_bias = _get_mlp_bias_enabled(target_mlp, "layer_2")
+
+ # Check if gated MLP (has gate_proj) or non-gated (only up_proj)
+ gated = target_mlp.get("gated", True) # Default to gated for backwards compatibility
+
+ # Passthrough weights
+ # layer_1 = gate_proj + up_proj (gated) or just up_proj (non-gated)
+ # layer_2 = down_proj
+ if gated:
+ weight_projs = ["gate_proj", "up_proj", "down_proj"]
+ else:
+ weight_projs = ["up_proj", "down_proj"]
+
+ mappings: dict[W, Expr] = {
+ target_mlp_path / proj / "weight": Ref(key=source_mlp_path / proj / "weight") for proj in weight_projs
+ }
+
+ # Passthrough biases if enabled
+ if layer_1_bias:
+ if gated:
+ mappings[target_mlp_path / "gate_proj" / "bias"] = Ref(key=source_mlp_path / "gate_proj" / "bias")
+ mappings[target_mlp_path / "up_proj" / "bias"] = Ref(key=source_mlp_path / "up_proj" / "bias")
+
+ # layer_2 = down_proj
+ if layer_2_bias:
+ mappings[target_mlp_path / "down_proj" / "bias"] = Ref(key=source_mlp_path / "down_proj" / "bias")
+
+ return ExprPlan(mappings=mappings)
def _plan_random_mlp(
@@ -1030,20 +1197,41 @@ def _plan_random_mlp(
target_mlp: dict,
hidden_size: int,
) -> ExprPlan:
- """Random initialization for MLP."""
+ """Random initialization for MLP weights and biases."""
target_mlp_path = W("model", "decoder", "blocks", target_layer_idx, "mlp")
intermediate_size = target_mlp["intermediate_size"]
- return ExprPlan(mappings={
- target_mlp_path / "gate_proj" / "weight": Init(
- shape=(intermediate_size, hidden_size), init_type="kaiming"
- ),
- target_mlp_path / "up_proj" / "weight": Init(
+
+ # Check per-layer bias configuration
+ layer_1_bias = _get_mlp_bias_enabled(target_mlp, "layer_1")
+ layer_2_bias = _get_mlp_bias_enabled(target_mlp, "layer_2")
+
+ # Check if gated MLP (has gate_proj) or non-gated (only up_proj)
+ gated = target_mlp.get("gated", True) # Default to gated for backwards compatibility
+
+ # Random init weights
+ mappings: dict[W, Expr] = {}
+ if gated:
+ mappings[target_mlp_path / "gate_proj" / "weight"] = Init(
shape=(intermediate_size, hidden_size), init_type="kaiming"
- ),
- target_mlp_path / "down_proj" / "weight": Init(
- shape=(hidden_size, intermediate_size), init_type="kaiming"
- ),
- })
+ )
+ mappings[target_mlp_path / "up_proj" / "weight"] = Init(
+ shape=(intermediate_size, hidden_size), init_type="kaiming"
+ )
+ mappings[target_mlp_path / "down_proj" / "weight"] = Init(
+ shape=(hidden_size, intermediate_size), init_type="kaiming"
+ )
+
+ # Random init biases if enabled
+ if layer_1_bias:
+ if gated:
+ mappings[target_mlp_path / "gate_proj" / "bias"] = Init(shape=(intermediate_size,), init_type="zeros")
+ mappings[target_mlp_path / "up_proj" / "bias"] = Init(shape=(intermediate_size,), init_type="zeros")
+
+ # layer_2 = down_proj
+ if layer_2_bias:
+ mappings[target_mlp_path / "down_proj" / "bias"] = Init(shape=(hidden_size,), init_type="zeros")
+
+ return ExprPlan(mappings=mappings)
def _plan_norms(
@@ -1083,10 +1271,12 @@ def _plan_norms_transfer(
f"Use 'init: random' to initialize randomly."
)
- return ExprPlan(mappings={
- target_layer / norm_name / "weight": Ref(key=source_layer / norm_name / "weight")
- for norm_name in ["input_layernorm", "post_attention_layernorm"]
- })
+ return ExprPlan(
+ mappings={
+ target_layer / norm_name / "weight": Ref(key=source_layer / norm_name / "weight")
+ for norm_name in ["input_layernorm", "post_attention_layernorm"]
+ }
+ )
def _plan_random_norms(
@@ -1095,7 +1285,9 @@ def _plan_random_norms(
) -> ExprPlan:
"""Random initialization for normalization layers."""
target_layer = W("model", "decoder", "blocks", target_layer_idx)
- return ExprPlan(mappings={
- target_layer / norm_name / "weight": Init(shape=(hidden_size,), init_type="ones")
- for norm_name in ["input_layernorm", "post_attention_layernorm"]
- })
+ return ExprPlan(
+ mappings={
+ target_layer / norm_name / "weight": Init(shape=(hidden_size,), init_type="ones")
+ for norm_name in ["input_layernorm", "post_attention_layernorm"]
+ }
+ )
diff --git a/fast_llm_external_models/apriel2/conversion/executor.py b/fast_llm_external_models/apriel2/conversion/executor.py
index a6c5672f0..b0779c97f 100644
--- a/fast_llm_external_models/apriel2/conversion/executor.py
+++ b/fast_llm_external_models/apriel2/conversion/executor.py
@@ -29,7 +29,8 @@
from __future__ import annotations
import hashlib
-from typing import Callable, Iterator
+from collections.abc import Iterator
+from typing import Callable
import torch
from torch import Tensor
@@ -81,8 +82,7 @@ def execute(
break
else:
raise ValueError(
- "Cannot infer device/dtype: plan has no source references. "
- "Provide device and dtype explicitly."
+ "Cannot infer device/dtype: plan has no source references. " "Provide device and dtype explicitly."
)
generator = torch.Generator(device=device)
@@ -94,10 +94,7 @@ def execute(
# Verify device/dtype consistency
for key, tensor in sources.items():
if tensor.device != device or tensor.dtype != dtype:
- raise ValueError(
- f"Source {key} has {tensor.device}/{tensor.dtype}, "
- f"expected {device}/{dtype}"
- )
+ raise ValueError(f"Source {key} has {tensor.device}/{tensor.dtype}, " f"expected {device}/{dtype}")
# Deterministic per-target seed
key_offset = int(hashlib.md5(str(target_key).encode()).hexdigest()[:8], 16)
diff --git a/fast_llm_external_models/apriel2/conversion/expr.py b/fast_llm_external_models/apriel2/conversion/expr.py
index 4867a27ae..34ea106fc 100644
--- a/fast_llm_external_models/apriel2/conversion/expr.py
+++ b/fast_llm_external_models/apriel2/conversion/expr.py
@@ -52,7 +52,8 @@
import math
from collections import defaultdict
-from typing import Annotated, Any, Callable, Iterator, Literal, TypedDict, Union, Unpack
+from collections.abc import Iterator
+from typing import Annotated, Any, Callable, Literal, TypedDict, Union, Unpack
import torch
from pydantic import BaseModel, ConfigDict, Field, GetCoreSchemaHandler, TypeAdapter
@@ -60,7 +61,6 @@
from pydantic_core import CoreSchema, core_schema
from torch import Tensor
-
# =============================================================================
# Weight Path Builder
# =============================================================================
@@ -78,7 +78,7 @@ class W(str):
mappings[q] = Ref(key=source_q)
"""
- def __new__(cls, *parts) -> "W":
+ def __new__(cls, *parts) -> W:
# Join parts, stripping any leading/trailing dots from each
cleaned = []
for p in parts:
@@ -89,12 +89,12 @@ def __new__(cls, *parts) -> "W":
cleaned.append(s)
return super().__new__(cls, ".".join(cleaned))
- def __truediv__(self, other) -> "W":
+ def __truediv__(self, other) -> W:
if isinstance(other, (list, tuple)):
return W(self, *other)
return W(self, other)
- def __rtruediv__(self, other) -> "W":
+ def __rtruediv__(self, other) -> W:
return W(other, self)
@classmethod
@@ -156,7 +156,7 @@ class Slice(BaseModel):
model_config = ConfigDict(frozen=True)
type: Literal["slice"] = "slice"
- expr: "Expr"
+ expr: Expr
slices: tuple[tuple[int | None, int | None, int | None], ...]
def find_refs(self) -> set[W]:
@@ -184,7 +184,7 @@ class Concat(BaseModel):
model_config = ConfigDict(frozen=True)
type: Literal["concat"] = "concat"
- exprs: tuple["Expr", ...]
+ exprs: tuple[Expr, ...]
dim: int = 0
def find_refs(self) -> set[W]:
@@ -303,7 +303,7 @@ class Reshape(BaseModel):
model_config = ConfigDict(frozen=True)
type: Literal["reshape"] = "reshape"
- expr: "Expr"
+ expr: Expr
shape: tuple[int, ...]
def find_refs(self) -> set[W]:
@@ -442,10 +442,10 @@ def __getitem__(self, key: W) -> Expr:
def __contains__(self, key: W) -> bool:
return key in self.mappings
- def __or__(self, other: "ExprPlan") -> "ExprPlan":
+ def __or__(self, other: ExprPlan) -> ExprPlan:
return compose(self, other)
- def __add__(self, other: "ExprPlan") -> "ExprPlan":
+ def __add__(self, other: ExprPlan) -> ExprPlan:
return merge(self, other)
def source_keys(self) -> set[str]:
@@ -471,7 +471,7 @@ def summary(self) -> dict[str, Any]:
"metadata": self.metadata,
}
- def fuse(self) -> "ExprPlan":
+ def fuse(self) -> ExprPlan:
return ExprPlan(
mappings={k: fuse(v) for k, v in self.mappings.items()},
source_format=self.source_format,
diff --git a/fast_llm_external_models/apriel2/conversion/io.py b/fast_llm_external_models/apriel2/conversion/io.py
index e1a261d7e..1f64df0b9 100644
--- a/fast_llm_external_models/apriel2/conversion/io.py
+++ b/fast_llm_external_models/apriel2/conversion/io.py
@@ -62,7 +62,7 @@ def __init__(self, files: list[Path], device: str = "cpu"):
self._handles: dict[Path, Any] = {}
self._key_index: dict[str, Path] = {}
- def __enter__(self) -> "SafetensorLoader":
+ def __enter__(self) -> SafetensorLoader:
# Pre-build index: key -> file (one-time O(nรm), then O(1) lookups)
for f in self.files:
handle = safe_open(f, framework="pt", device=self.device)
@@ -128,7 +128,7 @@ def __init__(
self._finalized: bool = False
self._result_path: Path | None = None
- def __enter__(self) -> "ShardedSafetensorWriter":
+ def __enter__(self) -> ShardedSafetensorWriter:
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
@@ -180,8 +180,7 @@ def _flush(self) -> None:
shard_file = self.output_dir / f"{self.base_name}-{self._shard_index:05d}.safetensors.tmp"
logger.debug(
- f"Writing shard {self._shard_index}: {len(self._buffer)} tensors, "
- f"{self._buffer_bytes / 1e9:.2f} GB"
+ f"Writing shard {self._shard_index}: {len(self._buffer)} tensors, " f"{self._buffer_bytes / 1e9:.2f} GB"
)
save_file(self._buffer, shard_file)
self._shard_files.append(shard_file)
diff --git a/fast_llm_external_models/apriel2/conversion/llava/plan.py b/fast_llm_external_models/apriel2/conversion/llava/plan.py
index df485efbd..a97e46c1a 100644
--- a/fast_llm_external_models/apriel2/conversion/llava/plan.py
+++ b/fast_llm_external_models/apriel2/conversion/llava/plan.py
@@ -1,11 +1,6 @@
"""Llava to Apriel2 weight conversion plan."""
-from fast_llm_external_models.apriel2.conversion.expr import (
- Expr,
- ExprPlan,
- Ref,
- W,
-)
+from fast_llm_external_models.apriel2.conversion.expr import Expr, ExprPlan, Ref, W
def plan_llava_to_apriel2(llava_config: dict) -> ExprPlan:
diff --git a/fast_llm_external_models/apriel2/conversion/qwen2/__init__.py b/fast_llm_external_models/apriel2/conversion/qwen2/__init__.py
new file mode 100644
index 000000000..d0a0b8e6e
--- /dev/null
+++ b/fast_llm_external_models/apriel2/conversion/qwen2/__init__.py
@@ -0,0 +1,6 @@
+"""Qwen2/Qwen2.5 to Apriel2 conversion module."""
+
+from fast_llm_external_models.apriel2.conversion.qwen2.config import convert_config
+from fast_llm_external_models.apriel2.conversion.qwen2.plan import plan_qwen2_to_apriel2
+
+__all__ = ["convert_config", "plan_qwen2_to_apriel2"]
diff --git a/fast_llm_external_models/apriel2/conversion/qwen2/config.py b/fast_llm_external_models/apriel2/conversion/qwen2/config.py
new file mode 100644
index 000000000..70629fe0e
--- /dev/null
+++ b/fast_llm_external_models/apriel2/conversion/qwen2/config.py
@@ -0,0 +1,79 @@
+"""Qwen2/Qwen2.5 to Apriel2 config conversion."""
+
+
+def convert_config(qwen2_config: dict) -> dict:
+ """Convert Qwen2/Qwen2.5 config to Apriel2TextConfig format.
+
+ Qwen2.5 architecture:
+ - Standard transformer with GQA (grouped query attention)
+ - QKV bias enabled, O bias disabled
+ - MLP bias disabled
+ - Gated SwiGLU MLP
+ - RMSNorm
+ - RoPE embeddings
+
+ Args:
+ qwen2_config: HuggingFace Qwen2Config as dict
+
+ Returns:
+ Apriel2TextConfig-compatible dict
+ """
+ hidden_size = qwen2_config["hidden_size"]
+ num_attention_heads = qwen2_config["num_attention_heads"]
+ num_key_value_heads = qwen2_config.get("num_key_value_heads", num_attention_heads)
+ head_dim = hidden_size // num_attention_heads
+
+ # Qwen2 uses QKV bias but not O bias - mirror Fast-LLM's per-layer config
+ return {
+ "model_type": "apriel2_text",
+ "architectures": ["Apriel2ForCausalLM"],
+ "auto_map": {
+ "AutoConfig": "configuration_apriel2.Apriel2TextConfig",
+ "AutoModel": "modeling_apriel2.Apriel2TextModel",
+ "AutoModelForCausalLM": "modeling_apriel2.Apriel2ForCausalLM",
+ },
+ "hidden_size": hidden_size,
+ "vocab_size": qwen2_config["vocab_size"],
+ "tie_word_embeddings": qwen2_config.get("tie_word_embeddings", False),
+ "decoder": {
+ "type": "fixed",
+ "num_blocks": qwen2_config["num_hidden_layers"],
+ "block": {
+ "mixer": {
+ "type": "attention",
+ "heads": num_attention_heads,
+ "head_groups": num_key_value_heads,
+ "head_size": head_dim,
+ # Per-layer bias config matching Fast-LLM structure
+ "query_layer": {"bias": {"enabled": True}},
+ "key_layer": {"bias": {"enabled": True}},
+ "value_layer": {"bias": {"enabled": True}},
+ "dense_layer": {"bias": {"enabled": False}},
+ "rotary": {
+ "type": "mistral_1d",
+ "theta": qwen2_config.get("rope_theta", 1000000.0),
+ },
+ },
+ "mlp": {
+ "type": "mlp",
+ "intermediate_size": qwen2_config["intermediate_size"],
+ "activation": qwen2_config.get("hidden_act", "silu"),
+ "gated": True,
+ "add_linear_biases": False,
+ },
+ "normalization": {
+ "type": "rms_norm",
+ "epsilon": qwen2_config.get("rms_norm_eps", 1e-6),
+ },
+ },
+ },
+ "head": {
+ "normalization": {
+ "type": "rms_norm",
+ "epsilon": qwen2_config.get("rms_norm_eps", 1e-6),
+ }
+ },
+ "embeddings": {
+ "max_position_embeddings": qwen2_config.get("max_position_embeddings", 32768),
+ },
+ }
diff --git a/fast_llm_external_models/apriel2/conversion/qwen2/plan.py b/fast_llm_external_models/apriel2/conversion/qwen2/plan.py
new file mode 100644
index 000000000..c1ec4af8b
--- /dev/null
+++ b/fast_llm_external_models/apriel2/conversion/qwen2/plan.py
@@ -0,0 +1,100 @@
+"""Qwen2/Qwen2.5 to Apriel2 weight conversion plan."""
+
+from fast_llm_external_models.apriel2.conversion.expr import Expr, ExprPlan, Ref, W
+
+
+def plan_qwen2_to_apriel2(qwen2_config: dict) -> ExprPlan:
+ """Build an expression plan for Qwen2/Qwen2.5 to Apriel2 conversion.
+
+ This is a pure mapping (all Ref expressions) since Qwen2โApriel2
+ is just renaming keys. The weight tensors are identical.
+
+ Key mapping (source keys have "model." prefix in safetensors):
+ Qwen2 (safetensor key) Apriel2
+ ---------------------- -------
+ model.embed_tokens.weight -> model.embed_tokens.weight
+ model.norm.weight -> model.norm.weight
+ model.layers.{i}.input_layernorm.weight -> model.decoder.blocks.{i}.input_layernorm.weight
+ model.layers.{i}.post_attention_layernorm.weight -> model.decoder.blocks.{i}.post_attention_layernorm.weight
+ model.layers.{i}.self_attn.q_proj.weight -> model.decoder.blocks.{i}.mixer.q_proj.weight
+ model.layers.{i}.self_attn.q_proj.bias -> model.decoder.blocks.{i}.mixer.q_proj.bias
+ model.layers.{i}.self_attn.k_proj.weight -> model.decoder.blocks.{i}.mixer.k_proj.weight
+ model.layers.{i}.self_attn.k_proj.bias -> model.decoder.blocks.{i}.mixer.k_proj.bias
+ model.layers.{i}.self_attn.v_proj.weight -> model.decoder.blocks.{i}.mixer.v_proj.weight
+ model.layers.{i}.self_attn.v_proj.bias -> model.decoder.blocks.{i}.mixer.v_proj.bias
+ model.layers.{i}.self_attn.o_proj.weight -> model.decoder.blocks.{i}.mixer.o_proj.weight
+ model.layers.{i}.mlp.gate_proj.weight -> model.decoder.blocks.{i}.mlp.gate_proj.weight
+ model.layers.{i}.mlp.up_proj.weight -> model.decoder.blocks.{i}.mlp.up_proj.weight
+ model.layers.{i}.mlp.down_proj.weight -> model.decoder.blocks.{i}.mlp.down_proj.weight
+
+ Note: Qwen2 has QKV biases but no O bias. The Apriel2 config uses per-layer
+ bias settings (query_layer.bias.enabled=True, dense_layer.bias.enabled=False)
+ to match this exactly - no workarounds needed.
+
+ Args:
+ qwen2_config: HuggingFace Qwen2Config as dict
+
+ Returns:
+ ExprPlan with Ref mappings
+ """
+ mappings: dict[str, Expr] = {}
+
+ num_layers = qwen2_config["num_hidden_layers"]
+
+ # Static mappings (embeddings and final norm)
+ # Note: Qwen2 safetensor keys have "model." prefix
+ static_mappings = [
+ (W("model", "embed_tokens", "weight"), W("model", "embed_tokens", "weight")),
+ (W("model", "norm", "weight"), W("model", "norm", "weight")),
+ ]
+
+ # lm_head - only if not tied
+ if not qwen2_config.get("tie_word_embeddings", False):
+ static_mappings.append((W("lm_head", "weight"), W("lm_head", "weight")))
+
+ for src, tgt in static_mappings:
+ mappings[tgt] = Ref(key=src)
+
+ # Layer mappings
+ for layer in range(num_layers):
+ # Source has "model.layers.{i}" prefix
+ qwen_layer = W("model", "layers", layer)
+ apriel_layer = W("model", "decoder", "blocks", layer)
+
+ # Attention projection weights
+ for proj in ["q_proj", "k_proj", "v_proj", "o_proj"]:
+ src = qwen_layer / "self_attn" / proj / "weight"
+ tgt = apriel_layer / "mixer" / proj / "weight"
+ mappings[tgt] = Ref(key=src)
+
+ # QKV biases (Qwen2 has these, but not O bias)
+ for proj in ["q_proj", "k_proj", "v_proj"]:
+ src = qwen_layer / "self_attn" / proj / "bias"
+ tgt = apriel_layer / "mixer" / proj / "bias"
+ mappings[tgt] = Ref(key=src)
+
+ # Note: o_proj has no bias in Qwen2, and Apriel2 config has dense_layer.bias.enabled=False
+
+ # MLP projections
+ for proj in ["gate_proj", "up_proj", "down_proj"]:
+ src = qwen_layer / "mlp" / proj / "weight"
+ tgt = apriel_layer / "mlp" / proj / "weight"
+ mappings[tgt] = Ref(key=src)
+
+ # Layer norms
+ mappings[apriel_layer / "input_layernorm" / "weight"] = Ref(key=qwen_layer / "input_layernorm" / "weight")
+ mappings[apriel_layer / "post_attention_layernorm" / "weight"] = Ref(
+ key=qwen_layer / "post_attention_layernorm" / "weight"
+ )
+
+ return ExprPlan(
+ mappings=mappings,
+ source_format="qwen2",
+ target_format="apriel2",
+ metadata={
+ "num_layers": num_layers,
+ "hidden_size": qwen2_config["hidden_size"],
+ "num_attention_heads": qwen2_config["num_attention_heads"],
+ "num_key_value_heads": qwen2_config.get("num_key_value_heads", qwen2_config["num_attention_heads"]),
+ },
+ )
diff --git a/fast_llm_external_models/apriel2/conversion/render.py b/fast_llm_external_models/apriel2/conversion/render.py
index d71fa03e1..f9a0c8ac1 100644
--- a/fast_llm_external_models/apriel2/conversion/render.py
+++ b/fast_llm_external_models/apriel2/conversion/render.py
@@ -8,17 +8,11 @@
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
+from fast_llm_external_models.apriel2.conversion.expr import Concat, Init, Ref, Reshape, Slice
+
if TYPE_CHECKING:
from fast_llm_external_models.apriel2.conversion.expr import Expr, ExprPlan
-from fast_llm_external_models.apriel2.conversion.expr import (
- Concat,
- Init,
- Ref,
- Reshape,
- Slice,
-)
-
@dataclass
class PlanTreeNode:
@@ -28,10 +22,10 @@ class PlanTreeNode:
After merging, leaf nodes contain aggregated values from multiple siblings.
"""
- children: dict[str, "PlanTreeNode"] = field(default_factory=dict)
+ children: dict[str, PlanTreeNode] = field(default_factory=dict)
# For leaf nodes: list of (sibling_key, expr) pairs
# Before merge: single item, after merge: multiple items from merged siblings
- values: list[tuple[str, "Expr"]] = field(default_factory=list)
+ values: list[tuple[str, Expr]] = field(default_factory=list)
def is_leaf(self) -> bool:
return len(self.children) == 0
@@ -61,7 +55,7 @@ def _build_plan_tree(plan: ExprPlan) -> PlanTreeNode:
return root
-def _expr_signature(expr: "Expr") -> tuple:
+def _expr_signature(expr: Expr) -> tuple:
"""Get a signature for an expression that determines merge compatibility.
Expressions with different signatures should not be merged together.
@@ -453,7 +447,7 @@ def _render_plan_tree(
)
-def _format_aggregated_leaf(values: list[tuple[str, "Expr"]]) -> str:
+def _format_aggregated_leaf(values: list[tuple[str, Expr]]) -> str:
"""Format a leaf with aggregated values using pattern discovery.
Args:
@@ -494,7 +488,7 @@ def _format_aggregated_leaf(values: list[tuple[str, "Expr"]]) -> str:
return _format_single_expr(first_expr)
-def _format_single_expr(expr: "Expr") -> str:
+def _format_single_expr(expr: Expr) -> str:
"""Format a single expression using ML notation."""
match expr:
case Ref(key=key):
@@ -531,7 +525,7 @@ def _format_single_expr(expr: "Expr") -> str:
return f"= {type(expr).__name__}"
-def _format_concat_part(expr: "Expr") -> str:
+def _format_concat_part(expr: Expr) -> str:
"""Format a single part of a concat (for short display)."""
match expr:
case Ref(key=key):
@@ -570,7 +564,7 @@ def _format_slice_notation(slices: tuple) -> str:
return f"[{', '.join(slice_strs)}]"
-def _format_aggregated_concat(values: list[tuple[str, "Expr"]]) -> str:
+def _format_aggregated_concat(values: list[tuple[str, Expr]]) -> str:
"""Format aggregated Concat expressions with pattern discovery."""
# Get the first concat to understand structure
first_concat = values[0][1]
@@ -590,7 +584,7 @@ def _format_aggregated_concat(values: list[tuple[str, "Expr"]]) -> str:
return f"= [{sep.join(formatted_parts)}]"
-def _format_aggregated_concat_part(values: list[tuple[str, "Expr"]]) -> str:
+def _format_aggregated_concat_part(values: list[tuple[str, Expr]]) -> str:
"""Format a single part of an aggregated concat."""
if len(values) == 1:
return _format_concat_part(values[0][1])
@@ -619,7 +613,7 @@ def _format_aggregated_concat_part(values: list[tuple[str, "Expr"]]) -> str:
return _format_concat_part(first_expr)
-def _format_aggregated_slice(values: list[tuple[str, "Expr"]]) -> str:
+def _format_aggregated_slice(values: list[tuple[str, Expr]]) -> str:
"""Format aggregated Slice expressions with pattern discovery."""
first_slice = values[0][1]
if not isinstance(first_slice, Slice):
diff --git a/fast_llm_external_models/apriel2/convert.py b/fast_llm_external_models/apriel2/convert.py
index cbf921b31..66c419dfd 100644
--- a/fast_llm_external_models/apriel2/convert.py
+++ b/fast_llm_external_models/apriel2/convert.py
@@ -15,6 +15,7 @@
Supported source formats:
- llava: Llava/Pixtral models
+- qwen2: Qwen2/Qwen2.5 models
- apriel2: Apriel2 models (surgery-only mode - no conversion, just apply surgeries)
"""
@@ -29,10 +30,7 @@
import yaml
from tqdm import tqdm
-# Allow running as script or module
-if __name__ == "__main__":
- sys.path.insert(0, str(Path(__file__).parent.parent.parent))
-
+# Import source-specific converters
from fast_llm_external_models.apriel2.conversion import (
DEFAULT_MAX_SHARD_SIZE,
ExprPlan,
@@ -41,11 +39,16 @@
StreamingExecutor,
compose,
compose_configs,
- plan_surgery,
)
-
-# Import source-specific converters
from fast_llm_external_models.apriel2.conversion import llava as llava_converter
+from fast_llm_external_models.apriel2.conversion import plan_surgery
+from fast_llm_external_models.apriel2.conversion import qwen2 as qwen2_converter
+from fast_llm_external_models.apriel2.conversion import strip_init_fields
+
+# Allow running as script or module
+if __name__ == "__main__":
+ sys.path.insert(0, str(Path(__file__).parent.parent.parent))
+
logger = logging.getLogger(__name__)
@@ -73,6 +76,7 @@ def _identity_plan(config: dict) -> ExprPlan:
# Each entry maps format name to (config_converter, plan_builder)
SOURCE_FORMATS: dict[str, tuple[Callable[[dict], dict], Callable[[dict], ExprPlan]]] = {
"llava": (llava_converter.convert_config, llava_converter.plan_llava_to_apriel2),
+ "qwen2": (qwen2_converter.convert_config, qwen2_converter.plan_qwen2_to_apriel2),
"apriel2": (_identity_config, _identity_plan),
}
@@ -88,8 +92,12 @@ def detect_source_format(config: dict) -> str | None:
if model_type in ("llava", "pixtral") or "text_config" in config:
return "llava"
+ # Qwen2/Qwen2.5 detection
+ if model_type == "qwen2":
+ return "qwen2"
+
# Apriel2 detection - check for Apriel2-specific structure
- if model_type == "apriel2" or "decoder" in config:
+ if model_type in ("apriel2", "apriel2_text") or "decoder" in config:
return "apriel2"
return None
@@ -142,15 +150,21 @@ def build_plan(
# Apply surgery chain if requested
if surgery_configs:
for i, surgery_config in enumerate(surgery_configs, 1):
- surgery_plan = plan_surgery(current_config, surgery_config)
- logger.info(f"Built surgery plan [{i}/{len(surgery_configs)}]: {surgery_plan.summary()['num_targets']} targets")
+ # S ร P โ T: compose state with surgery to get transition spec
+ target_config = compose_configs(current_config, surgery_config)
+
+ # S ร T โ Plan: build plan from source state and transition spec
+ surgery_plan = plan_surgery(current_config, target_config)
+ logger.info(
+ f"Built surgery plan [{i}/{len(surgery_configs)}]: {surgery_plan.summary()['num_targets']} targets"
+ )
- # Compose: current -> surgery
+ # Compose plans
current_plan = compose(current_plan, surgery_plan)
logger.info(f"Composed plan [{i}/{len(surgery_configs)}]: {current_plan.summary()['num_targets']} targets")
- # Compose configs: merge surgery spec into current config
- current_config = compose_configs(current_config, surgery_config)
+ # T โ S: strip init for next iteration (init is consumed by plan_surgery)
+ current_config = strip_init_fields(target_config)
return current_plan, current_config
@@ -211,9 +225,7 @@ def convert(
executor = StreamingExecutor(full_plan, loader)
with ShardedSafetensorWriter(output_dir, max_shard_size=max_shard_size) as writer:
- for target_key, tensor in tqdm(
- executor.execute(seed), desc="Converting", total=len(full_plan)
- ):
+ for target_key, tensor in tqdm(executor.execute(seed), desc="Converting", total=len(full_plan)):
writer.add(target_key, tensor)
return final_config
@@ -282,9 +294,7 @@ def resolve_input(input_path: str) -> Path:
def main():
- parser = argparse.ArgumentParser(
- description="Convert HuggingFace checkpoint to Apriel2 HF format"
- )
+ parser = argparse.ArgumentParser(description="Convert HuggingFace checkpoint to Apriel2 HF format")
parser.add_argument(
"input",
type=str,
@@ -384,8 +394,7 @@ def main():
safetensor_files = sorted(input_dir.glob("*.safetensors"))
if not safetensor_files:
raise ValueError(
- f"No safetensor files found in {input_dir}. "
- "Plan-based conversion requires safetensor files."
+ f"No safetensor files found in {input_dir}. " "Plan-based conversion requires safetensor files."
)
# Convert using plan-based approach with streaming sharded output
@@ -400,11 +409,11 @@ def main():
show_plan=args.show_plan or args.verbose,
)
- # Save config
+ # Save config (build_plan returns S which has no init, but strip defensively)
output_config_file = args.output_dir / "config.json"
logger.info(f"Saving config to {output_config_file}")
with open(output_config_file, "w") as f:
- json.dump(apriel2_config, f, indent=2)
+ json.dump(strip_init_fields(apriel2_config), f, indent=2)
# Copy tokenizer files
copy_tokenizer_files(input_dir, args.output_dir)
diff --git a/fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml b/fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml
new file mode 100644
index 000000000..34672916c
--- /dev/null
+++ b/fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml
@@ -0,0 +1,103 @@
+# Dataset preparation config for Tulu 3 SFT mixture with Qwen2 tokenizer
+#
+# This config converts the Tulu 3 SFT dataset (conversation format) to
+# Fast-LLM's memmap format, with automatic loss masking span computation
+# to train only on assistant responses.
+#
+# =============================================================================
+# TOKENIZER SETUP (one-time)
+# =============================================================================
+#
+# The tokenizer must have a chat template with {% generation %} markers.
+# Qwen2's default template doesn't have these, so we need to patch it.
+#
+# IMPORTANT: The entire assistant turn (opening tag + content + closing tag)
+# must be inside the {% generation %} block. This ensures the model learns to
+# produce the full assistant response including special tokens like <|im_end|>.
+# Reference: https://huggingface.co/HuggingFaceTB/SmolLM3-3B/blob/main/chat_template.jinja
+#
+# Run this Python script to create a patched tokenizer:
+#
+# from transformers import AutoTokenizer
+#
+# tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
+#
+# # Patch chat template: wrap ENTIRE assistant turn in generation markers
+# tokenizer.chat_template = '''{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system
+# You are a helpful assistant.<|im_end|>
+# ' }}{% endif %}{% if message['role'] == 'assistant' %}{% generation %}{{ '<|im_start|>assistant
+# ' + message['content'] + '<|im_end|>
+# ' }}{% endgeneration %}{% else %}{{ '<|im_start|>' + message['role'] + '
+# ' + message['content'] + '<|im_end|>
+# ' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant
+# ' }}{% endif %}'''
+#
+# tokenizer.save_pretrained("/path/to/qwen2-instruct-with-markers")
+#
+# =============================================================================
+# DATA PREPARATION
+# =============================================================================
+#
+# Small dataset (for testing):
+#
+# fast-llm prepare gpt_memmap \
+# -c fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml \
+# dataset.split=train[:1000] \
+# output_path=/path/to/tulu3-prepared-small
+#
+# Full dataset (~939K samples, ~6 minutes):
+#
+# fast-llm prepare gpt_memmap \
+# -c fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml
+#
+# =============================================================================
+# VERIFICATION
+# =============================================================================
+#
+# To verify the prepared dataset has loss masking spans:
+#
+# import pathlib
+# from fast_llm.data.dataset.memmap import MemmapDataset
+# from fast_llm.data.sample.language_model import LanguageModelSample
+# from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig
+#
+# dataset = MemmapDataset[LanguageModelSample](
+# 'tulu3',
+# pathlib.Path('/path/to/tulu3-prepared/shard_0_0.fast_llm_dataset'),
+# LanguageModelPreprocessingConfig(use_loss_masking_spans=True)
+# )
+#
+# doc = dataset.get_document(0)
+# print(f'Tokens: {len(doc.tokens.tokens)}')
+# print(f'Loss masking spans: {doc.loss_masking_spans.ranges}')
+#
+# =============================================================================
+
+# Dataset configuration
+dataset:
+ # Tulu 3 SFT mixture from AllenAI
+ path: allenai/tulu-3-sft-mixture
+ split: train
+
+ # Source schema for conversation format
+ source_schema:
+ # Use conversation type (vs default "document" type)
+ type: conversation
+
+ # Column containing the messages list
+ messages: messages
+
+# Tokenizer configuration
+# IMPORTANT: Must use a tokenizer with {% generation %} markers in its chat template.
+# See instructions above to create a patched Qwen2 tokenizer.
+tokenizer:
+ path: /path/to/qwen2-instruct-with-markers
+ # Qwen2 doesn't have a BOS token by default, use <|endoftext|> as BOS
+ bos_token: "<|endoftext|>"
+
+# Output configuration
+output_path: /path/to/tulu3-prepared
+
+# Processing configuration
+num_workers: 8
+documents_per_shard: 100000
diff --git a/fast_llm_external_models/apriel2/examples/train_supernet_qwen2.yaml b/fast_llm_external_models/apriel2/examples/train_supernet_qwen2.yaml
new file mode 100644
index 000000000..aad168713
--- /dev/null
+++ b/fast_llm_external_models/apriel2/examples/train_supernet_qwen2.yaml
@@ -0,0 +1,193 @@
+# Training config for Qwen2-based Apriel2 stochastic supernet on Tulu 3 SFT data
+#
+# This config trains a stochastic supernet where each layer can sample from
+# multiple mixer types (attention, sliding window, gated delta net, KDA).
+# Only the mixer weights are trained; all other weights are frozen.
+# Activation-level distillation from a teacher model guides the training.
+#
+# =============================================================================
+# PREREQUISITES
+# =============================================================================
+#
+# 1. TOKENIZER SETUP
+#
+# Qwen2's default chat template doesn't have generation markers needed for
+# loss masking. Create a patched tokenizer following the SmolLM3 pattern:
+# https://huggingface.co/HuggingFaceTB/SmolLM3-3B/blob/main/chat_template.jinja
+#
+# IMPORTANT: The ENTIRE assistant turn (opening tag + content + closing tag)
+# must be inside {% generation %}...{% endgeneration %} markers.
+#
+# from transformers import AutoTokenizer
+# tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
+# # Wrap entire assistant turn in generation markers (NOT just content!)
+# tokenizer.chat_template = '''{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system
+# You are a helpful assistant.<|im_end|>
+# ' }}{% endif %}{% if message['role'] == 'assistant' %}{% generation %}{{ '<|im_start|>assistant
+# ' + message['content'] + '<|im_end|>
+# ' }}{% endgeneration %}{% else %}{{ '<|im_start|>' + message['role'] + '
+# ' + message['content'] + '<|im_end|>
+# ' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant
+# ' }}{% endif %}'''
+# tokenizer.save_pretrained("/path/to/qwen2-instruct-with-markers")
+#
+# 2. PREPARE TULU 3 DATASET
+#
+# Small dataset (for testing):
+#
+# fast-llm prepare gpt_memmap \
+# -c fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml \
+# tokenizer.path=/path/to/qwen2-instruct-with-markers \
+# dataset.split=train[:1000] \
+# output_path=/path/to/tulu3-prepared-small
+#
+# Full dataset (~939K samples, ~6 minutes):
+#
+# fast-llm prepare gpt_memmap \
+# -c fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml \
+# tokenizer.path=/path/to/qwen2-instruct-with-markers \
+# output_path=/path/to/tulu3-prepared
+#
+# 3. CONVERT QWEN2 TO APRIEL2 SUPERNET (student model)
+#
+# This creates a stochastic supernet with multiple mixer types per layer:
+#
+# python fast_llm_external_models/apriel2/convert.py \
+# Qwen/Qwen2.5-0.5B-Instruct \
+# /path/to/qwen2-supernet \
+# --surgery fast_llm_external_models/apriel2/examples/stochastic_supernet.yaml
+#
+# 4. CONVERT QWEN2 TO APRIEL2 (teacher model)
+#
+# The teacher is the original model without surgery, used for distillation:
+#
+# python fast_llm_external_models/apriel2/convert.py \
+# Qwen/Qwen2.5-0.5B-Instruct \
+# /path/to/qwen2-teacher
+#
+# 5. RUN TRAINING
+#
+# Update paths below and run:
+#
+# fast-llm train gpt \
+# -c fast_llm_external_models/apriel2/examples/train_supernet_qwen2.yaml
+#
+# For long runs, use nohup:
+#
+# nohup fast-llm train gpt \
+# -c fast_llm_external_models/apriel2/examples/train_supernet_qwen2.yaml \
+# > training.log 2>&1 &
+# tail -f training.log
+#
+# =============================================================================
+# PERFORMANCE TUNING
+# =============================================================================
+#
+# Default config uses seq=2048, micro_batch=2, batch=64 (~131k tokens/iter).
+# Adjust settings based on your GPU memory:
+# - Reduce micro_batch_size or sequence_length if OOM
+# - Increase micro_batch_size or sequence_length if memory available
+#
+# =============================================================================
+# OUTPUT
+# =============================================================================
+#
+# Checkpoints: /path/to/qwen2-supernet-trained/checkpoints/{iteration}/
+# Exports: /path/to/qwen2-supernet-trained/export/apriel2_text/{iteration}/
+#
+# =============================================================================
+
+# Load pretrained model (Qwen2 converted to Apriel2 supernet)
+pretrained:
+ path: /path/to/qwen2-supernet
+ format: apriel2_text
+ model_weights: true
+ load_config: model
+
+# Model config
+model:
+ base_model:
+ # Freeze all components except the mixer
+ decoder:
+ block:
+ mlp:
+ lr_scale: 0.0 # Freeze MLP
+ normalization:
+ lr_scale: 0.0 # Freeze layer norms
+ distillation_model: teacher
+ activation_distillation_factor: 0.5
+ embeddings:
+ lr_scale: 0.0 # Freeze word embeddings
+ head:
+ lr_scale: 0.0 # Freeze output head
+ # cross_entropy_implementation: torch
+ distillation_model: teacher
+ distillation_loss_factor: 1.0
+ distillation_loss_implementation: reverse_kl
+ multi_stage:
+ zero_stage: 2
+ distributed:
+ compute_dtype: bf16
+ seed: 42
+
+# Teacher model for activation-level distillation
+reference_models:
+ teacher:
+ model:
+ type: gpt
+ pretrained:
+ path: /path/to/qwen2-teacher
+ format: apriel2_text
+ model_weights: true
+ load_config: model
+
+# Batch configuration
+batch:
+ sequence_length: 2048
+ micro_batch_size: 2
+ batch_size: 64
+ truncate_documents: false
+ use_loss_masking_spans: true
+
+# Data configuration (prepared Tulu 3 dataset)
+data:
+ datasets:
+ training:
+ type: file
+ path: /path/to/tulu3-prepared/fast_llm_config.yaml
+
+# Optimizer configuration
+optimizer:
+ learning_rate:
+ base: 3.0e-05
+ decay_style: cosine
+ warmup_iterations: 100
+ decay_iterations: 10000
+ minimum: 1.0e-06
+ weight_decay: 0.1
+ beta_1: 0.9
+ beta_2: 0.95
+
+# Training configuration
+# At seq=2048, batch=64: ~131k tokens/iter
+training:
+ train_iters: 10000
+ num_workers: 4
+ logs:
+ interval: 10
+ checkpoint:
+ interval: 100
+ export:
+ interval: 100
+ format: apriel2_text
+ test_iters: 0
+ evaluators: {}
+ # Weights & Biases configuration (optional, uncomment to enable)
+ # wandb:
+ # entity_name: your-entity
+ # project_name: your-project
+ # group_name: your-group
+
+# Experiment directory
+run:
+ experiment_dir: /path/to/qwen2-supernet-trained
diff --git a/fast_llm_external_models/apriel2/examples/train_supernet_small.yaml b/fast_llm_external_models/apriel2/examples/train_supernet_small.yaml
index 78c22e57f..be4d06e0a 100644
--- a/fast_llm_external_models/apriel2/examples/train_supernet_small.yaml
+++ b/fast_llm_external_models/apriel2/examples/train_supernet_small.yaml
@@ -107,7 +107,7 @@ model:
lr_scale: 0.0 # Freeze layer norms (norm_1 and norm_2 in each block)
# Activation-level distillation: teach mixers to mimic teacher's attention outputs
distillation_model: teacher
- activation_distillation_factor: 0.1
+ activation_distillation_factor: 0.8
embeddings:
lr_scale: 0.0 # Freeze word embeddings
head:
diff --git a/fast_llm_external_models/apriel2/modeling_apriel2.py b/fast_llm_external_models/apriel2/modeling_apriel2.py
index 4c263b4e2..240240cd6 100644
--- a/fast_llm_external_models/apriel2/modeling_apriel2.py
+++ b/fast_llm_external_models/apriel2/modeling_apriel2.py
@@ -24,8 +24,8 @@
is_torch_flex_attn_available,
)
-from fast_llm_external_models.apriel2.cache import Apriel2Cache
-from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config, Apriel2TextConfig
+from .cache import Apriel2Cache
+from .configuration_apriel2 import Apriel2Config, Apriel2TextConfig
# GDN implementation - matches Fast-LLM's gdn.py exactly
try:
@@ -395,14 +395,30 @@ def __init__(self, d_model: int, mixer_config: dict, layer_idx: int, config):
# cross_document_attention: if False, use cu_seqlens to isolate sequences (e.g., images)
self.cross_document_attention = mixer_config.get("cross_document_attention", True)
- # Whether to add biases to linear projections
- add_bias = mixer_config.get("add_linear_biases", False)
-
- # Projections (Fast-LLM weight names: q_proj, k_proj, v_proj, o_proj)
- self.q_proj = nn.Linear(d_model, self.num_heads * self.head_dim, bias=add_bias)
- self.k_proj = nn.Linear(d_model, self.num_key_value_heads * self.head_dim, bias=add_bias)
- self.v_proj = nn.Linear(d_model, self.num_key_value_heads * self.head_dim, bias=add_bias)
- self.o_proj = nn.Linear(self.num_heads * self.head_dim, d_model, bias=add_bias)
+ # Bias configuration mirroring Fast-LLM's structure:
+ # - add_linear_biases: bool (default for all projections)
+ # - query_layer: {"bias": {"enabled": bool}} (per-layer override)
+ # - key_layer: {"bias": {"enabled": bool}}
+ # - value_layer: {"bias": {"enabled": bool}}
+ # - dense_layer: {"bias": {"enabled": bool}}
+ default_bias = mixer_config.get("add_linear_biases", False)
+
+ def get_layer_bias(layer_name: str) -> bool:
+ layer_cfg = mixer_config.get(layer_name, {})
+ bias_cfg = layer_cfg.get("bias", {})
+ enabled = bias_cfg.get("enabled")
+ return default_bias if enabled is None else enabled
+
+ q_bias = get_layer_bias("query_layer")
+ k_bias = get_layer_bias("key_layer")
+ v_bias = get_layer_bias("value_layer")
+ o_bias = get_layer_bias("dense_layer")
+
+ # Projections
+ self.q_proj = nn.Linear(d_model, self.num_heads * self.head_dim, bias=q_bias)
+ self.k_proj = nn.Linear(d_model, self.num_key_value_heads * self.head_dim, bias=k_bias)
+ self.v_proj = nn.Linear(d_model, self.num_key_value_heads * self.head_dim, bias=v_bias)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, d_model, bias=o_bias)
@classmethod
def setup(
@@ -1017,6 +1033,8 @@ def torch_chunk_gated_delta_rule(
if not output_final_state:
last_recurrent_state = None
+ elif last_recurrent_state is not None:
+ last_recurrent_state = last_recurrent_state.to(initial_dtype)
core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1])
core_attn_out = core_attn_out[:, :, :sequence_length]
core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
@@ -1225,7 +1243,9 @@ def forward(self, hidden_states: torch.Tensor, past_key_values=None, attention_m
mixed_qkv = self.convolution.update(
mixed_qkv.squeeze(2), # [batch, conv_dim, 1] -> [batch, conv_dim]
conv_state,
- ).unsqueeze(2) # [batch, conv_dim] -> [batch, conv_dim, 1]
+ ).unsqueeze(
+ 2
+ ) # [batch, conv_dim] -> [batch, conv_dim, 1]
else:
# Prefill mode
use_cache = past_key_values is not None
@@ -1270,8 +1290,14 @@ def forward(self, hidden_states: torch.Tensor, past_key_values=None, attention_m
output_final_state=past_key_values is not None,
use_qk_l2norm_in_kernel=True,
)
+ # Ensure state is in same dtype as hidden_states (fla kernel may return float32)
+ if last_recurrent_state is not None:
+ last_recurrent_state = last_recurrent_state.to(hidden_states.dtype)
else:
# Recurrent mode for single token decode
+ # Convert recurrent_state to match hidden_states dtype if needed
+ if recurrent_state is not None and recurrent_state.dtype != hidden_states.dtype:
+ recurrent_state = recurrent_state.to(hidden_states.dtype)
output, last_recurrent_state = self._recurrent_gated_delta_rule(
query, key, value, g, beta_gate, recurrent_state
)
@@ -1294,7 +1320,16 @@ def forward(self, hidden_states: torch.Tensor, past_key_values=None, attention_m
return (output,)
def _recurrent_gated_delta_rule(self, query, key, value, g, beta, state):
- """Single-step recurrent update for cached inference."""
+ """Single-step recurrent update for cached inference.
+
+ Input shapes: [batch, seq=1, heads, dim]
+ Need shapes: [batch, heads, dim] for einsum operations
+ """
+ # Transpose from [batch, seq, heads, dim] to [batch, heads, seq, dim]
+ query = query.transpose(1, 2)
+ key = key.transpose(1, 2)
+ value = value.transpose(1, 2)
+
# L2 normalize query and key
query = _l2norm(query, dim=-1, eps=1e-6)
key = _l2norm(key, dim=-1, eps=1e-6)
@@ -1307,7 +1342,9 @@ def _recurrent_gated_delta_rule(self, query, key, value, g, beta, state):
beta = beta.squeeze(1)
# Update state: S = exp(g) * S + beta * k^T @ v
- decay = g.exp().unsqueeze(-1).unsqueeze(-1) # [batch, heads, 1, 1]
+ # Keep everything in the same dtype as input (exp() returns float32, need to convert back)
+ input_dtype = query.dtype
+ decay = g.exp().to(input_dtype).unsqueeze(-1).unsqueeze(-1) # [batch, heads, 1, 1]
k_outer_v = torch.einsum("bhk,bhv->bhkv", key * beta.unsqueeze(-1), value)
state = decay * state + k_outer_v
@@ -1315,6 +1352,12 @@ def _recurrent_gated_delta_rule(self, query, key, value, g, beta, state):
output = torch.einsum("bhk,bhkv->bhv", query, state)
output = output.unsqueeze(2) # [batch, heads, 1, v_dim]
+ # Transpose back to [batch, seq=1, heads, v_dim]
+ output = output.transpose(1, 2)
+
+ # Ensure state matches output dtype
+ state = state.to(output.dtype)
+
return output, state
@classmethod
@@ -1447,9 +1490,7 @@ def __init__(
# Normalization - use GatedRMSNormalization (same wrapper as GDN, with sigmoid activation)
self.norm = GatedRMSNormalization(self.head_dim, eps=self.norm_eps, activation=self.norm_activation)
- def _apply_conv(
- self, x: torch.Tensor, conv: CausalConv1d, conv_state: torch.Tensor | None, use_cache: bool
- ):
+ def _apply_conv(self, x: torch.Tensor, conv: CausalConv1d, conv_state: torch.Tensor | None, use_cache: bool):
"""
Apply causal convolution with cache support.
@@ -1828,16 +1869,36 @@ def __init__(
self.post_attention_layernorm = self._create_norm(norm_config, hidden_size, rms_norm_eps)
def _create_mlp(self, mlp_config: dict, hidden_size: int):
- """Create MLP based on config."""
+ """Create MLP based on config.
+
+ Supports per-layer bias configuration mirroring Fast-LLM:
+ - add_linear_biases: default bias setting for all layers
+ - layer_1.bias.enabled: override for up_proj/gate_proj
+ - layer_2.bias.enabled: override for down_proj
+ """
mlp_type = mlp_config.get("type", "mlp")
if mlp_type == "mlp":
intermediate_size = mlp_config["intermediate_size"]
activation = mlp_config.get("activation", "silu")
- gated = mlp_config["gated"]
- bias = mlp_config.get("add_linear_biases", False)
+ gated = mlp_config.get("gated", False)
+
+ # Per-layer bias configuration (mirrors Fast-LLM structure)
+ default_bias = mlp_config.get("add_linear_biases", False)
+
+ def get_layer_bias(layer_name: str) -> bool:
+ layer_cfg = mlp_config.get(layer_name, {})
+ bias_cfg = layer_cfg.get("bias", {})
+ enabled = bias_cfg.get("enabled")
+ return default_bias if enabled is None else enabled
+
+ layer_1_bias = get_layer_bias("layer_1")
+ layer_2_bias = get_layer_bias("layer_2")
if gated:
+ # MistralMLP uses gate_proj, up_proj, down_proj (all bias controlled together)
+ # For now, we use the default bias setting for gated MLPs
+ # TODO: Add per-layer bias support to gated MLP
mlp_cfg = SimpleNamespace(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
@@ -1845,7 +1906,13 @@ def _create_mlp(self, mlp_config: dict, hidden_size: int):
)
return MistralMLP(mlp_cfg)
else:
- return SimpleMLP(hidden_size, intermediate_size, activation, bias)
+ return SimpleMLP(
+ hidden_size,
+ intermediate_size,
+ activation,
+ layer_1_bias=layer_1_bias,
+ layer_2_bias=layer_2_bias,
+ )
else:
raise ValueError(f"Unknown MLP type: {mlp_type}")
@@ -2179,6 +2246,8 @@ def forward(
class Apriel2ForCausalLM(Apriel2PreTrainedModel, GenerationMixin):
"""Apriel2 model with a language modeling head (text-only)."""
+ _tied_weights_keys = ["lm_head.weight"]
+
def __init__(self, config: Apriel2TextConfig):
super().__init__(config)
self.model = Apriel2TextModel(config)
@@ -2186,6 +2255,7 @@ def __init__(self, config: Apriel2TextConfig):
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
+ # post_init() calls init_weights() which calls tie_weights() if config.tie_word_embeddings
self.post_init()
def get_input_embeddings(self):
@@ -2583,14 +2653,26 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
class SimpleMLP(nn.Module):
- """Non-gated MLP: up_proj -> activation -> down_proj."""
+ """Non-gated MLP: up_proj -> activation -> down_proj.
+
+ Supports per-layer bias configuration mirroring Fast-LLM:
+ - layer_1_bias: bias for up_proj (layer_1 in Fast-LLM naming)
+ - layer_2_bias: bias for down_proj (layer_2 in Fast-LLM naming)
+ """
- def __init__(self, hidden_size: int, intermediate_size: int, activation: str = "silu", bias: bool = False):
+ def __init__(
+ self,
+ hidden_size: int,
+ intermediate_size: int,
+ activation: str = "silu",
+ layer_1_bias: bool = False,
+ layer_2_bias: bool = False,
+ ):
super().__init__()
from transformers.activations import ACT2FN
- self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=bias)
- self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=bias)
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=layer_1_bias)
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=layer_2_bias)
self.act_fn = ACT2FN[activation]
def forward(self, x):
diff --git a/fast_llm_external_models/tests/test_apriel2/conftest.py b/fast_llm_external_models/tests/test_apriel2/conftest.py
index 8585aec65..21b90b097 100644
--- a/fast_llm_external_models/tests/test_apriel2/conftest.py
+++ b/fast_llm_external_models/tests/test_apriel2/conftest.py
@@ -1,23 +1,44 @@
"""Test fixtures for Apriel2 model tests."""
+from collections.abc import Generator
from pathlib import Path
-from typing import Generator
import pytest
import torch
from transformers import LlavaConfig, LlavaForConditionalGeneration, MistralConfig
+from fast_llm_external_models.apriel2.cache import _AttentionCache, _SSMCache
+
+
+# Register custom marks
+def pytest_configure(config):
+ config.addinivalue_line("markers", "slow: marks tests as slow (deselect with '-m \"not slow\"')")
+
+
+def _can_import_fast_llm():
+ """Check if Fast-LLM is available."""
+ try:
+ return True
+ except ImportError:
+ return False
+
# Skip marker for tests that require CUDA for Mamba forward pass
requires_cuda = pytest.mark.skipif(
- not torch.cuda.is_available(),
- reason="SSM mixers (Mamba) require CUDA for forward pass"
+ not torch.cuda.is_available(), reason="SSM mixers (Mamba) require CUDA for forward pass"
)
+# Skip marker for tests that require Fast-LLM
+requires_fastllm = pytest.mark.skipif(not _can_import_fast_llm(), reason="Fast-LLM not available")
-@pytest.fixture(autouse=True)
+
+@pytest.fixture(scope="module", autouse=True)
def set_default_device():
- """Set default device to CUDA for all tests (Mamba requires CUDA)."""
+ """Set default device to CUDA for all tests (Mamba requires CUDA).
+
+ Module-scoped to ensure it runs before any module-scoped fixtures
+ that load models (e.g., qwen2_model_and_tokenizer).
+ """
if torch.cuda.is_available():
old_device = torch.get_default_device()
torch.set_default_device("cuda")
@@ -27,9 +48,12 @@ def set_default_device():
yield
-@pytest.fixture(autouse=True)
+@pytest.fixture(scope="module", autouse=True)
def set_default_dtype():
- """Set default dtype to float32 for numerical comparison tests."""
+ """Set default dtype to float32 for numerical comparison tests.
+
+ Module-scoped to ensure it runs before any module-scoped fixtures.
+ """
old_dtype = torch.get_default_dtype()
torch.set_default_dtype(torch.float32)
yield
@@ -135,14 +159,11 @@ def model_pair(request, small_pixtral_model, tmp_path):
tuple: (source_model, target_model, expected_atol, variant_name)
"""
import json
+
from safetensors import safe_open
from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config
- from fast_llm_external_models.apriel2.conversion import (
- convert_llava_config,
- execute,
- plan_llava_to_apriel2,
- )
+ from fast_llm_external_models.apriel2.conversion import convert_llava_config, execute, plan_llava_to_apriel2
from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForConditionalGeneration
source = small_pixtral_model
@@ -638,12 +659,12 @@ def apriel2_config_comprehensive():
"type": "pattern",
"num_blocks": 6,
"pattern": [
- "attn", # 0: pure full attention
- "swa", # 1: pure sliding window attention
- "mamba", # 2: pure mamba
- "gdn", # 3: pure gated delta net
- "stoch_attn_mamba", # 4: stochastic attention + mamba
- "stoch_swa_gdn", # 5: stochastic swa + gated delta net
+ "attn", # 0: pure full attention
+ "swa", # 1: pure sliding window attention
+ "mamba", # 2: pure mamba
+ "gdn", # 3: pure gated delta net
+ "stoch_attn_mamba", # 4: stochastic attention + mamba
+ "stoch_swa_gdn", # 5: stochastic swa + gated delta net
],
"blocks": {
"attn": {
@@ -761,6 +782,52 @@ def apriel2_config_comprehensive():
)
+@pytest.fixture
+def apriel2_config_with_bias():
+ """Apriel2 config with Qwen-style per-layer bias and non-gated MLP.
+
+ This config exercises:
+ - Per-layer attention bias (QKV bias enabled, O bias disabled)
+ - Non-gated MLP with per-layer bias (layer_1 enabled, layer_2 disabled)
+ - Config structure parity with Fast-LLM's AffineLinearConfig
+
+ Critical for testing bias inheritance through surgery operations.
+ """
+ from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config
+
+ return Apriel2Config(
+ vocab_size=100,
+ hidden_size=64,
+ decoder={
+ "type": "fixed",
+ "num_blocks": 2,
+ "block": {
+ "mixer": {
+ "type": "attention",
+ "heads": 4,
+ "head_groups": 2,
+ "head_size": 16,
+ "rotary": {"type": "mistral_1d", "theta": 10000.0},
+ # Qwen-style: QKV bias enabled, O bias disabled
+ "query_layer": {"bias": {"enabled": True}},
+ "key_layer": {"bias": {"enabled": True}},
+ "value_layer": {"bias": {"enabled": True}},
+ "dense_layer": {"bias": {"enabled": False}},
+ },
+ "mlp": {
+ "type": "mlp",
+ "intermediate_size": 256,
+ "gated": False, # Non-gated MLP (SimpleMLP)
+ # Per-layer MLP bias
+ "layer_1": {"bias": {"enabled": True}},
+ "layer_2": {"bias": {"enabled": False}},
+ },
+ "normalization": {"type": "rms_norm", "epsilon": 1e-5},
+ },
+ },
+ )
+
+
@pytest.fixture
def apriel2_cache(apriel2_config_tiny):
"""Create empty Apriel2Cache from tiny config."""
@@ -863,6 +930,77 @@ def additive_surgery_chain():
]
+@pytest.fixture
+def bias_surgery_chain():
+ """Surgery chain that exercises bias inheritance through surgery operations.
+
+ Designed to be used with apriel2_config_with_bias as the source config.
+ Tests that per-layer bias settings (Qwen-style QKV bias, non-gated MLP bias)
+ are correctly inherited through:
+ - Stochastic wrapper creation
+ - Adding new sub-mixers that inherit from source
+ - Cross-type derivation (attention โ sliding_window)
+
+ Source config has:
+ - Attention: query/key/value bias enabled, dense bias disabled
+ - MLP: layer_1 bias enabled, layer_2 bias disabled (non-gated)
+ """
+ return [
+ # S1: Wrap in stochastic - bias should transfer to attention sub-mixer
+ {
+ "decoder": {
+ "block": {
+ "mixer": {
+ "type": "stochastic",
+ "main_mixer_name": "attention",
+ "mixers": {
+ "attention": {"init": "transfer"},
+ },
+ },
+ "mlp": {"init": "transfer"},
+ "normalization": {"init": "transfer"},
+ },
+ },
+ },
+ # S2: Add sliding_window - should inherit bias from source attention
+ {
+ "decoder": {
+ "block": {
+ "mixer": {
+ "mixers": {
+ "sliding_window": {
+ "type": "attention",
+ "init": "transfer",
+ "window_size": 512,
+ },
+ },
+ },
+ },
+ },
+ },
+ # S3: Add new attention with DIFFERENT bias config (random init)
+ {
+ "decoder": {
+ "block": {
+ "mixer": {
+ "mixers": {
+ "full_bias_attn": {
+ "type": "attention",
+ "init": "random",
+ "heads": 4,
+ "head_groups": 2,
+ "head_size": 16,
+ "rotary": {"type": "mistral_1d", "theta": 10000.0},
+ "add_linear_biases": True, # All biases enabled
+ },
+ },
+ },
+ },
+ },
+ },
+ ]
+
+
@pytest.fixture
def comprehensive_torture_chain():
"""Comprehensive torture chain exercising ALL conversion paths.
@@ -885,7 +1023,7 @@ def comprehensive_torture_chain():
# MIL requires: d_inner <= Q rows (256), d_xb <= K/V rows (128)
mamba_params = {
"d_inner": 256, # Must be <= heads*head_size = 256
- "d_xb": 64, # Must be <= head_groups*head_size = 128
+ "d_xb": 64, # Must be <= head_groups*head_size = 128
"dt_rank": 16,
"d_state": 16,
"d_conv": 4,
@@ -1532,3 +1670,330 @@ def torture_surgery_chain():
},
},
]
+
+
+# =============================================================================
+# Shared Config Dict Fixtures (for compose_configs / plan_surgery tests)
+# =============================================================================
+
+
+@pytest.fixture
+def base_config_dict():
+ """Complete Apriel2 config dict without biases (Llama-style).
+
+ Use this as the base config for testing compose_configs and plan_surgery.
+ Returns a dict (not Apriel2Config) for direct use with compose_configs.
+ """
+ return {
+ "model_type": "apriel2",
+ "hidden_size": 256,
+ "vocab_size": 1000,
+ "bos_token_id": 1,
+ "eos_token_id": 2,
+ "tie_word_embeddings": False,
+ "decoder": {
+ "type": "fixed",
+ "num_blocks": 2,
+ "block": {
+ "mixer": {
+ "type": "attention",
+ "heads": 8,
+ "head_groups": 4,
+ "head_size": 32,
+ "rotary": {"type": "mistral_1d", "theta": 10000.0},
+ },
+ "mlp": {"type": "mlp", "intermediate_size": 512, "gated": True},
+ "normalization": {"type": "rms_norm", "epsilon": 1e-5},
+ },
+ },
+ }
+
+
+@pytest.fixture
+def base_config_with_bias_dict():
+ """Complete Apriel2 config dict with Qwen-style biases.
+
+ - QKV bias enabled, O bias disabled
+ - Gated MLP (no per-layer bias control in this style)
+
+ Use this for testing bias inheritance through surgery operations.
+ Returns a dict (not Apriel2Config) for direct use with compose_configs.
+ """
+ return {
+ "model_type": "apriel2",
+ "hidden_size": 256,
+ "vocab_size": 1000,
+ "bos_token_id": 1,
+ "eos_token_id": 2,
+ "tie_word_embeddings": False,
+ "decoder": {
+ "type": "fixed",
+ "num_blocks": 2,
+ "block": {
+ "mixer": {
+ "type": "attention",
+ "heads": 8,
+ "head_groups": 4,
+ "head_size": 32,
+ "rotary": {"type": "mistral_1d", "theta": 10000.0},
+ "query_layer": {"bias": {"enabled": True}},
+ "key_layer": {"bias": {"enabled": True}},
+ "value_layer": {"bias": {"enabled": True}},
+ "dense_layer": {"bias": {"enabled": False}},
+ },
+ "mlp": {"type": "mlp", "intermediate_size": 512, "gated": True},
+ "normalization": {"type": "rms_norm", "epsilon": 1e-5},
+ },
+ },
+ }
+
+
+def make_weights_for_config(config: dict) -> dict:
+ """Create random weights matching a config's expected schema.
+
+ This is a helper function (not a fixture) for creating test weights.
+ Use it in tests that need weights for plan execution.
+
+ Args:
+ config: Complete Apriel2 config dict
+
+ Returns:
+ Dict mapping weight key strings to torch tensors
+ """
+ from fast_llm_external_models.apriel2.conversion import W
+
+ hidden = config["hidden_size"]
+ vocab = config["vocab_size"]
+ decoder = config["decoder"]
+ num_blocks = decoder["num_blocks"]
+ block = decoder["block"]
+ mixer = block["mixer"]
+ mlp = block["mlp"]
+
+ heads = mixer["heads"]
+ head_groups = mixer["head_groups"]
+ head_size = mixer["head_size"]
+ inter = mlp["intermediate_size"]
+
+ # Check bias settings
+ has_q_bias = mixer.get("query_layer", {}).get("bias", {}).get("enabled", False)
+ has_k_bias = mixer.get("key_layer", {}).get("bias", {}).get("enabled", False)
+ has_v_bias = mixer.get("value_layer", {}).get("bias", {}).get("enabled", False)
+
+ weights = {}
+ weights["model.embed_tokens.weight"] = torch.randn(vocab, hidden)
+
+ for i in range(num_blocks):
+ p = f"model.decoder.blocks.{i}"
+
+ # Attention
+ weights[f"{p}.mixer.q_proj.weight"] = torch.randn(heads * head_size, hidden)
+ weights[f"{p}.mixer.k_proj.weight"] = torch.randn(head_groups * head_size, hidden)
+ weights[f"{p}.mixer.v_proj.weight"] = torch.randn(head_groups * head_size, hidden)
+ weights[f"{p}.mixer.o_proj.weight"] = torch.randn(hidden, heads * head_size)
+
+ if has_q_bias:
+ weights[f"{p}.mixer.q_proj.bias"] = torch.randn(heads * head_size)
+ if has_k_bias:
+ weights[f"{p}.mixer.k_proj.bias"] = torch.randn(head_groups * head_size)
+ if has_v_bias:
+ weights[f"{p}.mixer.v_proj.bias"] = torch.randn(head_groups * head_size)
+
+ # MLP
+ weights[f"{p}.mlp.up_proj.weight"] = torch.randn(inter, hidden)
+ weights[f"{p}.mlp.gate_proj.weight"] = torch.randn(inter, hidden)
+ weights[f"{p}.mlp.down_proj.weight"] = torch.randn(hidden, inter)
+
+ # Norms
+ weights[f"{p}.input_layernorm.weight"] = torch.randn(hidden)
+ weights[f"{p}.post_attention_layernorm.weight"] = torch.randn(hidden)
+
+ weights["model.norm.weight"] = torch.randn(hidden)
+ weights["lm_head.weight"] = torch.randn(vocab, hidden)
+
+ return {W(k): v for k, v in weights.items()}
+
+
+# =============================================================================
+# Cache Test Fixtures - Tensor Dimensions
+# =============================================================================
+
+
+@pytest.fixture
+def batch_size():
+ """Default batch size for cache tests."""
+ return 2
+
+
+@pytest.fixture
+def num_heads():
+ """Default number of attention heads for cache tests."""
+ return 4
+
+
+@pytest.fixture
+def head_dim():
+ """Default head dimension for cache tests."""
+ return 16
+
+
+@pytest.fixture
+def make_kv(batch_size, num_heads, head_dim):
+ """Factory fixture for creating KV tensors."""
+
+ def _make_kv(seq_len):
+ return (
+ torch.randn(batch_size, num_heads, seq_len, head_dim),
+ torch.randn(batch_size, num_heads, seq_len, head_dim),
+ )
+
+ return _make_kv
+
+
+# =============================================================================
+# Cache Test Fixtures - HuggingFace Cache Layers
+# =============================================================================
+
+
+@pytest.fixture
+def hf_dynamic_layer():
+ """HuggingFace DynamicLayer for full attention contract testing."""
+ from transformers.cache_utils import DynamicLayer
+
+ return DynamicLayer()
+
+
+@pytest.fixture
+def hf_sliding_layer(window_size):
+ """HuggingFace DynamicSlidingWindowLayer for sliding window contract testing."""
+ from transformers.cache_utils import DynamicSlidingWindowLayer
+
+ return DynamicSlidingWindowLayer(sliding_window=window_size)
+
+
+# =============================================================================
+# Cache Test Fixtures - Apriel2 Low-level Caches
+# =============================================================================
+
+
+@pytest.fixture
+def apriel_attention_cache():
+ """Apriel2 attention cache without window (full attention)."""
+ return _AttentionCache(window=None)
+
+
+@pytest.fixture
+def apriel_sliding_cache(window_size):
+ """Apriel2 attention cache with sliding window."""
+ return _AttentionCache(window=window_size)
+
+
+@pytest.fixture
+def ssm_cache():
+ """Apriel2 SSM cache for Mamba/GDN/KDA layers."""
+ return _SSMCache()
+
+
+# =============================================================================
+# Cache Test Fixtures - Apriel2 Configs (Simple Versions)
+# =============================================================================
+
+
+@pytest.fixture
+def attention_config():
+ """Pure attention config (2 layers, no sliding window)."""
+ from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config
+
+ return Apriel2Config(
+ vocab_size=100,
+ hidden_size=64,
+ decoder={
+ "type": "fixed",
+ "num_blocks": 2,
+ "block": {
+ "mixer": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16},
+ "mlp": {"type": "mlp", "intermediate_size": 256},
+ "normalization": {"type": "rms_norm", "epsilon": 1e-5},
+ },
+ },
+ )
+
+
+@pytest.fixture
+def swa_config():
+ """Sliding window attention config (2 layers, window=8)."""
+ from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config
+
+ return Apriel2Config(
+ vocab_size=100,
+ hidden_size=64,
+ decoder={
+ "type": "fixed",
+ "num_blocks": 2,
+ "block": {
+ "mixer": {
+ "type": "attention",
+ "heads": 4,
+ "head_groups": 2,
+ "head_size": 16,
+ "window_size": 8,
+ },
+ "mlp": {"type": "mlp", "intermediate_size": 256},
+ "normalization": {"type": "rms_norm", "epsilon": 1e-5},
+ },
+ },
+ )
+
+
+@pytest.fixture
+def ssm_config():
+ """Pure SSM config (2 layers)."""
+ from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config
+
+ return Apriel2Config(
+ vocab_size=100,
+ hidden_size=64,
+ decoder={
+ "type": "fixed",
+ "num_blocks": 2,
+ "block": {
+ "mixer": {"type": "mamba", "state_size": 16},
+ "mlp": {"type": "mlp", "intermediate_size": 256},
+ "normalization": {"type": "rms_norm", "epsilon": 1e-5},
+ },
+ },
+ )
+
+
+@pytest.fixture
+def stochastic_config():
+ """Stochastic mixer config with attention and mamba (2 layers)."""
+ from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config
+
+ return Apriel2Config(
+ vocab_size=100,
+ hidden_size=64,
+ decoder={
+ "type": "fixed",
+ "num_blocks": 2,
+ "block": {
+ "mixer": {
+ "type": "stochastic",
+ "main_mixer_name": "attention",
+ "mixers": {
+ "attention": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16},
+ "mamba": {"type": "mamba", "state_size": 16},
+ },
+ },
+ "mlp": {"type": "mlp", "intermediate_size": 256},
+ "normalization": {"type": "rms_norm", "epsilon": 1e-5},
+ },
+ },
+ )
+
+
+# Parameterized window size fixture (used by hf_sliding_layer and apriel_sliding_cache)
+@pytest.fixture(params=[4, 8, 16, 32])
+def window_size(request):
+ """Parameterized window sizes for sliding window tests."""
+ return request.param
diff --git a/fast_llm_external_models/tests/test_apriel2/test_cache.py b/fast_llm_external_models/tests/test_apriel2/test_cache.py
deleted file mode 100644
index ca8158b4f..000000000
--- a/fast_llm_external_models/tests/test_apriel2/test_cache.py
+++ /dev/null
@@ -1,1258 +0,0 @@
-"""Comprehensive tests for Apriel2Cache.
-
-Architecture Overview
-=====================
-Apriel2Cache manages state for autoregressive generation across different mixer types:
-
-1. **Attention Cache** (_AttentionCache): Stores key/value states
- - Supports sliding window (window_size) for SWA
- - Efficient roll optimization for single-token decode
-
-2. **SSM Cache** (_SSMCache): Stores conv and recurrent states
- - Used by Mamba, GDN, KDA
- - KDA uses tuple conv states (q, k, v), others use single tensor
-
-3. **Stochastic Mixer Routing**: For layers with multiple mixer options
- - Each mixer has independent cache (no sharing)
- - active_mixer pointer routes operations to correct sub-cache
- - Switching mixers preserves each mixer's independent state
-
-Cache Invalidation Semantics
-============================
-When switching between mixers in a stochastic layer:
-- Each mixer maintains its OWN independent history
-- Switching does NOT invalidate the previous mixer's cache
-- Switching does NOT copy state between mixers
-- To invalidate: call reset() explicitly
-
-This is intentional for training with stochastic sampling where each mixer
-should learn from its own history. For inference, main_mixer_name is fixed.
-
-Test Organization
-=================
-1. CREATION & PROPERTIES - Cache initialization, config parsing
-2. ATTENTION CACHE - Updates, sliding window, concatenation
-3. SSM CACHE - Conv states, recurrent states, KDA tuples
-4. STOCHASTIC ROUTING - Active mixer, isolation, switching
-5. CACHE INVALIDATION - Reset, per-mixer reset, coherence
-6. BEAM SEARCH - batch_repeat, reorder, select
-7. HF INTEGRATION - get_mask_sizes, indexing, properties
-8. GENERATION PATTERNS - Prefillโdecode, cropโcontinue
-9. ERROR HANDLING - Guards, bounds, invalid operations
-"""
-
-import pytest
-import torch
-
-from fast_llm_external_models.apriel2.cache import (
- Apriel2Cache,
- _AttentionCache,
- _SSMCache,
-)
-
-
-# =============================================================================
-# FIXTURES - Configs and Sample Data
-# =============================================================================
-
-
-@pytest.fixture
-def tiny_attention_config():
- """Minimal config with pure attention layers."""
- from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config
-
- return Apriel2Config(
- vocab_size=100,
- hidden_size=64,
- decoder={
- "type": "fixed",
- "num_blocks": 2,
- "block": {
- "mixer": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16},
- "mlp": {"type": "mlp", "intermediate_size": 256},
- "normalization": {"type": "rms_norm", "epsilon": 1e-5},
- },
- },
- )
-
-
-@pytest.fixture
-def swa_config():
- """Config with sliding window attention."""
- from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config
-
- return Apriel2Config(
- vocab_size=100,
- hidden_size=64,
- decoder={
- "type": "fixed",
- "num_blocks": 2,
- "block": {
- "mixer": {
- "type": "attention",
- "heads": 4,
- "head_groups": 2,
- "head_size": 16,
- "window_size": 8, # Small for testing
- },
- "mlp": {"type": "mlp", "intermediate_size": 256},
- "normalization": {"type": "rms_norm", "epsilon": 1e-5},
- },
- },
- )
-
-
-@pytest.fixture
-def ssm_config():
- """Config with pure SSM layers (mamba)."""
- from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config
-
- return Apriel2Config(
- vocab_size=100,
- hidden_size=64,
- decoder={
- "type": "fixed",
- "num_blocks": 2,
- "block": {
- "mixer": {
- "type": "mamba",
- "d_inner": 128,
- "d_state": 16,
- "dt_rank": 4,
- "d_conv": 4,
- },
- "mlp": {"type": "mlp", "intermediate_size": 256},
- "normalization": {"type": "rms_norm", "epsilon": 1e-5},
- },
- },
- )
-
-
-@pytest.fixture
-def kda_config():
- """Config with pure KDA layers."""
- from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config
-
- return Apriel2Config(
- vocab_size=100,
- hidden_size=64,
- decoder={
- "type": "fixed",
- "num_blocks": 2,
- "block": {
- "mixer": {
- "type": "kda",
- "heads": 4,
- "head_dim": 16,
- "convolution_layer": {"kernel_size": 4},
- "normalization": {"epsilon": 1e-5},
- },
- "mlp": {"type": "mlp", "intermediate_size": 256},
- "normalization": {"type": "rms_norm", "epsilon": 1e-5},
- },
- },
- )
-
-
-@pytest.fixture
-def stochastic_config():
- """Config with stochastic mixer (attention + mamba)."""
- from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config
-
- return Apriel2Config(
- vocab_size=100,
- hidden_size=64,
- decoder={
- "type": "pattern",
- "num_blocks": 2,
- "pattern": ["attn", "stochastic"],
- "blocks": {
- "attn": {
- "mixer": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16},
- "mlp": {"type": "mlp", "intermediate_size": 256},
- "normalization": {"type": "rms_norm", "epsilon": 1e-5},
- },
- "stochastic": {
- "mixer": {
- "type": "stochastic",
- "main_mixer_name": "attention",
- "mixers": {
- "attention": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16},
- "mamba": {"type": "mamba", "d_inner": 128, "d_state": 16, "dt_rank": 4, "d_conv": 4},
- },
- },
- "mlp": {"type": "mlp", "intermediate_size": 256},
- "normalization": {"type": "rms_norm", "epsilon": 1e-5},
- },
- },
- },
- )
-
-
-@pytest.fixture
-def all_mixers_config():
- """Config with stochastic mixer containing all 5 mixer types."""
- from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config
-
- return Apriel2Config(
- vocab_size=100,
- hidden_size=64,
- decoder={
- "type": "pattern",
- "num_blocks": 2,
- "pattern": ["attn", "all_mixers"],
- "blocks": {
- "attn": {
- "mixer": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16},
- "mlp": {"type": "mlp", "intermediate_size": 256},
- "normalization": {"type": "rms_norm", "epsilon": 1e-5},
- },
- "all_mixers": {
- "mixer": {
- "type": "stochastic",
- "main_mixer_name": "attention",
- "mixers": {
- "attention": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16},
- "swa": {
- "type": "attention",
- "heads": 4,
- "head_groups": 2,
- "head_size": 16,
- "window_size": 1024,
- },
- "mamba": {"type": "mamba", "d_inner": 128, "d_state": 16, "dt_rank": 4, "d_conv": 4},
- "gdn": {
- "type": "gdn",
- "value_heads": 4,
- "key_heads": 2,
- "key_head_dim": 16,
- "value_head_dim": 16,
- "convolution_layer": {"kernel_size": 4},
- },
- "kda": {
- "type": "kda",
- "heads": 4,
- "head_dim": 16,
- "convolution_layer": {"kernel_size": 4},
- "normalization": {"epsilon": 1e-5},
- },
- },
- },
- "mlp": {"type": "mlp", "intermediate_size": 256},
- "normalization": {"type": "rms_norm", "epsilon": 1e-5},
- },
- },
- },
- )
-
-
-@pytest.fixture
-def multi_window_config():
- """Config with multiple different window sizes."""
- from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config
-
- return Apriel2Config(
- vocab_size=100,
- hidden_size=64,
- decoder={
- "type": "pattern",
- "num_blocks": 3,
- "pattern": ["full", "small_window", "large_window"],
- "blocks": {
- "full": {
- "mixer": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16},
- "mlp": {"type": "mlp", "intermediate_size": 256},
- "normalization": {"type": "rms_norm", "epsilon": 1e-5},
- },
- "small_window": {
- "mixer": {
- "type": "attention",
- "heads": 4,
- "head_groups": 2,
- "head_size": 16,
- "window_size": 512,
- },
- "mlp": {"type": "mlp", "intermediate_size": 256},
- "normalization": {"type": "rms_norm", "epsilon": 1e-5},
- },
- "large_window": {
- "mixer": {
- "type": "attention",
- "heads": 4,
- "head_groups": 2,
- "head_size": 16,
- "window_size": 2048,
- },
- "mlp": {"type": "mlp", "intermediate_size": 256},
- "normalization": {"type": "rms_norm", "epsilon": 1e-5},
- },
- },
- },
- )
-
-
-@pytest.fixture
-def sample_kv():
- """Sample key/value tensors: [batch=2, heads=4, seq=10, head_dim=16]."""
- return torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16)
-
-
-@pytest.fixture
-def sample_conv_single():
- """Sample single-tensor conv state: [batch=2, d_inner=128, kernel=4]."""
- return torch.randn(2, 128, 4)
-
-
-@pytest.fixture
-def sample_conv_tuple():
- """Sample tuple conv state for KDA: (q, k, v) each [batch=2, d=64, kernel=3]."""
- return (torch.randn(2, 64, 3), torch.randn(2, 64, 3), torch.randn(2, 64, 3))
-
-
-@pytest.fixture
-def sample_recurrent():
- """Sample recurrent state: [batch=2, heads=4, head_dim=16, d_state=16]."""
- return torch.randn(2, 4, 16, 16)
-
-
-# =============================================================================
-# SECTION 1: CACHE CREATION & PROPERTIES
-# =============================================================================
-
-
-class TestCacheCreation:
- """Test cache initialization from config."""
-
- def test_attention_cache_creation(self, tiny_attention_config):
- """Create cache for pure attention config."""
- cache = Apriel2Cache(tiny_attention_config)
-
- assert len(cache) == 2
- assert cache.mixer_types == ["attention", "attention"]
- assert all(isinstance(l, _AttentionCache) for l in cache.layers)
-
- def test_ssm_cache_creation(self, ssm_config):
- """Create cache for pure SSM config."""
- cache = Apriel2Cache(ssm_config)
-
- assert len(cache) == 2
- assert cache.mixer_types == ["mamba", "mamba"]
- assert all(isinstance(l, _SSMCache) for l in cache.layers)
-
- def test_kda_cache_creation(self, kda_config):
- """Create cache for pure KDA config."""
- cache = Apriel2Cache(kda_config)
-
- assert len(cache) == 2
- assert cache.mixer_types == ["kda", "kda"]
- assert all(isinstance(l, _SSMCache) for l in cache.layers)
-
- def test_stochastic_cache_creation(self, stochastic_config):
- """Create cache for stochastic mixer config."""
- cache = Apriel2Cache(stochastic_config)
-
- assert len(cache) == 2
- # Layer 0: pure attention, Layer 1: stochastic (dict)
- assert isinstance(cache.layers[0], _AttentionCache)
- assert isinstance(cache.layers[1], dict)
- assert set(cache.layers[1].keys()) == {"attention", "mamba"}
-
- def test_swa_window_captured(self, swa_config):
- """Verify sliding window size is captured."""
- cache = Apriel2Cache(swa_config)
-
- assert cache.layers[0].window == 8
- assert cache.is_sliding == [True, True]
-
- def test_active_mixers_initialized_none(self, stochastic_config):
- """Verify active_mixers starts as None for all layers."""
- cache = Apriel2Cache(stochastic_config)
-
- assert cache.active_mixers == [None, None]
-
-
-class TestCacheProperties:
- """Test cache property accessors."""
-
- def test_empty_cache_properties(self, tiny_attention_config):
- """Test properties of uninitialized cache."""
- cache = Apriel2Cache(tiny_attention_config)
-
- assert cache.is_initialized == False
- assert cache.has_previous_state == False
- assert cache.max_batch_size is None
- assert cache.max_cache_len is None
- assert cache.is_compileable == False
-
- def test_is_initialized_attention(self, tiny_attention_config, sample_kv):
- """is_initialized detects attention cache."""
- cache = Apriel2Cache(tiny_attention_config)
- cache.update(*sample_kv, layer_idx=0)
-
- assert cache.is_initialized == True
-
- def test_is_initialized_ssm(self, ssm_config, sample_conv_single):
- """is_initialized detects SSM cache."""
- cache = Apriel2Cache(ssm_config)
- cache.conv_states[0] = sample_conv_single
-
- assert cache.is_initialized == True
-
- def test_has_previous_state_ssm_only(self, ssm_config, sample_conv_single):
- """has_previous_state only looks at SSM conv states."""
- cache = Apriel2Cache(ssm_config)
-
- assert cache.has_previous_state == False
- cache.conv_states[0] = sample_conv_single
- assert cache.has_previous_state == True
-
- def test_has_previous_state_ignores_attention(self, tiny_attention_config, sample_kv):
- """has_previous_state ignores attention cache."""
- cache = Apriel2Cache(tiny_attention_config)
- cache.update(*sample_kv, layer_idx=0)
-
- # Attention cache is set, but has_previous_state only checks SSM
- assert cache.has_previous_state == False
-
- def test_max_batch_size_from_attention(self, tiny_attention_config, sample_kv):
- """max_batch_size from attention cache."""
- cache = Apriel2Cache(tiny_attention_config)
- cache.update(*sample_kv, layer_idx=0)
-
- assert cache.max_batch_size == 2
-
- def test_max_batch_size_from_ssm(self, ssm_config, sample_conv_single):
- """max_batch_size from SSM cache."""
- cache = Apriel2Cache(ssm_config)
- cache.conv_states[0] = sample_conv_single
-
- assert cache.max_batch_size == 2
-
- def test_max_batch_size_from_kda_tuple(self, kda_config, sample_conv_tuple):
- """max_batch_size from KDA tuple conv state."""
- cache = Apriel2Cache(kda_config)
- cache.conv_states[0] = sample_conv_tuple
-
- assert cache.max_batch_size == 2
-
- def test_max_cache_len_single_window(self, swa_config):
- """max_cache_len with single window size."""
- cache = Apriel2Cache(swa_config)
- assert cache.max_cache_len == 8
-
- def test_max_cache_len_multiple_windows(self, multi_window_config):
- """max_cache_len returns minimum window."""
- cache = Apriel2Cache(multi_window_config)
- assert cache.max_cache_len == 512 # min(512, 2048)
-
- def test_max_cache_len_no_windows(self, tiny_attention_config):
- """max_cache_len is None when no windows."""
- cache = Apriel2Cache(tiny_attention_config)
- assert cache.max_cache_len is None
-
- def test_is_sliding_mixed(self, multi_window_config):
- """is_sliding reflects per-layer window presence."""
- cache = Apriel2Cache(multi_window_config)
- assert cache.is_sliding == [False, True, True]
-
-
-# =============================================================================
-# SECTION 2: ATTENTION CACHE OPERATIONS
-# =============================================================================
-
-
-class TestAttentionCacheBasics:
- """Test basic attention cache operations."""
-
- def test_update_stores_kv(self, tiny_attention_config, sample_kv):
- """update() stores key/value states."""
- cache = Apriel2Cache(tiny_attention_config)
- key, value = sample_kv
-
- k_out, v_out = cache.update(key, value, layer_idx=0)
-
- torch.testing.assert_close(k_out, key)
- torch.testing.assert_close(v_out, value)
- assert cache.get_seq_length(0) == 10
-
- def test_update_concatenates(self, tiny_attention_config, sample_kv):
- """Subsequent updates concatenate."""
- cache = Apriel2Cache(tiny_attention_config)
- key, value = sample_kv
-
- cache.update(key, value, layer_idx=0)
- k_out, v_out = cache.update(key, value, layer_idx=0)
-
- assert k_out.shape[-2] == 20
- assert cache.get_seq_length(0) == 20
-
- def test_key_value_cache_accessors(self, tiny_attention_config, sample_kv):
- """Test key_cache and value_cache accessors."""
- cache = Apriel2Cache(tiny_attention_config)
- cache.update(*sample_kv, layer_idx=0)
-
- assert cache.key_cache[0] is not None
- assert cache.value_cache[0] is not None
- torch.testing.assert_close(cache.key_cache[0], sample_kv[0])
-
-
-class TestSlidingWindowAttention:
- """Test sliding window attention behavior."""
-
- def test_initial_within_window(self, swa_config):
- """Initial sequence within window is kept."""
- cache = Apriel2Cache(swa_config)
- key = torch.randn(2, 4, 5, 16) # seq=5 < window=8
- value = torch.randn(2, 4, 5, 16)
-
- cache.update(key, value, layer_idx=0)
-
- assert cache.get_seq_length(0) == 5
-
- def test_initial_exceeds_window(self, swa_config):
- """Initial sequence > window is truncated to last window tokens."""
- cache = Apriel2Cache(swa_config)
- key = torch.arange(12).float().view(1, 1, 12, 1).expand(2, 4, 12, 16)
- value = key.clone()
-
- k_out, v_out = cache.update(key, value, layer_idx=0)
-
- assert cache.get_seq_length(0) == 8
- # Should keep tokens 4-11 (last 8)
- assert k_out[0, 0, 0, 0].item() == 4.0
-
- def test_single_token_roll_path(self, swa_config):
- """Single token decode with full window uses efficient roll."""
- cache = Apriel2Cache(swa_config)
-
- # Fill window exactly
- key1 = torch.arange(8).float().view(1, 1, 8, 1).expand(2, 4, 8, 16)
- cache.update(key1, key1.clone(), layer_idx=0)
-
- # Decode single token
- key2 = torch.full((2, 4, 1, 16), 8.0)
- k_out, _ = cache.update(key2, key2.clone(), layer_idx=0)
-
- assert cache.get_seq_length(0) == 8
- assert k_out[0, 0, 0, 0].item() == 1.0 # Token 0 rolled out
- assert k_out[0, 0, 7, 0].item() == 8.0 # New token at end
-
- def test_multi_token_cat_slice_path(self, swa_config):
- """Multiple tokens use cat+slice path."""
- cache = Apriel2Cache(swa_config)
-
- # Fill window
- key1 = torch.randn(2, 4, 8, 16)
- cache.update(key1, key1.clone(), layer_idx=0)
-
- # Add 3 tokens
- key2 = torch.randn(2, 4, 3, 16)
- k_out, _ = cache.update(key2, key2.clone(), layer_idx=0)
-
- assert cache.get_seq_length(0) == 8
- torch.testing.assert_close(k_out[..., -3:, :], key2)
-
- def test_partial_then_fill_then_overflow(self, swa_config):
- """Progressive filling: partial โ full โ overflow."""
- cache = Apriel2Cache(swa_config)
-
- cache.update(torch.randn(2, 4, 5, 16), torch.randn(2, 4, 5, 16), layer_idx=0)
- assert cache.get_seq_length(0) == 5
-
- cache.update(torch.randn(2, 4, 3, 16), torch.randn(2, 4, 3, 16), layer_idx=0)
- assert cache.get_seq_length(0) == 8
-
- cache.update(torch.randn(2, 4, 2, 16), torch.randn(2, 4, 2, 16), layer_idx=0)
- assert cache.get_seq_length(0) == 8
-
- def test_contiguous_output(self, swa_config):
- """Outputs are contiguous after windowing."""
- cache = Apriel2Cache(swa_config)
-
- cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0)
- cache.update(torch.randn(2, 4, 5, 16), torch.randn(2, 4, 5, 16), layer_idx=0)
-
- assert cache.layers[0].key.is_contiguous()
- assert cache.layers[0].value.is_contiguous()
-
-
-# =============================================================================
-# SECTION 3: SSM CACHE OPERATIONS
-# =============================================================================
-
-
-class TestSSMCacheBasics:
- """Test basic SSM cache operations."""
-
- def test_conv_states_accessor(self, ssm_config, sample_conv_single):
- """Test conv_states accessor."""
- cache = Apriel2Cache(ssm_config)
-
- cache.conv_states[0] = sample_conv_single
- torch.testing.assert_close(cache.conv_states[0], sample_conv_single)
-
- def test_recurrent_states_accessor(self, ssm_config, sample_recurrent):
- """Test recurrent_states accessor."""
- cache = Apriel2Cache(ssm_config)
-
- cache.recurrent_states[0] = sample_recurrent
- torch.testing.assert_close(cache.recurrent_states[0], sample_recurrent)
-
- def test_ssm_seq_length_always_zero(self, ssm_config, sample_conv_single):
- """get_seq_length returns 0 for SSM (no KV cache)."""
- cache = Apriel2Cache(ssm_config)
- cache.conv_states[0] = sample_conv_single
-
- assert cache.get_seq_length(0) == 0
-
-
-class TestKDACache:
- """Test KDA-specific cache operations with tuple conv states."""
-
- def test_tuple_conv_storage(self, kda_config, sample_conv_tuple):
- """KDA stores tuple conv states."""
- cache = Apriel2Cache(kda_config)
-
- cache.conv_states[0] = sample_conv_tuple
-
- assert isinstance(cache.conv_states[0], tuple)
- assert len(cache.conv_states[0]) == 3
- for i in range(3):
- torch.testing.assert_close(cache.conv_states[0][i], sample_conv_tuple[i])
-
- def test_tuple_with_recurrent(self, kda_config, sample_conv_tuple, sample_recurrent):
- """KDA can have both tuple conv and recurrent states."""
- cache = Apriel2Cache(kda_config)
-
- cache.conv_states[0] = sample_conv_tuple
- cache.recurrent_states[0] = sample_recurrent
-
- assert isinstance(cache.conv_states[0], tuple)
- assert cache.recurrent_states[0] is not None
-
- def test_has_previous_state_detects_tuple(self, kda_config, sample_conv_tuple):
- """has_previous_state works with tuple conv states."""
- cache = Apriel2Cache(kda_config)
-
- assert cache.has_previous_state == False
- cache.conv_states[0] = sample_conv_tuple
- assert cache.has_previous_state == True
-
-
-# =============================================================================
-# SECTION 4: STOCHASTIC ROUTING
-# =============================================================================
-
-
-class TestStochasticRouting:
- """Test stochastic mixer cache routing."""
-
- def test_set_active_mixer(self, stochastic_config):
- """set_active_mixer sets the pointer."""
- cache = Apriel2Cache(stochastic_config)
-
- cache.set_active_mixer(1, "attention")
- assert cache.active_mixers[1] == "attention"
-
- cache.set_active_mixer(1, "mamba")
- assert cache.active_mixers[1] == "mamba"
-
- def test_operations_route_to_active(self, stochastic_config, sample_kv):
- """Operations route to currently active mixer."""
- cache = Apriel2Cache(stochastic_config)
-
- cache.set_active_mixer(1, "attention")
- cache.update(*sample_kv, layer_idx=1)
- attn_len = cache.get_seq_length(1)
-
- cache.set_active_mixer(1, "mamba")
- mamba_len = cache.get_seq_length(1)
-
- assert attn_len == 10
- assert mamba_len == 0 # Mamba cache is separate and empty
-
- def test_each_mixer_independent_cache(self, stochastic_config, sample_kv, sample_conv_single):
- """Each mixer maintains independent cache."""
- cache = Apriel2Cache(stochastic_config)
-
- # Fill attention cache
- cache.set_active_mixer(1, "attention")
- cache.update(*sample_kv, layer_idx=1)
-
- # Fill mamba cache
- cache.set_active_mixer(1, "mamba")
- cache.conv_states[1] = sample_conv_single
-
- # Both preserved
- cache.set_active_mixer(1, "attention")
- assert cache.get_seq_length(1) == 10
-
- cache.set_active_mixer(1, "mamba")
- torch.testing.assert_close(cache.conv_states[1], sample_conv_single)
-
-
-class TestMixerSwitching:
- """Test behavior when switching between mixers mid-generation."""
-
- def test_switch_preserves_previous_state(self, stochastic_config, sample_kv):
- """Switching mixers preserves previous mixer's state."""
- cache = Apriel2Cache(stochastic_config)
-
- cache.set_active_mixer(1, "attention")
- cache.update(*sample_kv, layer_idx=1)
- original_key = cache.layers[1]["attention"].key.clone()
-
- # Switch to mamba, do something
- cache.set_active_mixer(1, "mamba")
- cache.conv_states[1] = torch.randn(2, 128, 4)
-
- # Switch back - attention unchanged
- cache.set_active_mixer(1, "attention")
- torch.testing.assert_close(cache.layers[1]["attention"].key, original_key)
-
- def test_switch_does_not_copy_state(self, stochastic_config, sample_kv):
- """Switching does NOT copy state between mixers."""
- cache = Apriel2Cache(stochastic_config)
-
- # Fill attention with 10 tokens
- cache.set_active_mixer(1, "attention")
- cache.update(*sample_kv, layer_idx=1)
-
- # Switch to mamba - it has NO history from attention
- cache.set_active_mixer(1, "mamba")
- assert cache.conv_states[1] is None
- assert cache.recurrent_states[1] is None
-
- def test_has_previous_state_checks_all_sub_caches(self, stochastic_config):
- """has_previous_state checks ALL sub-caches, not just active."""
- cache = Apriel2Cache(stochastic_config)
-
- cache.set_active_mixer(1, "mamba")
- cache.conv_states[1] = torch.randn(2, 128, 4)
-
- # Even if we switch away, has_previous_state still detects it
- cache.set_active_mixer(1, "attention")
- assert cache.has_previous_state == True
-
-
-class TestAllMixerTypes:
- """Test cache isolation across all 5 mixer types."""
-
- def test_all_five_mixer_types_isolated(self, all_mixers_config):
- """All 5 mixer types maintain isolated caches."""
- cache = Apriel2Cache(all_mixers_config)
- layer_idx = 1 # Stochastic layer
-
- # Fill each mixer's cache
- cache.set_active_mixer(layer_idx, "attention")
- attn_kv = (torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16))
- cache.update(*attn_kv, layer_idx=layer_idx)
-
- cache.set_active_mixer(layer_idx, "swa")
- swa_kv = (torch.randn(2, 4, 5, 16), torch.randn(2, 4, 5, 16))
- cache.update(*swa_kv, layer_idx=layer_idx)
-
- cache.set_active_mixer(layer_idx, "mamba")
- mamba_conv = torch.randn(2, 128, 4)
- cache.conv_states[layer_idx] = mamba_conv
-
- cache.set_active_mixer(layer_idx, "gdn")
- gdn_conv = torch.randn(2, 64, 3)
- cache.conv_states[layer_idx] = gdn_conv
-
- cache.set_active_mixer(layer_idx, "kda")
- kda_conv = (torch.randn(2, 64, 3), torch.randn(2, 64, 3), torch.randn(2, 64, 3))
- cache.conv_states[layer_idx] = kda_conv
-
- # Verify all preserved
- cache.set_active_mixer(layer_idx, "attention")
- assert cache.get_seq_length(layer_idx) == 10
-
- cache.set_active_mixer(layer_idx, "swa")
- assert cache.get_seq_length(layer_idx) == 5
-
- cache.set_active_mixer(layer_idx, "mamba")
- torch.testing.assert_close(cache.conv_states[layer_idx], mamba_conv)
-
- cache.set_active_mixer(layer_idx, "gdn")
- torch.testing.assert_close(cache.conv_states[layer_idx], gdn_conv)
-
- cache.set_active_mixer(layer_idx, "kda")
- assert isinstance(cache.conv_states[layer_idx], tuple)
-
-
-# =============================================================================
-# SECTION 5: CACHE INVALIDATION
-# =============================================================================
-
-
-class TestCacheInvalidation:
- """Test cache invalidation and reset semantics.
-
- Key principle: Each mixer maintains independent state. To invalidate:
- - reset() clears ALL caches across ALL layers and mixers
- - There is no per-mixer reset (by design - each mixer is independent)
- """
-
- def test_reset_clears_attention(self, tiny_attention_config, sample_kv):
- """reset() clears attention cache."""
- cache = Apriel2Cache(tiny_attention_config)
- cache.update(*sample_kv, layer_idx=0)
-
- cache.reset()
-
- assert cache.is_initialized == False
- assert cache.get_seq_length(0) == 0
-
- def test_reset_clears_ssm(self, ssm_config, sample_conv_single, sample_recurrent):
- """reset() clears SSM cache."""
- cache = Apriel2Cache(ssm_config)
- cache.conv_states[0] = sample_conv_single
- cache.recurrent_states[0] = sample_recurrent
-
- cache.reset()
-
- assert cache.has_previous_state == False
- assert cache.conv_states[0] is None
- assert cache.recurrent_states[0] is None
-
- def test_reset_clears_kda_tuple(self, kda_config, sample_conv_tuple):
- """reset() clears KDA tuple conv states."""
- cache = Apriel2Cache(kda_config)
- cache.conv_states[0] = sample_conv_tuple
-
- cache.reset()
-
- assert cache.conv_states[0] is None
-
- def test_reset_clears_all_stochastic_mixers(self, all_mixers_config):
- """reset() clears ALL mixer caches in stochastic layer."""
- cache = Apriel2Cache(all_mixers_config)
- layer_idx = 1
-
- # Fill all mixers
- cache.set_active_mixer(layer_idx, "attention")
- cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=layer_idx)
-
- cache.set_active_mixer(layer_idx, "mamba")
- cache.conv_states[layer_idx] = torch.randn(2, 128, 4)
-
- cache.set_active_mixer(layer_idx, "kda")
- cache.conv_states[layer_idx] = (torch.randn(2, 64, 3),) * 3
-
- cache.reset()
-
- # All cleared
- assert cache.layers[layer_idx]["attention"].key is None
- assert cache.layers[layer_idx]["mamba"].conv is None
- assert cache.layers[layer_idx]["kda"].conv is None
-
- def test_crop_truncates_attention(self, tiny_attention_config, sample_kv):
- """crop() truncates attention cache to max_length."""
- cache = Apriel2Cache(tiny_attention_config)
- cache.update(*sample_kv, layer_idx=0)
-
- cache.crop(5)
-
- assert cache.get_seq_length(0) == 5
-
- def test_crop_affects_all_layers(self, tiny_attention_config, sample_kv):
- """crop() affects all layers."""
- cache = Apriel2Cache(tiny_attention_config)
- cache.update(*sample_kv, layer_idx=0)
- cache.update(*sample_kv, layer_idx=1)
-
- cache.crop(3)
-
- assert cache.get_seq_length(0) == 3
- assert cache.get_seq_length(1) == 3
-
- def test_crop_ignores_ssm(self, ssm_config, sample_conv_single):
- """crop() only affects attention, not SSM."""
- cache = Apriel2Cache(ssm_config)
- cache.conv_states[0] = sample_conv_single
-
- cache.crop(5) # Should not crash
-
- # Conv state unchanged
- torch.testing.assert_close(cache.conv_states[0], sample_conv_single)
-
-
-# =============================================================================
-# SECTION 6: BEAM SEARCH OPERATIONS
-# =============================================================================
-
-
-class TestBatchRepeatInterleave:
- """Test batch_repeat_interleave for beam search expansion."""
-
- def test_repeat_attention(self, tiny_attention_config, sample_kv):
- """Repeat attention cache for beam search."""
- cache = Apriel2Cache(tiny_attention_config)
- cache.update(*sample_kv, layer_idx=0)
-
- cache.batch_repeat_interleave(3)
-
- assert cache.max_batch_size == 6 # 2 * 3
-
- def test_repeat_ssm(self, ssm_config, sample_conv_single, sample_recurrent):
- """Repeat SSM cache for beam search."""
- cache = Apriel2Cache(ssm_config)
- cache.conv_states[0] = sample_conv_single
- cache.recurrent_states[0] = sample_recurrent
-
- cache.batch_repeat_interleave(4)
-
- assert cache.conv_states[0].shape[0] == 8 # 2 * 4
- assert cache.recurrent_states[0].shape[0] == 8
-
- def test_repeat_kda_tuple(self, kda_config, sample_conv_tuple):
- """Repeat KDA tuple conv states."""
- cache = Apriel2Cache(kda_config)
- cache.conv_states[0] = sample_conv_tuple
-
- cache.batch_repeat_interleave(3)
-
- for c in cache.conv_states[0]:
- assert c.shape[0] == 6
-
- def test_repeat_stochastic_all_mixers(self, all_mixers_config):
- """Repeat all mixer caches in stochastic layer."""
- cache = Apriel2Cache(all_mixers_config)
- layer_idx = 1
-
- cache.set_active_mixer(layer_idx, "attention")
- cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=layer_idx)
-
- cache.set_active_mixer(layer_idx, "mamba")
- cache.conv_states[layer_idx] = torch.randn(2, 128, 4)
-
- cache.batch_repeat_interleave(2)
-
- cache.set_active_mixer(layer_idx, "attention")
- assert cache.layers[layer_idx]["attention"].key.shape[0] == 4
-
- cache.set_active_mixer(layer_idx, "mamba")
- assert cache.conv_states[layer_idx].shape[0] == 4
-
- def test_repeat_skips_none(self, tiny_attention_config):
- """Repeat gracefully skips None caches."""
- cache = Apriel2Cache(tiny_attention_config)
- # Don't fill anything
-
- cache.batch_repeat_interleave(3) # Should not crash
-
- assert cache.max_batch_size is None
-
-
-class TestReorderCache:
- """Test reorder_cache for beam search hypothesis selection."""
-
- def test_reorder_attention(self, tiny_attention_config, sample_kv):
- """Reorder attention cache."""
- cache = Apriel2Cache(tiny_attention_config)
- key, value = sample_kv
- # Make batches distinguishable
- key = torch.arange(2).float().view(2, 1, 1, 1).expand(2, 4, 10, 16)
- cache.update(key, key.clone(), layer_idx=0)
-
- beam_idx = torch.tensor([1, 0])
- cache.reorder_cache(beam_idx)
-
- assert cache.layers[0].key[0, 0, 0, 0].item() == 1.0
- assert cache.layers[0].key[1, 0, 0, 0].item() == 0.0
-
- def test_reorder_ssm(self, ssm_config):
- """Reorder SSM cache."""
- cache = Apriel2Cache(ssm_config)
- conv = torch.arange(2).float().view(2, 1, 1).expand(2, 128, 4)
- cache.conv_states[0] = conv.clone()
-
- beam_idx = torch.tensor([1, 0])
- cache.reorder_cache(beam_idx)
-
- assert cache.conv_states[0][0, 0, 0].item() == 1.0
-
- def test_reorder_kda_tuple(self, kda_config):
- """Reorder KDA tuple conv states."""
- cache = Apriel2Cache(kda_config)
- conv_q = torch.arange(2).float().view(2, 1, 1).expand(2, 64, 3)
- cache.conv_states[0] = (conv_q.clone(), conv_q.clone(), conv_q.clone())
-
- beam_idx = torch.tensor([1, 0])
- cache.reorder_cache(beam_idx)
-
- for c in cache.conv_states[0]:
- assert c[0, 0, 0].item() == 1.0
-
-
-class TestBatchSelectIndices:
- """Test batch_select_indices for beam selection."""
-
- def test_select_attention(self, tiny_attention_config, sample_kv):
- """Select subset of attention cache."""
- cache = Apriel2Cache(tiny_attention_config)
- key = torch.arange(4).float().view(4, 1, 1, 1).expand(4, 4, 10, 16)
- cache.update(key, key.clone(), layer_idx=0)
-
- indices = torch.tensor([0, 3])
- cache.batch_select_indices(indices)
-
- assert cache.max_batch_size == 2
- assert cache.layers[0].key[0, 0, 0, 0].item() == 0.0
- assert cache.layers[0].key[1, 0, 0, 0].item() == 3.0
-
- def test_select_kda_tuple(self, kda_config):
- """Select subset of KDA tuple conv states."""
- cache = Apriel2Cache(kda_config)
- conv = tuple(torch.arange(4).float().view(4, 1, 1).expand(4, 64, 3).clone() for _ in range(3))
- cache.conv_states[0] = conv
-
- indices = torch.tensor([1, 2])
- cache.batch_select_indices(indices)
-
- for c in cache.conv_states[0]:
- assert c.shape[0] == 2
- assert c[0, 0, 0].item() == 1.0
-
-
-# =============================================================================
-# SECTION 7: HUGGINGFACE INTEGRATION
-# =============================================================================
-
-
-class TestGetMaskSizes:
- """Test get_mask_sizes() for attention mask computation."""
-
- def test_empty_cache(self, tiny_attention_config):
- """Mask sizes with empty cache."""
- cache = Apriel2Cache(tiny_attention_config)
- cache_position = torch.arange(10)
-
- kv_length, kv_offset = cache.get_mask_sizes(cache_position, layer_idx=0)
-
- assert kv_length == 10
- assert kv_offset == 0
-
- def test_with_cached_tokens(self, tiny_attention_config, sample_kv):
- """Mask sizes with cached tokens."""
- cache = Apriel2Cache(tiny_attention_config)
- cache.update(*sample_kv, layer_idx=0) # 10 tokens
-
- cache_position = torch.arange(5)
- kv_length, kv_offset = cache.get_mask_sizes(cache_position, layer_idx=0)
-
- assert kv_length == 15 # 10 + 5
- assert kv_offset == 10
-
- def test_single_token_decode(self, tiny_attention_config, sample_kv):
- """Mask sizes for single token decode."""
- cache = Apriel2Cache(tiny_attention_config)
- cache.update(*sample_kv, layer_idx=0)
-
- cache_position = torch.arange(1)
- kv_length, kv_offset = cache.get_mask_sizes(cache_position, layer_idx=0)
-
- assert kv_length == 11
- assert kv_offset == 10
-
- def test_ssm_returns_query_only(self, ssm_config, sample_conv_single):
- """SSM layers return query_length (no KV cache)."""
- cache = Apriel2Cache(ssm_config)
- cache.conv_states[0] = sample_conv_single
-
- cache_position = torch.arange(5)
- kv_length, kv_offset = cache.get_mask_sizes(cache_position, layer_idx=0)
-
- assert kv_length == 5
- assert kv_offset == 0
-
-
-class TestCacheIndexing:
- """Test cache[idx] indexing."""
-
- def test_attention_returns_kv(self, tiny_attention_config, sample_kv):
- """Indexing attention layer returns (key, value)."""
- cache = Apriel2Cache(tiny_attention_config)
- cache.update(*sample_kv, layer_idx=0)
-
- result = cache[0]
-
- assert isinstance(result, tuple)
- torch.testing.assert_close(result[0], sample_kv[0])
-
- def test_empty_returns_empty_tensors(self, tiny_attention_config):
- """Indexing empty layer returns empty tensors."""
- cache = Apriel2Cache(tiny_attention_config)
-
- result = cache[0]
-
- assert result[0].numel() == 0
- assert result[1].numel() == 0
-
- def test_ssm_returns_empty(self, ssm_config, sample_conv_single):
- """Indexing SSM layer returns empty (no KV)."""
- cache = Apriel2Cache(ssm_config)
- cache.conv_states[0] = sample_conv_single
-
- result = cache[0]
-
- assert result[0].numel() == 0
-
- def test_stochastic_attention_returns_kv(self, stochastic_config, sample_kv):
- """Indexing stochastic with attention active returns KV."""
- cache = Apriel2Cache(stochastic_config)
- cache.set_active_mixer(1, "attention")
- cache.update(*sample_kv, layer_idx=1)
-
- result = cache[1]
-
- torch.testing.assert_close(result[0], sample_kv[0])
-
-
-# =============================================================================
-# SECTION 8: GENERATION PATTERNS
-# =============================================================================
-
-
-class TestGenerationPatterns:
- """Test real-world generation patterns."""
-
- def test_prefill_then_decode(self, tiny_attention_config, sample_kv):
- """Prefill with long prompt, then decode token-by-token."""
- cache = Apriel2Cache(tiny_attention_config)
- cache.update(*sample_kv, layer_idx=0) # Prefill 10 tokens
-
- for _ in range(5):
- new_kv = (torch.randn(2, 4, 1, 16), torch.randn(2, 4, 1, 16))
- cache.update(*new_kv, layer_idx=0)
-
- assert cache.get_seq_length(0) == 15
-
- def test_crop_then_continue(self, tiny_attention_config, sample_kv):
- """Crop old context, continue generation."""
- cache = Apriel2Cache(tiny_attention_config)
- cache.update(*sample_kv, layer_idx=0)
- cache.update(*sample_kv, layer_idx=0) # 20 tokens
-
- cache.crop(5) # Keep last 5
- cache.update(torch.randn(2, 4, 3, 16), torch.randn(2, 4, 3, 16), layer_idx=0)
-
- assert cache.get_seq_length(0) == 8
-
- def test_reset_between_generations(self, tiny_attention_config, sample_kv):
- """Reset between independent generations."""
- cache = Apriel2Cache(tiny_attention_config)
-
- # First generation
- cache.update(*sample_kv, layer_idx=0)
- assert cache.is_initialized == True
-
- # Reset
- cache.reset()
- assert cache.is_initialized == False
-
- # Second generation
- cache.update(*sample_kv, layer_idx=0)
- assert cache.get_seq_length(0) == 10
-
- def test_multi_layer_consistency(self, tiny_attention_config, sample_kv):
- """All layers updated consistently."""
- cache = Apriel2Cache(tiny_attention_config)
-
- for layer_idx in range(2):
- cache.update(*sample_kv, layer_idx=layer_idx)
- cache.update(torch.randn(2, 4, 1, 16), torch.randn(2, 4, 1, 16), layer_idx=layer_idx)
-
- for layer_idx in range(2):
- assert cache.get_seq_length(layer_idx) == 11
-
-
-# =============================================================================
-# SECTION 9: ERROR HANDLING
-# =============================================================================
-
-
-class TestErrorHandling:
- """Test error conditions and guards."""
-
- def test_stochastic_update_without_active_mixer(self, stochastic_config):
- """update() on stochastic without active_mixer raises."""
- cache = Apriel2Cache(stochastic_config)
-
- with pytest.raises(RuntimeError, match="needs active_mixer set"):
- cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=1)
-
- def test_stochastic_accessor_without_active_mixer(self, stochastic_config):
- """Accessing stochastic cache without active_mixer raises."""
- cache = Apriel2Cache(stochastic_config)
-
- with pytest.raises(RuntimeError, match="requires set_active_mixer"):
- _ = cache.conv_states[1]
-
- def test_accessor_error_lists_available_mixers(self, stochastic_config):
- """Error message lists available mixers."""
- cache = Apriel2Cache(stochastic_config)
-
- with pytest.raises(RuntimeError, match="Available mixers:"):
- _ = cache.key_cache[1]
-
- def test_invalid_mixer_name(self, stochastic_config):
- """Invalid mixer name raises KeyError on access."""
- cache = Apriel2Cache(stochastic_config)
- cache.set_active_mixer(1, "nonexistent")
-
- with pytest.raises(KeyError):
- cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=1)
-
- def test_layer_idx_out_of_bounds(self, tiny_attention_config):
- """Out-of-bounds layer_idx raises IndexError."""
- cache = Apriel2Cache(tiny_attention_config)
-
- with pytest.raises(IndexError):
- cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=999)
-
-
-# =============================================================================
-# SECTION 10: INTERNAL CLASSES
-# =============================================================================
-
-
-class TestAttentionCacheInternal:
- """Test internal _AttentionCache class directly."""
-
- def test_unbounded_growth(self):
- """No window allows unbounded growth."""
- cache = _AttentionCache(window=None)
-
- for _ in range(10):
- cache.update(torch.randn(2, 4, 100, 16), torch.randn(2, 4, 100, 16))
-
- assert cache.key.shape[-2] == 1000
-
- def test_window_enforced(self):
- """Window caps cache size."""
- cache = _AttentionCache(window=50)
-
- for _ in range(10):
- cache.update(torch.randn(2, 4, 100, 16), torch.randn(2, 4, 100, 16))
-
- assert cache.key.shape[-2] == 50
-
-
-class TestSSMCacheInternal:
- """Test internal _SSMCache class directly."""
-
- def test_initial_none(self):
- """Initial states are None."""
- cache = _SSMCache()
-
- assert cache.conv is None
- assert cache.recurrent is None
-
- def test_stores_tuple(self):
- """Can store tuple (for KDA)."""
- cache = _SSMCache()
- cache.conv = (torch.randn(2, 64, 3),) * 3
-
- assert isinstance(cache.conv, tuple)
diff --git a/fast_llm_external_models/tests/test_apriel2/test_cache_apriel2_specific.py b/fast_llm_external_models/tests/test_apriel2/test_cache_apriel2_specific.py
new file mode 100644
index 000000000..b45779454
--- /dev/null
+++ b/fast_llm_external_models/tests/test_apriel2/test_cache_apriel2_specific.py
@@ -0,0 +1,341 @@
+"""Tests for Apriel2-specific cache behaviors with no HuggingFace equivalent.
+
+This module tests features unique to Apriel2Cache that cannot be validated
+against upstream HF implementations:
+
+1. Stochastic mixer routing (switching between attention/SSM per layer)
+2. Multi-mixer layer support
+3. Error handling and guard rails
+4. Beam search operations (batch_repeat, reorder, select)
+5. Crop operation
+
+Fixtures used from conftest.py:
+ - stochastic_config: Stochastic mixer config with attention and mamba
+ - attention_config: Pure attention config
+ - ssm_config: Pure SSM config
+"""
+
+import pytest
+import torch
+
+from fast_llm_external_models.apriel2.cache import Apriel2Cache
+
+# =============================================================================
+# STOCHASTIC MIXER ROUTING
+# =============================================================================
+
+
+class TestStochasticMixerRouting:
+ """Test routing operations to correct sub-cache in stochastic layers."""
+
+ def test_set_active_mixer(self, stochastic_config):
+ """set_active_mixer updates routing for layer."""
+ cache = Apriel2Cache(stochastic_config)
+
+ cache.set_active_mixer(0, "attention")
+ assert cache.active_mixers[0] == "attention"
+
+ cache.set_active_mixer(0, "mamba")
+ assert cache.active_mixers[0] == "mamba"
+
+ def test_update_routes_to_active_mixer(self, stochastic_config):
+ """update() stores in correct sub-cache based on active_mixer."""
+ cache = Apriel2Cache(stochastic_config)
+
+ # Route to attention
+ cache.set_active_mixer(0, "attention")
+ cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0)
+
+ # Attention sub-cache should have data
+ assert cache.layers[0]["attention"].key is not None
+ # Mamba sub-cache should be empty
+ assert cache.layers[0]["mamba"].conv is None
+
+ def test_each_mixer_has_independent_cache(self, stochastic_config):
+ """Each mixer in a stochastic layer has its own independent state."""
+ cache = Apriel2Cache(stochastic_config)
+
+ # Store in attention
+ cache.set_active_mixer(0, "attention")
+ cache.update(torch.randn(2, 4, 5, 16), torch.randn(2, 4, 5, 16), layer_idx=0)
+
+ # Switch to mamba and store
+ cache.set_active_mixer(0, "mamba")
+ cache.layers[0]["mamba"].conv = torch.randn(2, 64, 4)
+
+ # Attention data should be unchanged
+ assert cache.layers[0]["attention"].cumulative_length == 5
+
+ def test_switching_preserves_all_states(self, stochastic_config):
+ """Switching active_mixer doesn't clear other mixer's state."""
+ cache = Apriel2Cache(stochastic_config)
+
+ # Build up attention state
+ cache.set_active_mixer(0, "attention")
+ cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0)
+ attn_key = cache.layers[0]["attention"].key.clone()
+
+ # Switch to mamba
+ cache.set_active_mixer(0, "mamba")
+
+ # Attention state preserved
+ torch.testing.assert_close(cache.layers[0]["attention"].key, attn_key)
+
+
+# =============================================================================
+# ERROR HANDLING
+# =============================================================================
+
+
+class TestErrorHandling:
+ """Test guard rails and error messages."""
+
+ def test_update_without_active_mixer_raises(self, stochastic_config):
+ """update() on stochastic layer without active_mixer raises RuntimeError."""
+ cache = Apriel2Cache(stochastic_config)
+
+ with pytest.raises(RuntimeError, match="needs active_mixer set"):
+ cache.update(torch.randn(2, 4, 5, 16), torch.randn(2, 4, 5, 16), layer_idx=0)
+
+ def test_accessor_without_active_mixer_raises(self, stochastic_config):
+ """Accessing key_cache/value_cache without active_mixer raises RuntimeError."""
+ cache = Apriel2Cache(stochastic_config)
+
+ with pytest.raises(RuntimeError, match="requires set_active_mixer"):
+ _ = cache.key_cache[0]
+
+ def test_error_message_lists_available_mixers(self, stochastic_config):
+ """Error message includes list of available mixers."""
+ cache = Apriel2Cache(stochastic_config)
+
+ with pytest.raises(RuntimeError, match="attention.*mamba|mamba.*attention"):
+ _ = cache.key_cache[0]
+
+
+# =============================================================================
+# BEAM SEARCH OPERATIONS
+# =============================================================================
+
+
+class TestBeamSearchOperations:
+ """Test batch manipulation for beam search."""
+
+ def test_batch_repeat_interleave_attention(self, attention_config):
+ """batch_repeat_interleave expands batch dimension."""
+ cache = Apriel2Cache(attention_config)
+ cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0)
+
+ cache.batch_repeat_interleave(3)
+
+ assert cache.layers[0].key.shape[0] == 6 # 2 * 3
+
+ def test_batch_repeat_interleave_ssm(self, ssm_config):
+ """batch_repeat_interleave works for SSM caches."""
+ cache = Apriel2Cache(ssm_config)
+ cache.layers[0].conv = torch.randn(2, 64, 4)
+
+ cache.batch_repeat_interleave(3)
+
+ assert cache.layers[0].conv.shape[0] == 6
+
+ def test_batch_repeat_interleave_kda_tuple(self, ssm_config):
+ """batch_repeat_interleave handles KDA tuple conv states."""
+ cache = Apriel2Cache(ssm_config)
+ cache.layers[0].conv = (torch.randn(2, 64, 4),) * 3
+
+ cache.batch_repeat_interleave(3)
+
+ assert cache.layers[0].conv[0].shape[0] == 6
+
+ def test_reorder_cache_attention(self, attention_config):
+ """reorder_cache reorders batch dimension."""
+ cache = Apriel2Cache(attention_config)
+ k = torch.arange(4).float().view(4, 1, 1, 1).expand(4, 4, 10, 16)
+ cache.update(k, k.clone(), layer_idx=0)
+
+ beam_idx = torch.tensor([3, 2, 1, 0])
+ cache.reorder_cache(beam_idx)
+
+ # Check reordering
+ assert cache.layers[0].key[0, 0, 0, 0].item() == 3.0
+ assert cache.layers[0].key[3, 0, 0, 0].item() == 0.0
+
+ def test_batch_select_indices(self, attention_config):
+ """batch_select_indices selects subset of batch."""
+ cache = Apriel2Cache(attention_config)
+ cache.update(torch.randn(4, 4, 10, 16), torch.randn(4, 4, 10, 16), layer_idx=0)
+
+ indices = torch.tensor([0, 2])
+ cache.batch_select_indices(indices)
+
+ assert cache.layers[0].key.shape[0] == 2
+
+ def test_reorder_cache_ssm_tuple(self, ssm_config):
+ """reorder_cache handles KDA tuple conv states."""
+ cache = Apriel2Cache(ssm_config)
+ # Create distinguishable tensors for each batch position
+ conv0 = torch.full((1, 64, 4), 0.0)
+ conv1 = torch.full((1, 64, 4), 1.0)
+ conv2 = torch.full((1, 64, 4), 2.0)
+ cache.layers[0].conv = (
+ torch.cat([conv0, conv1, conv2], dim=0),
+ torch.cat([conv0, conv1, conv2], dim=0),
+ torch.cat([conv0, conv1, conv2], dim=0),
+ )
+
+ beam_idx = torch.tensor([2, 1, 0])
+ cache.reorder_cache(beam_idx)
+
+ # Check reordering: batch[0] should now have value 2.0
+ assert cache.layers[0].conv[0][0, 0, 0].item() == 2.0
+ assert cache.layers[0].conv[0][2, 0, 0].item() == 0.0
+
+ def test_batch_select_indices_ssm_tuple(self, ssm_config):
+ """batch_select_indices handles KDA tuple conv states."""
+ cache = Apriel2Cache(ssm_config)
+ cache.layers[0].conv = (torch.randn(4, 64, 4),) * 3
+
+ indices = torch.tensor([0, 2])
+ cache.batch_select_indices(indices)
+
+ assert cache.layers[0].conv[0].shape[0] == 2
+ assert cache.layers[0].conv[1].shape[0] == 2
+ assert cache.layers[0].conv[2].shape[0] == 2
+
+
+# =============================================================================
+# CROP OPERATION
+# =============================================================================
+
+
+class TestCropOperation:
+ """Test cache truncation."""
+
+ def test_crop_truncates_attention(self, attention_config):
+ """crop() truncates attention cache."""
+ cache = Apriel2Cache(attention_config)
+ cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0)
+
+ cache.crop(5)
+
+ assert cache.layers[0].key.shape[-2] == 5
+ assert cache.get_seq_length(0) == 5
+
+ def test_crop_affects_all_layers(self, attention_config):
+ """crop() affects all layers."""
+ cache = Apriel2Cache(attention_config)
+ cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0)
+ cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=1)
+
+ cache.crop(3)
+
+ assert cache.layers[0].key.shape[-2] == 3
+ assert cache.layers[1].key.shape[-2] == 3
+
+ def test_crop_ignores_ssm(self, ssm_config):
+ """crop() doesn't affect SSM caches (they don't have seq dimension)."""
+ cache = Apriel2Cache(ssm_config)
+ cache.layers[0].conv = torch.randn(2, 64, 4)
+
+ # Should not raise
+ cache.crop(5)
+
+ # SSM state unchanged
+ assert cache.layers[0].conv.shape == (2, 64, 4)
+
+
+# =============================================================================
+# CACHE PROPERTIES
+# =============================================================================
+
+
+class TestCacheProperties:
+ """Test cache property methods."""
+
+ def test_is_initialized_attention(self, attention_config):
+ """is_initialized True after update."""
+ cache = Apriel2Cache(attention_config)
+ assert not cache.is_initialized
+
+ cache.update(torch.randn(2, 4, 5, 16), torch.randn(2, 4, 5, 16), layer_idx=0)
+ assert cache.is_initialized
+
+ def test_is_initialized_ssm(self, ssm_config):
+ """is_initialized True after setting conv state."""
+ cache = Apriel2Cache(ssm_config)
+ assert not cache.is_initialized
+
+ cache.layers[0].conv = torch.randn(2, 64, 4)
+ assert cache.is_initialized
+
+ def test_has_previous_state_ssm_only(self, ssm_config):
+ """has_previous_state checks SSM conv states."""
+ cache = Apriel2Cache(ssm_config)
+ assert not cache.has_previous_state
+
+ cache.layers[0].conv = torch.randn(2, 64, 4)
+ assert cache.has_previous_state
+
+ def test_has_previous_state_ignores_attention(self, attention_config):
+ """has_previous_state ignores attention caches."""
+ cache = Apriel2Cache(attention_config)
+ cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0)
+
+ # Attention-only cache returns False for has_previous_state
+ assert not cache.has_previous_state
+
+ def test_reset_clears_ssm_states(self, ssm_config):
+ """reset() clears SSM conv and recurrent states."""
+ cache = Apriel2Cache(ssm_config)
+ cache.layers[0].conv = torch.randn(2, 64, 4)
+ cache.layers[0].recurrent = torch.randn(2, 64, 16)
+
+ cache.reset()
+
+ assert cache.layers[0].conv is None
+ assert cache.layers[0].recurrent is None
+
+ def test_max_batch_size_from_ssm_tuple(self, ssm_config):
+ """max_batch_size works with KDA tuple conv states."""
+ cache = Apriel2Cache(ssm_config)
+ cache.layers[0].conv = (torch.randn(3, 64, 4),) * 3
+
+ assert cache.max_batch_size == 3
+
+ def test_max_batch_size(self, attention_config):
+ """max_batch_size returns batch dimension."""
+ cache = Apriel2Cache(attention_config)
+ cache.update(torch.randn(3, 4, 10, 16), torch.randn(3, 4, 10, 16), layer_idx=0)
+
+ assert cache.max_batch_size == 3
+
+ def test_len_returns_num_layers(self, attention_config):
+ """__len__ returns number of layers."""
+ cache = Apriel2Cache(attention_config)
+ assert len(cache) == 2
+
+
+# =============================================================================
+# INDEXING
+# =============================================================================
+
+
+class TestCacheIndexing:
+ """Test __getitem__ for HF compatibility."""
+
+ def test_getitem_returns_kv_tuple(self, attention_config):
+ """cache[idx] returns (key, value) tuple."""
+ cache = Apriel2Cache(attention_config)
+ cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0)
+
+ k, v = cache[0]
+ assert k.shape == (2, 4, 10, 16)
+ assert v.shape == (2, 4, 10, 16)
+
+ def test_getitem_empty_returns_empty_tensors(self, attention_config):
+ """cache[idx] on empty cache returns empty tensors."""
+ cache = Apriel2Cache(attention_config)
+
+ k, v = cache[0]
+ assert k.numel() == 0
+ assert v.numel() == 0
diff --git a/fast_llm_external_models/tests/test_apriel2/test_cache_contracts.py b/fast_llm_external_models/tests/test_apriel2/test_cache_contracts.py
new file mode 100644
index 000000000..8ceabfb91
--- /dev/null
+++ b/fast_llm_external_models/tests/test_apriel2/test_cache_contracts.py
@@ -0,0 +1,591 @@
+"""Contract tests for Apriel2Cache against HuggingFace cache implementations.
+
+This module tests that Apriel2Cache components behave equivalently to their
+HuggingFace counterparts. This ensures compatibility with HF's generation
+infrastructure (mask creation, beam search, etc.).
+
+Mapping:
+ Apriel2 Component HuggingFace Equivalent
+ ----------------- ----------------------
+ _AttentionCache (no window) -> DynamicLayer
+ _AttentionCache (window) -> DynamicSlidingWindowLayer
+ _SSMCache -> MambaCache (different interface, same concept)
+
+Apriel2-specific features (stochastic routing, multi-mixer layers) are tested
+separately in test_cache_apriel2_specific.py since they have no HF equivalent.
+
+Fixtures used from conftest.py:
+ - batch_size, num_heads, head_dim: Tensor dimensions
+ - hf_dynamic_layer: HuggingFace DynamicLayer
+ - hf_sliding_layer: HuggingFace DynamicSlidingWindowLayer (parameterized by window_size)
+ - apriel_attention_cache: Apriel2 _AttentionCache (no window)
+ - apriel_sliding_cache: Apriel2 _AttentionCache (with window, parameterized)
+ - window_size: Parameterized window sizes [4, 8, 16, 32]
+ - attention_config, swa_config: Apriel2 configs
+"""
+
+import pytest
+import torch
+
+from fast_llm_external_models.apriel2.cache import Apriel2Cache, _AttentionCache
+
+# =============================================================================
+# SECTION 1: FULL ATTENTION - _AttentionCache vs DynamicLayer
+# =============================================================================
+
+
+class TestFullAttentionContract:
+ """Test _AttentionCache (no window) matches HuggingFace DynamicLayer.
+
+ DynamicLayer is the standard cache for full causal attention.
+ We test that our cache produces identical mask parameters.
+ """
+
+ # -------------------------------------------------------------------------
+ # get_seq_length: Must match exactly for generation to work
+ # -------------------------------------------------------------------------
+
+ @pytest.mark.parametrize("seq_len", [1, 5, 10, 50, 100])
+ def test_get_seq_length_after_prefill(
+ self, hf_dynamic_layer, apriel_attention_cache, batch_size, num_heads, head_dim, seq_len
+ ):
+ """After prefill, cumulative_length matches HF get_seq_length."""
+ key = torch.randn(batch_size, num_heads, seq_len, head_dim)
+ value = torch.randn(batch_size, num_heads, seq_len, head_dim)
+
+ hf_dynamic_layer.update(key.clone(), value.clone())
+ apriel_attention_cache.update(key.clone(), value.clone())
+
+ assert apriel_attention_cache.cumulative_length == hf_dynamic_layer.get_seq_length()
+
+ @pytest.mark.parametrize("prefill_len", [1, 5, 10])
+ @pytest.mark.parametrize("decode_steps", [1, 5, 10, 20])
+ def test_get_seq_length_during_decode(
+ self, hf_dynamic_layer, apriel_attention_cache, batch_size, num_heads, head_dim, prefill_len, decode_steps
+ ):
+ """During decode, cumulative_length tracks total tokens seen."""
+ # Prefill
+ key = torch.randn(batch_size, num_heads, prefill_len, head_dim)
+ value = torch.randn(batch_size, num_heads, prefill_len, head_dim)
+ hf_dynamic_layer.update(key.clone(), value.clone())
+ apriel_attention_cache.update(key.clone(), value.clone())
+
+ # Decode
+ for step in range(decode_steps):
+ key = torch.randn(batch_size, num_heads, 1, head_dim)
+ value = torch.randn(batch_size, num_heads, 1, head_dim)
+ hf_dynamic_layer.update(key.clone(), value.clone())
+ apriel_attention_cache.update(key.clone(), value.clone())
+
+ assert (
+ apriel_attention_cache.cumulative_length == hf_dynamic_layer.get_seq_length()
+ ), f"Mismatch at decode step {step}"
+
+ # -------------------------------------------------------------------------
+ # get_mask_sizes: Verify HF behavior for documentation
+ # -------------------------------------------------------------------------
+
+ @pytest.mark.parametrize("prefill_len", [1, 5, 10])
+ @pytest.mark.parametrize("decode_steps", [0, 1, 5, 10])
+ def test_hf_mask_sizes_kv_length(
+ self, hf_dynamic_layer, apriel_attention_cache, batch_size, num_heads, head_dim, prefill_len, decode_steps
+ ):
+ """Document HF's kv_length behavior and verify cumulative_length tracks correctly.
+
+ For full attention, kv_length = cumulative_length + query_length.
+ This test verifies our cache tracks tokens identically to HF.
+ """
+ # Prefill
+ key = torch.randn(batch_size, num_heads, prefill_len, head_dim)
+ value = torch.randn(batch_size, num_heads, prefill_len, head_dim)
+ hf_dynamic_layer.update(key.clone(), value.clone())
+ apriel_attention_cache.update(key.clone(), value.clone())
+
+ # Decode
+ for _ in range(decode_steps):
+ key = torch.randn(batch_size, num_heads, 1, head_dim)
+ value = torch.randn(batch_size, num_heads, 1, head_dim)
+ hf_dynamic_layer.update(key.clone(), value.clone())
+ apriel_attention_cache.update(key.clone(), value.clone())
+
+ # Verify cumulative_length matches HF
+ assert apriel_attention_cache.cumulative_length == hf_dynamic_layer.get_seq_length()
+
+ # Verify HF's kv_length follows the expected formula
+ cache_position = torch.arange(1) # Single token decode
+ hf_kv_len, hf_kv_offset = hf_dynamic_layer.get_mask_sizes(cache_position)
+ expected_kv_len = hf_dynamic_layer.get_seq_length() + cache_position.shape[0]
+ assert hf_kv_len == expected_kv_len
+
+ def test_hf_kv_offset_always_zero(self, hf_dynamic_layer, batch_size, num_heads, head_dim):
+ """Document that HF DynamicLayer always returns kv_offset=0.
+
+ For full attention, all cached KV pairs map to absolute positions
+ starting from 0, so kv_offset is always 0.
+ """
+ # Add many tokens
+ for _ in range(20):
+ key = torch.randn(batch_size, num_heads, 5, head_dim)
+ value = torch.randn(batch_size, num_heads, 5, head_dim)
+ hf_dynamic_layer.update(key.clone(), value.clone())
+
+ cache_position = torch.arange(1)
+ _, hf_kv_offset = hf_dynamic_layer.get_mask_sizes(cache_position)
+
+ assert hf_kv_offset == 0, "DynamicLayer always returns kv_offset=0"
+
+ # -------------------------------------------------------------------------
+ # update: Output shape and values must match
+ # -------------------------------------------------------------------------
+
+ @pytest.mark.parametrize("seq_len", [1, 5, 10])
+ def test_update_returns_same_shape(
+ self, hf_dynamic_layer, apriel_attention_cache, batch_size, num_heads, head_dim, seq_len
+ ):
+ """update() returns tensors with matching shapes."""
+ key = torch.randn(batch_size, num_heads, seq_len, head_dim)
+ value = torch.randn(batch_size, num_heads, seq_len, head_dim)
+
+ hf_k, hf_v = hf_dynamic_layer.update(key.clone(), value.clone())
+ apr_k, apr_v = apriel_attention_cache.update(key.clone(), value.clone())
+
+ assert hf_k.shape == apr_k.shape
+ assert hf_v.shape == apr_v.shape
+
+ def test_update_concatenates_identically(
+ self, hf_dynamic_layer, apriel_attention_cache, batch_size, num_heads, head_dim
+ ):
+ """Multiple updates produce identical concatenated states."""
+ # Use deterministic values for comparison
+ k1 = torch.arange(10).float().view(1, 1, 10, 1).expand(batch_size, num_heads, 10, head_dim)
+ v1 = k1.clone()
+
+ hf_dynamic_layer.update(k1.clone(), v1.clone())
+ apriel_attention_cache.update(k1.clone(), v1.clone())
+
+ k2 = torch.arange(10, 15).float().view(1, 1, 5, 1).expand(batch_size, num_heads, 5, head_dim)
+ v2 = k2.clone()
+
+ hf_k, hf_v = hf_dynamic_layer.update(k2.clone(), v2.clone())
+ apr_k, apr_v = apriel_attention_cache.update(k2.clone(), v2.clone())
+
+ torch.testing.assert_close(hf_k, apr_k)
+ torch.testing.assert_close(hf_v, apr_v)
+
+
+# =============================================================================
+# SECTION 2: SLIDING WINDOW - _AttentionCache vs DynamicSlidingWindowLayer
+# =============================================================================
+
+
+class TestSlidingWindowContract:
+ """Test _AttentionCache (with window) matches HuggingFace DynamicSlidingWindowLayer.
+
+ DynamicSlidingWindowLayer is used for sliding window attention (e.g., Mistral).
+ Critical behaviors:
+ - cumulative_length tracks ALL tokens seen (not just cached)
+ - kv_offset increases once window is exceeded
+ - kv_length is capped at window size
+
+ Uses fixtures from conftest.py:
+ - window_size: parameterized [4, 8, 16, 32]
+ - hf_sliding_layer: DynamicSlidingWindowLayer
+ - apriel_sliding_cache: _AttentionCache with window
+ """
+
+ # -------------------------------------------------------------------------
+ # cumulative_length: Must track total tokens, not cached tokens
+ # -------------------------------------------------------------------------
+
+ @pytest.mark.parametrize("prefill_len", [1, 3, 5, 10, 20])
+ def test_cumulative_length_matches_after_prefill(
+ self, hf_sliding_layer, apriel_sliding_cache, batch_size, num_heads, head_dim, prefill_len
+ ):
+ """cumulative_length matches HF get_seq_length after prefill."""
+ key = torch.randn(batch_size, num_heads, prefill_len, head_dim)
+ value = torch.randn(batch_size, num_heads, prefill_len, head_dim)
+
+ hf_sliding_layer.update(key.clone(), value.clone())
+ apriel_sliding_cache.update(key.clone(), value.clone())
+
+ assert apriel_sliding_cache.cumulative_length == hf_sliding_layer.get_seq_length()
+
+ def test_cumulative_length_continues_past_window(
+ self, hf_sliding_layer, apriel_sliding_cache, window_size, batch_size, num_heads, head_dim
+ ):
+ """cumulative_length keeps growing even after window is full."""
+ total_tokens = window_size * 3 # Way past window
+
+ for i in range(total_tokens):
+ key = torch.randn(batch_size, num_heads, 1, head_dim)
+ value = torch.randn(batch_size, num_heads, 1, head_dim)
+
+ hf_sliding_layer.update(key.clone(), value.clone())
+ apriel_sliding_cache.update(key.clone(), value.clone())
+
+ expected = i + 1
+ assert apriel_sliding_cache.cumulative_length == expected
+ assert hf_sliding_layer.get_seq_length() == expected
+
+ # -------------------------------------------------------------------------
+ # get_mask_sizes: kv_offset must increase once window is exceeded
+ # -------------------------------------------------------------------------
+
+ def test_kv_offset_zero_before_window_full(
+ self, hf_sliding_layer, apriel_sliding_cache, window_size, batch_size, num_heads, head_dim
+ ):
+ """kv_offset is 0 while cumulative < window.
+
+ Before the window is full, kv_offset should be 0 because all cached tokens
+ correspond to absolute positions starting from 0.
+ """
+ # Add tokens up to window-1
+ for i in range(window_size - 1):
+ key = torch.randn(batch_size, num_heads, 1, head_dim)
+ value = torch.randn(batch_size, num_heads, 1, head_dim)
+
+ hf_sliding_layer.update(key.clone(), value.clone())
+ apriel_sliding_cache.update(key.clone(), value.clone())
+
+ cache_position = torch.arange(1)
+ hf_kv_len, hf_kv_offset = hf_sliding_layer.get_mask_sizes(cache_position)
+
+ # Verify HF returns 0 offset before window full
+ assert hf_kv_offset == 0, f"HF offset should be 0 at step {i}"
+ # Verify Apriel cache tracks cumulative correctly
+ assert apriel_sliding_cache.cumulative_length == i + 1
+
+ def test_kv_offset_increases_after_window_full(
+ self, hf_sliding_layer, apriel_sliding_cache, window_size, batch_size, num_heads, head_dim
+ ):
+ """kv_offset increases once cumulative >= window.
+
+ Once the window is full, the cache discards oldest tokens. kv_offset tracks
+ which absolute position KV[0] corresponds to.
+ """
+ # Fill to exactly window
+ for _ in range(window_size):
+ key = torch.randn(batch_size, num_heads, 1, head_dim)
+ value = torch.randn(batch_size, num_heads, 1, head_dim)
+ hf_sliding_layer.update(key.clone(), value.clone())
+ apriel_sliding_cache.update(key.clone(), value.clone())
+
+ cache_position = torch.arange(1)
+ hf_kv_len, hf_kv_offset = hf_sliding_layer.get_mask_sizes(cache_position)
+
+ # At window boundary, offset should be 1
+ assert hf_kv_offset == 1, "HF offset should be 1 at window boundary"
+ assert apriel_sliding_cache.cumulative_length == window_size
+
+ # Add more tokens and verify offset keeps increasing with HF
+ for i in range(5):
+ key = torch.randn(batch_size, num_heads, 1, head_dim)
+ value = torch.randn(batch_size, num_heads, 1, head_dim)
+ hf_sliding_layer.update(key.clone(), value.clone())
+ apriel_sliding_cache.update(key.clone(), value.clone())
+
+ hf_kv_len, hf_kv_offset = hf_sliding_layer.get_mask_sizes(cache_position)
+
+ expected_offset = i + 2
+ assert hf_kv_offset == expected_offset
+ assert apriel_sliding_cache.cumulative_length == window_size + i + 1
+
+ def test_kv_length_capped_at_window(
+ self, hf_sliding_layer, apriel_sliding_cache, window_size, batch_size, num_heads, head_dim
+ ):
+ """kv_length is capped at window size once exceeded.
+
+ For a query of length 1 after the window is full, kv_length = window
+ (window-1 cached tokens + 1 query token).
+ """
+ # Way past window
+ for _ in range(window_size * 2):
+ key = torch.randn(batch_size, num_heads, 1, head_dim)
+ value = torch.randn(batch_size, num_heads, 1, head_dim)
+ hf_sliding_layer.update(key.clone(), value.clone())
+ apriel_sliding_cache.update(key.clone(), value.clone())
+
+ cache_position = torch.arange(1)
+ hf_kv_len, _ = hf_sliding_layer.get_mask_sizes(cache_position)
+
+ # HF returns window (window-1 cached + 1 query)
+ assert hf_kv_len == window_size
+ # Verify our cache tracked cumulative correctly
+ assert apriel_sliding_cache.cumulative_length == window_size * 2
+
+ # -------------------------------------------------------------------------
+ # Full sequence length tracking through generation
+ # -------------------------------------------------------------------------
+
+ @pytest.mark.parametrize("prefill_len", [1, 3, 5, 10, 20])
+ def test_cumulative_length_tracks_all_tokens(
+ self, hf_sliding_layer, apriel_sliding_cache, window_size, batch_size, num_heads, head_dim, prefill_len
+ ):
+ """cumulative_length tracks total tokens seen through prefill + decode.
+
+ This is the foundation for correct mask size computation. We verify that
+ our _AttentionCache tracks tokens identically to HuggingFace's DynamicSlidingWindowLayer.
+ The actual get_mask_sizes computation is tested in TestApriel2CacheIntegration.
+ """
+ # Prefill
+ key = torch.randn(batch_size, num_heads, prefill_len, head_dim)
+ value = torch.randn(batch_size, num_heads, prefill_len, head_dim)
+ hf_sliding_layer.update(key.clone(), value.clone())
+ apriel_sliding_cache.update(key.clone(), value.clone())
+
+ assert apriel_sliding_cache.cumulative_length == hf_sliding_layer.get_seq_length()
+
+ # Decode past window
+ for i in range(window_size + 10):
+ key = torch.randn(batch_size, num_heads, 1, head_dim)
+ value = torch.randn(batch_size, num_heads, 1, head_dim)
+ hf_sliding_layer.update(key.clone(), value.clone())
+ apriel_sliding_cache.update(key.clone(), value.clone())
+
+ assert (
+ apriel_sliding_cache.cumulative_length == hf_sliding_layer.get_seq_length()
+ ), f"cumulative_length mismatch at step {i}"
+
+
+# =============================================================================
+# SECTION 3: SSM CACHE - _SSMCache vs MambaCache concept
+# =============================================================================
+
+
+class TestSSMCacheContract:
+ """Document _SSMCache interface and verify basic contract.
+
+ Unlike attention caches which have HF equivalents (DynamicLayer, DynamicSlidingWindowLayer),
+ SSM caches have no direct HF counterpart with matching interface. HF's MambaCache uses
+ different methods (update_conv_state, update_ssm_state), so we can't do direct comparison.
+
+ These tests document the interface contract:
+ 1. `conv` and `recurrent` attributes for storing states
+ 2. Both support None (lazy initialization)
+ 3. `conv` can be tuple (for KDA which has separate q/k/v conv states)
+
+ Higher-level operations (reorder, batch_repeat, reset) are tested in
+ TestBeamSearchOperations in test_cache_apriel2_specific.py.
+ """
+
+ def test_conv_state_storage(self, ssm_cache):
+ """conv attribute stores conv states (batch, intermediate, kernel_size)."""
+ conv = torch.randn(2, 64, 4)
+ ssm_cache.conv = conv
+ torch.testing.assert_close(ssm_cache.conv, conv)
+
+ def test_recurrent_state_storage(self, ssm_cache):
+ """recurrent attribute stores SSM states (batch, intermediate, state_size)."""
+ recurrent = torch.randn(2, 64, 16)
+ ssm_cache.recurrent = recurrent
+ torch.testing.assert_close(ssm_cache.recurrent, recurrent)
+
+ def test_conv_state_tuple_for_kda(self, ssm_cache):
+ """conv can be tuple for KDA's separate q/k/v convolutions."""
+ conv_tuple = (torch.randn(2, 64, 4), torch.randn(2, 64, 4), torch.randn(2, 64, 4))
+ ssm_cache.conv = conv_tuple
+ assert isinstance(ssm_cache.conv, tuple)
+ assert len(ssm_cache.conv) == 3
+
+ def test_initial_states_none(self, ssm_cache):
+ """States are None initially (lazy initialization pattern)."""
+ assert ssm_cache.conv is None
+ assert ssm_cache.recurrent is None
+
+ def test_states_independent(self, ssm_cache):
+ """conv and recurrent states are independent."""
+ ssm_cache.conv = torch.randn(2, 64, 4)
+ assert ssm_cache.recurrent is None # recurrent unchanged
+
+ ssm_cache.recurrent = torch.randn(2, 64, 16)
+ assert ssm_cache.conv is not None # conv unchanged
+
+
+# =============================================================================
+# SECTION 4: APRIEL2CACHE INTEGRATION
+# =============================================================================
+
+
+class TestApriel2CacheIntegration:
+ """Test Apriel2Cache correctly delegates to underlying caches.
+
+ Uses fixtures from conftest.py:
+ - attention_config: Pure attention config
+ - swa_config: Sliding window attention config (window=8)
+ """
+
+ def test_get_seq_length_matches_dynamic_layer(self, attention_config):
+ """Apriel2Cache.get_seq_length matches DynamicLayer for full attention."""
+ from transformers.cache_utils import DynamicLayer
+
+ cache = Apriel2Cache(attention_config)
+ hf_layer = DynamicLayer()
+
+ key = torch.randn(2, 4, 10, 16)
+ value = torch.randn(2, 4, 10, 16)
+
+ cache.update(key.clone(), value.clone(), layer_idx=0)
+ hf_layer.update(key.clone(), value.clone())
+
+ assert cache.get_seq_length(0) == hf_layer.get_seq_length()
+
+ def test_get_mask_sizes_matches_dynamic_layer(self, attention_config):
+ """Apriel2Cache.get_mask_sizes matches DynamicLayer."""
+ from transformers.cache_utils import DynamicLayer
+
+ cache = Apriel2Cache(attention_config)
+ hf_layer = DynamicLayer()
+
+ key = torch.randn(2, 4, 10, 16)
+ value = torch.randn(2, 4, 10, 16)
+
+ cache.update(key.clone(), value.clone(), layer_idx=0)
+ hf_layer.update(key.clone(), value.clone())
+
+ cache_position = torch.arange(1)
+ hf_kv_len, hf_kv_offset = hf_layer.get_mask_sizes(cache_position)
+ apr_kv_len, apr_kv_offset = cache.get_mask_sizes(cache_position, layer_idx=0)
+
+ assert apr_kv_len == hf_kv_len
+ assert apr_kv_offset == hf_kv_offset
+
+ def test_get_mask_sizes_matches_sliding_layer(self, swa_config):
+ """Apriel2Cache.get_mask_sizes matches DynamicSlidingWindowLayer."""
+ from transformers.cache_utils import DynamicSlidingWindowLayer
+
+ cache = Apriel2Cache(swa_config)
+ hf_layer = DynamicSlidingWindowLayer(sliding_window=8)
+
+ # Fill past window
+ for _ in range(15):
+ key = torch.randn(2, 4, 1, 16)
+ value = torch.randn(2, 4, 1, 16)
+ cache.update(key.clone(), value.clone(), layer_idx=0)
+ hf_layer.update(key.clone(), value.clone())
+
+ cache_position = torch.arange(1)
+ hf_kv_len, hf_kv_offset = hf_layer.get_mask_sizes(cache_position)
+ apr_kv_len, apr_kv_offset = cache.get_mask_sizes(cache_position, layer_idx=0)
+
+ assert apr_kv_len == hf_kv_len
+ assert apr_kv_offset == hf_kv_offset
+
+ def test_reset_clears_cumulative_length(self, attention_config):
+ """reset() clears cumulative_length (matches DynamicLayer.reset)."""
+ cache = Apriel2Cache(attention_config)
+
+ cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0)
+ assert cache.get_seq_length(0) == 10
+
+ cache.reset()
+ assert cache.get_seq_length(0) == 0
+
+
+# =============================================================================
+# SECTION 5: MASK CORRECTNESS (SEMANTIC TESTS)
+# =============================================================================
+
+
+class TestMaskCorrectness:
+ """Test that mask parameters produce semantically correct masks.
+
+ These tests verify the END RESULT: masks created with our parameters
+ allow the correct attention patterns.
+ """
+
+ def test_full_attention_decode_can_attend_to_all(self):
+ """During decode, query can attend to all cached positions."""
+ from transformers.masking_utils import causal_mask_function, sdpa_mask
+
+ cache = _AttentionCache(window=None)
+
+ # Prefill + decode
+ for _ in range(10):
+ cache.update(torch.randn(1, 1, 1, 16), torch.randn(1, 1, 1, 16))
+
+ # Mask for decode step
+ cache_position = torch.tensor([10]) # Position of new token
+ kv_length = cache.cumulative_length + 1
+ kv_offset = 0
+
+ mask = sdpa_mask(
+ batch_size=1,
+ cache_position=cache_position,
+ kv_length=kv_length,
+ kv_offset=kv_offset,
+ mask_function=causal_mask_function,
+ )
+
+ if mask is not None:
+ # Query at position 10 should attend to positions 0-10
+ query_mask = mask[0, 0, 0, :]
+ for kv_idx in range(kv_length):
+ assert query_mask[kv_idx].item() == True, f"Should attend to position {kv_idx}"
+
+ @pytest.mark.parametrize("window_size", [4, 8, 16])
+ def test_sliding_window_decode_respects_window(self, window_size):
+ """During decode, query only attends within sliding window."""
+ from transformers.masking_utils import sdpa_mask, sliding_window_causal_mask_function
+
+ cache = _AttentionCache(window=window_size)
+
+ # Fill way past window
+ total_tokens = window_size * 2
+ for _ in range(total_tokens):
+ cache.update(torch.randn(1, 1, 1, 16), torch.randn(1, 1, 1, 16))
+
+ # Mask for decode step
+ cache_position = torch.tensor([total_tokens])
+ cumulative = cache.cumulative_length
+ kv_offset = max(cumulative - window_size + 1, 0)
+ kv_length = window_size - 1 + 1 # cached + query
+
+ mask = sdpa_mask(
+ batch_size=1,
+ cache_position=cache_position,
+ kv_length=kv_length,
+ kv_offset=kv_offset,
+ mask_function=sliding_window_causal_mask_function(window_size),
+ )
+
+ if mask is not None:
+ query_mask = mask[0, 0, 0, :]
+ query_pos = cache_position[0].item()
+
+ for kv_idx in range(kv_length):
+ abs_pos = kv_offset + kv_idx
+ in_window = abs_pos > query_pos - window_size
+ causal = abs_pos <= query_pos
+ expected = in_window and causal
+
+ assert (
+ query_mask[kv_idx].item() == expected
+ ), f"Position {abs_pos}: expected {expected}, got {query_mask[kv_idx].item()}"
+
+ def test_prefill_has_causal_pattern(self):
+ """During prefill, mask has proper causal (lower triangular) pattern."""
+ from transformers.masking_utils import causal_mask_function, sdpa_mask
+
+ cache = _AttentionCache(window=None)
+ cache.update(torch.randn(1, 1, 5, 16), torch.randn(1, 1, 5, 16))
+
+ cache_position = torch.arange(5)
+ kv_length = cache.cumulative_length
+ kv_offset = 0
+
+ mask = sdpa_mask(
+ batch_size=1,
+ cache_position=cache_position,
+ kv_length=kv_length,
+ kv_offset=kv_offset,
+ mask_function=causal_mask_function,
+ allow_is_causal_skip=False, # Force mask creation
+ )
+
+ if mask is not None:
+ # Check causal pattern
+ for q_idx in range(5):
+ for kv_idx in range(5):
+ expected = kv_idx <= q_idx
+ actual = mask[0, 0, q_idx, kv_idx].item()
+ assert actual == expected, f"q={q_idx}, kv={kv_idx}: expected {expected}"
diff --git a/fast_llm_external_models/tests/test_apriel2/test_causal_conv1d.py b/fast_llm_external_models/tests/test_apriel2/test_causal_conv1d.py
index ec6abc1d2..0567cd76e 100644
--- a/fast_llm_external_models/tests/test_apriel2/test_causal_conv1d.py
+++ b/fast_llm_external_models/tests/test_apriel2/test_causal_conv1d.py
@@ -24,7 +24,6 @@
from fast_llm_external_models.apriel2.modeling_apriel2 import CausalConv1d, _causal_conv1d_fn
-
# =============================================================================
# Fixtures
# =============================================================================
@@ -63,6 +62,7 @@ def kernel_size():
def to_device(conv: CausalConv1d, device: str) -> CausalConv1d:
"""Create a copy of conv on the specified device."""
import copy
+
return copy.deepcopy(conv).to(device)
@@ -71,7 +71,9 @@ def prefill(conv: CausalConv1d, x: torch.Tensor, state: torch.Tensor = None) ->
return conv(x, conv_state=state, return_final_state=True)
-def decode_sequence(conv: CausalConv1d, tokens: torch.Tensor, state: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+def decode_sequence(
+ conv: CausalConv1d, tokens: torch.Tensor, state: torch.Tensor
+) -> tuple[torch.Tensor, torch.Tensor]:
"""Decode multiple tokens one-by-one, return (stacked_outputs, final_state).
Args:
@@ -223,7 +225,7 @@ def test_chunked_prefill_cpu(self, conv, dim, total_len, chunk_size):
outputs = []
state = None
for start in range(0, total_len, chunk_size):
- chunk = x[:, :, start:start + chunk_size]
+ chunk = x[:, :, start : start + chunk_size]
out, state = prefill(conv, chunk, state)
outputs.append(out)
@@ -248,7 +250,7 @@ def test_chunked_prefill_cuda(self, conv, dim, total_len, chunk_size):
outputs = []
state = None
for start in range(0, total_len, chunk_size):
- chunk = x[:, :, start:start + chunk_size].cuda()
+ chunk = x[:, :, start : start + chunk_size].cuda()
out, state = prefill(conv_cuda, chunk, state)
outputs.append(out)
@@ -329,7 +331,7 @@ def test_all_cpu_paths_match(self, conv, dim):
outputs = []
state = None
for start in range(0, total_len, chunk_size):
- chunk = x[:, :, start:start + chunk_size]
+ chunk = x[:, :, start : start + chunk_size]
out, state = prefill(conv, chunk, state)
outputs.append(out)
path1 = torch.cat(outputs, dim=-1)
@@ -374,7 +376,7 @@ def test_all_paths_match_cross_device(self, conv, dim):
# CPU chunked
outputs, state = [], None
for start in range(0, total_len, chunk_size):
- out, state = prefill(conv, x[:, :, start:start + chunk_size], state)
+ out, state = prefill(conv, x[:, :, start : start + chunk_size], state)
outputs.append(out)
results["cpu_chunked"] = torch.cat(outputs, dim=-1)
@@ -393,7 +395,7 @@ def test_all_paths_match_cross_device(self, conv, dim):
# CUDA chunked
outputs, state = [], None
for start in range(0, total_len, chunk_size):
- out, state = prefill(conv_cuda, x[:, :, start:start + chunk_size].cuda(), state)
+ out, state = prefill(conv_cuda, x[:, :, start : start + chunk_size].cuda(), state)
outputs.append(out.cpu())
results["cuda_chunked"] = torch.cat(outputs, dim=-1)
@@ -431,8 +433,7 @@ def test_all_paths_match_cross_device(self, conv, dim):
for name, result in results.items():
tol = tolerances[name]
torch.testing.assert_close(
- result, reference, atol=tol, rtol=tol,
- msg=f"Path '{name}' diverged from reference"
+ result, reference, atol=tol, rtol=tol, msg=f"Path '{name}' diverged from reference"
)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required")
@@ -468,8 +469,8 @@ def test_long_decode_no_drift(self, conv, dim):
# Check no systematic drift (errors shouldn't consistently increase)
decode_errors = errors[prefill_len:]
- first_half = decode_errors[:len(decode_errors)//2].mean()
- second_half = decode_errors[len(decode_errors)//2:].mean()
+ first_half = decode_errors[: len(decode_errors) // 2].mean()
+ second_half = decode_errors[len(decode_errors) // 2 :].mean()
assert second_half < first_half * 2, "Errors growing over decode steps (drift detected)"
diff --git a/fast_llm_external_models/tests/test_apriel2/test_compose_configs.py b/fast_llm_external_models/tests/test_apriel2/test_compose_configs.py
index 0bd6ac88d..3413b9d25 100644
--- a/fast_llm_external_models/tests/test_apriel2/test_compose_configs.py
+++ b/fast_llm_external_models/tests/test_apriel2/test_compose_configs.py
@@ -20,7 +20,7 @@
import yaml
from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config
-from fast_llm_external_models.apriel2.conversion.config import apply_surgery, compose_configs
+from fast_llm_external_models.apriel2.conversion.config import compose_configs
class TestComposeConfigsLaws:
@@ -75,14 +75,10 @@ def source_config(self):
},
}
- def test_identity_empty_surgery(self, source_config):
- """Law 1: compose_configs(config, {}) == config"""
- result = compose_configs(source_config, {})
- assert result == source_config
-
- def test_identity_none_surgery(self, source_config):
- """Law 1: compose_configs(config, None) == config"""
- result = compose_configs(source_config, None)
+ @pytest.mark.parametrize("empty_surgery", [{}, None])
+ def test_identity(self, source_config, empty_surgery):
+ """Law 1: compose_configs(config, empty) == config for empty in [{}, None]"""
+ result = compose_configs(source_config, empty_surgery)
assert result == source_config
def test_override_explicit_values(self, source_config):
@@ -114,7 +110,7 @@ def test_same_type_inheritance(self, source_config):
assert mixer["head_size"] == 32 # Inherited
assert mixer["rope_theta"] == 10000.0 # Inherited
assert mixer["window_size"] == 512 # Added
- assert "init" not in mixer # Stripped by apply_surgery
+ # init is preserved for plan_surgery to see (stripped only at final output)
def test_cross_type_attention_to_gdn(self, source_config):
"""Law 5: attentionโgdn derives GDN dims from attention geometry."""
@@ -239,8 +235,14 @@ def test_null_deletion(self, source_config):
assert "vision_encoder" not in result
- def test_init_stripped_from_result(self, source_config):
- """Verify `init` keys are stripped from final result."""
+ def test_init_preserved_for_plan_surgery(self, source_config):
+ """Verify `init` keys are preserved so plan_surgery can see them.
+
+ The `init` field controls weight initialization (transfer vs random).
+ It's preserved through composition and only stripped at final output.
+ """
+ from fast_llm_external_models.apriel2.conversion.config import strip_init_fields
+
surgery = {
"decoder": {
"block": {
@@ -252,20 +254,20 @@ def test_init_stripped_from_result(self, source_config):
"gdn": {"type": "gdn", "init": "random", "convolution_layer": {"kernel_size": 4}},
},
},
- "mlp": {"init": "transfer"},
- "normalization": {"init": "transfer"},
},
},
}
result = compose_configs(source_config, surgery)
- def check_no_init(d, path=""):
- assert "init" not in d, f"Found 'init' key at {path}"
- for k, v in d.items():
- if isinstance(v, dict):
- check_no_init(v, f"{path}.{k}")
+ # init is preserved in composed config
+ mixers = result["decoder"]["block"]["mixer"]["mixers"]
+ assert mixers["attention"].get("init") == "transfer"
+ assert mixers["gdn"].get("init") == "random"
- check_no_init(result)
+ # strip_init_fields removes them for final output
+ stripped = strip_init_fields(result)
+ assert "init" not in stripped["decoder"]["block"]["mixer"]["mixers"]["attention"]
+ assert "init" not in stripped["decoder"]["block"]["mixer"]["mixers"]["gdn"]
def test_init_random_still_inherits_config(self, source_config):
"""init: random is for weights only - config params still inherited."""
@@ -287,6 +289,212 @@ def test_init_random_still_inherits_config(self, source_config):
assert mixer["head_groups"] == 4
assert mixer["window_size"] == 512
+ # =========================================================================
+ # Monoid Laws: compose_configs forms a monoid action on configs
+ # =========================================================================
+
+ def test_surgery_monoid_associativity(self):
+ """MONOID: merge(merge(A, B), C) == merge(A, merge(B, C)) for partial configs."""
+ surgery_a = {"decoder": {"block": {"mixer": {"type": "stochastic", "main_mixer_name": "attention"}}}}
+ surgery_b = {"decoder": {"block": {"mixer": {"mixers": {"sliding_window": {"window_size": 512}}}}}}
+ surgery_c = {"decoder": {"block": {"mixer": {"mixers": {"gdn": {"type": "gdn"}}}}}}
+
+ # Left-associated: (A โ B) โ C
+ ab_c = compose_configs(compose_configs(surgery_a, surgery_b), surgery_c)
+ # Right-associated: A โ (B โ C)
+ a_bc = compose_configs(surgery_a, compose_configs(surgery_b, surgery_c))
+
+ assert ab_c == a_bc, "Surgery monoid should be associative"
+
+ @pytest.mark.parametrize("num_surgeries", [2, 3])
+ def test_monoid_action_compatibility(self, source_config, num_surgeries):
+ """MONOID ACTION: apply(apply(c, A), B) == apply(c, merge(A, B))
+
+ This is the key law: applying surgeries sequentially equals merging first.
+ Parameterized to test with 2 and 3 surgeries.
+ """
+ surgeries = [
+ {
+ "decoder": {
+ "block": {
+ "mixer": {"type": "stochastic", "main_mixer_name": "attention", "mixers": {"attention": {}}}
+ }
+ }
+ },
+ {"decoder": {"block": {"mixer": {"mixers": {"sliding_window": {"window_size": 512}}}}}},
+ {"decoder": {"block": {"mixer": {"mixers": {"gdn": {"type": "gdn"}}}}}},
+ ][:num_surgeries]
+
+ # Sequential: ((c โณ A) โณ B) โณ ...
+ result_sequential = source_config
+ for s in surgeries:
+ result_sequential = compose_configs(result_sequential, s)
+
+ # Merged: c โณ (A โ B โ ...)
+ merged = surgeries[0]
+ for s in surgeries[1:]:
+ merged = compose_configs(merged, s)
+ result_merged = compose_configs(source_config, merged)
+
+ assert result_sequential == result_merged, f"Monoid action compatibility failed for {num_surgeries} surgeries"
+
+
+class TestBiasConfigInheritance:
+ """Test per-layer bias inheritance through surgery composition.
+
+ These tests verify that the per-layer bias configuration (mirroring Fast-LLM's
+ AffineLinearConfig) is correctly inherited through surgery operations:
+ - query_layer.bias.enabled, key_layer.bias.enabled, etc. for attention
+ - layer_1.bias.enabled, layer_2.bias.enabled for MLP
+ """
+
+ @pytest.fixture
+ def source_config_with_bias(self):
+ """Source config with Qwen-style bias (QKV enabled, O disabled)."""
+ return {
+ "model_type": "apriel2",
+ "architectures": ["Apriel2ForCausalLM"],
+ "hidden_size": 256,
+ "vocab_size": 1000,
+ "decoder": {
+ "type": "fixed",
+ "num_blocks": 4,
+ "block": {
+ "mixer": {
+ "type": "attention",
+ "heads": 8,
+ "head_groups": 4,
+ "head_size": 32,
+ "rotary": {"type": "mistral_1d", "theta": 10000.0},
+ # Qwen-style per-layer bias
+ "query_layer": {"bias": {"enabled": True}},
+ "key_layer": {"bias": {"enabled": True}},
+ "value_layer": {"bias": {"enabled": True}},
+ "dense_layer": {"bias": {"enabled": False}},
+ },
+ "mlp": {
+ "type": "mlp",
+ "intermediate_size": 512,
+ "gated": False,
+ # Per-layer MLP bias
+ "layer_1": {"bias": {"enabled": True}},
+ "layer_2": {"bias": {"enabled": False}},
+ },
+ "normalization": {"type": "rms_norm", "epsilon": 1e-5},
+ },
+ },
+ }
+
+ def test_same_type_inherits_attention_bias(self, source_config_with_bias):
+ """Same-type surgery inherits per-layer attention bias settings."""
+ surgery = {
+ "decoder": {
+ "block": {
+ "mixer": {
+ "window_size": 512, # Add sliding window behavior
+ },
+ },
+ },
+ }
+ result = compose_configs(source_config_with_bias, surgery)
+
+ mixer = result["decoder"]["block"]["mixer"]
+ assert mixer["query_layer"]["bias"]["enabled"] is True
+ assert mixer["key_layer"]["bias"]["enabled"] is True
+ assert mixer["value_layer"]["bias"]["enabled"] is True
+ assert mixer["dense_layer"]["bias"]["enabled"] is False
+
+ def test_same_type_inherits_mlp_bias(self, source_config_with_bias):
+ """Same-type surgery inherits per-layer MLP bias settings."""
+ surgery = {
+ "decoder": {
+ "block": {
+ "mlp": {
+ "intermediate_size": 1024, # Change size
+ },
+ },
+ },
+ }
+ result = compose_configs(source_config_with_bias, surgery)
+
+ mlp = result["decoder"]["block"]["mlp"]
+ assert mlp["layer_1"]["bias"]["enabled"] is True
+ assert mlp["layer_2"]["bias"]["enabled"] is False
+ assert mlp["intermediate_size"] == 1024
+
+ def test_cross_type_attention_to_sliding_window_preserves_bias(self, source_config_with_bias):
+ """attentionโsliding_window cross-type preserves per-layer bias."""
+ surgery = {
+ "decoder": {
+ "block": {
+ "mixer": {
+ "type": "sliding_window", # Cross-type derivation
+ "window_size": 512,
+ },
+ },
+ },
+ }
+ result = compose_configs(source_config_with_bias, surgery)
+
+ mixer = result["decoder"]["block"]["mixer"]
+ assert mixer["type"] == "sliding_window"
+ # Bias settings preserved through cross-type
+ assert mixer["query_layer"]["bias"]["enabled"] is True
+ assert mixer["key_layer"]["bias"]["enabled"] is True
+ assert mixer["value_layer"]["bias"]["enabled"] is True
+ assert mixer["dense_layer"]["bias"]["enabled"] is False
+
+ def test_stochastic_wrapper_inherits_bias(self, source_config_with_bias):
+ """Wrapping in stochastic inherits bias settings to all sub-mixers."""
+ surgery = {
+ "decoder": {
+ "block": {
+ "mixer": {
+ "type": "stochastic",
+ "main_mixer_name": "attention",
+ "mixers": {
+ "attention": {"init": "transfer"},
+ "sliding_window": {
+ "type": "sliding_window",
+ "window_size": 512,
+ "init": "transfer",
+ },
+ },
+ },
+ },
+ },
+ }
+ result = compose_configs(source_config_with_bias, surgery)
+
+ mixers = result["decoder"]["block"]["mixer"]["mixers"]
+
+ # Attention sub-mixer inherits bias
+ assert mixers["attention"]["query_layer"]["bias"]["enabled"] is True
+ assert mixers["attention"]["dense_layer"]["bias"]["enabled"] is False
+
+ # Sliding window sub-mixer also inherits bias
+ assert mixers["sliding_window"]["query_layer"]["bias"]["enabled"] is True
+ assert mixers["sliding_window"]["dense_layer"]["bias"]["enabled"] is False
+
+ def test_surgery_can_override_bias(self, source_config_with_bias):
+ """Surgery can explicitly override inherited bias settings."""
+ surgery = {
+ "decoder": {
+ "block": {
+ "mixer": {
+ "dense_layer": {"bias": {"enabled": True}}, # Override O bias
+ },
+ },
+ },
+ }
+ result = compose_configs(source_config_with_bias, surgery)
+
+ mixer = result["decoder"]["block"]["mixer"]
+ # Q/K/V unchanged
+ assert mixer["query_layer"]["bias"]["enabled"] is True
+ # O bias overridden
+ assert mixer["dense_layer"]["bias"]["enabled"] is True
+
class TestComposeConfigsRealYAML:
"""Test compose_configs with real YAML surgery files."""
@@ -398,160 +606,12 @@ def test_build_plan_returns_complete_config(self, llava_pixtral_checkpoint):
mixer = config.decoder["block"]["mixer"]
assert mixer["type"] == "stochastic"
- # Each sub-mixer should have complete config (no init keys)
+ # Each sub-mixer should have complete config
+ # (init is preserved for plan_surgery, stripped only at final output)
for name, sub_mixer in mixer["mixers"].items():
- assert "init" not in sub_mixer, f"Sub-mixer {name} still has 'init' key"
assert "type" in sub_mixer
-class TestMonoidLaws:
- """Test the algebraic laws of compose_configs.
-
- Surgery specs form a MONOID under deep-merge:
- - Identity: {}
- - Operation: deep merge (overlay wins)
- - Associativity: merge(merge(A, B), C) == merge(A, merge(B, C))
-
- compose_configs is a MONOID ACTION on configs:
- - Identity action: apply(config, {}) == config
- - Compatibility: apply(apply(c, A), B) == apply(c, merge(A, B))
- """
-
- @pytest.fixture
- def complete_config(self):
- """A complete Apriel2 config."""
- return {
- "model_type": "apriel2",
- "architectures": ["Apriel2ForConditionalGeneration"],
- "hidden_size": 256,
- "vocab_size": 1000,
- "bos_token_id": 1,
- "eos_token_id": 2,
- "tie_word_embeddings": False,
- "image_token_index": 100,
- "decoder": {
- "type": "fixed",
- "num_blocks": 4,
- "block": {
- "mixer": {
- "type": "attention",
- "heads": 8,
- "head_groups": 4,
- "head_size": 32,
- "rope_theta": 10000.0,
- },
- "mlp": {"type": "mlp", "intermediate_size": 512},
- "normalization": {"type": "rms_norm", "epsilon": 1e-5},
- },
- },
- }
-
- @pytest.fixture
- def surgery_a(self):
- """First surgery: wrap in stochastic with attention."""
- return {
- "decoder": {
- "block": {
- "mixer": {
- "type": "stochastic",
- "main_mixer_name": "attention",
- "mixers": {
- "attention": {"init": "transfer"},
- },
- },
- },
- },
- }
-
- @pytest.fixture
- def surgery_b(self):
- """Second surgery: add sliding window mixer."""
- return {
- "decoder": {
- "block": {
- "mixer": {
- "mixers": {
- "sliding_window": {"init": "transfer", "window_size": 512},
- },
- },
- },
- },
- }
-
- def test_identity_action(self, complete_config):
- """apply(config, {}) == config"""
- result = compose_configs(complete_config, {})
- assert result == complete_config
-
- def test_surgery_monoid_associativity(self, surgery_a, surgery_b):
- """merge(merge(A, B), C) == merge(A, merge(B, C)) for partial configs."""
- surgery_c = {
- "decoder": {
- "block": {
- "mixer": {
- "mixers": {
- "gdn": {"type": "gdn", "init": "transfer", "convolution_layer": {"kernel_size": 4}},
- },
- },
- },
- },
- }
-
- # Left-associated: (A โ B) โ C
- ab = compose_configs(surgery_a, surgery_b)
- ab_c = compose_configs(ab, surgery_c)
-
- # Right-associated: A โ (B โ C)
- bc = compose_configs(surgery_b, surgery_c)
- a_bc = compose_configs(surgery_a, bc)
-
- assert ab_c == a_bc, "Surgery monoid should be associative"
-
- def test_monoid_action_compatibility(self, complete_config, surgery_a, surgery_b):
- """apply(apply(c, A), B) == apply(c, merge(A, B))
-
- This is the key law: applying surgeries sequentially should equal
- merging the surgeries first, then applying once.
- """
- # Sequential application: (c โณ A) โณ B
- result_sequential = compose_configs(compose_configs(complete_config, surgery_a), surgery_b)
-
- # Merged application: c โณ (A โ B)
- merged_surgery = compose_configs(surgery_a, surgery_b)
- result_merged = compose_configs(complete_config, merged_surgery)
-
- # These should be equivalent
- assert result_sequential == result_merged, "Monoid action should satisfy compatibility law"
-
- def test_three_way_compatibility(self, complete_config, surgery_a, surgery_b):
- """Test with three surgeries for stronger confidence."""
- surgery_c = {
- "decoder": {
- "block": {
- "mixer": {
- "mixers": {
- "gdn": {"type": "gdn", "init": "transfer", "convolution_layer": {"kernel_size": 4}},
- },
- },
- },
- },
- }
-
- # Sequential: ((c โณ A) โณ B) โณ C
- seq = compose_configs(
- compose_configs(compose_configs(complete_config, surgery_a), surgery_b),
- surgery_c
- )
-
- # Merged: c โณ ((A โ B) โ C)
- merged = compose_configs(
- complete_config,
- compose_configs(compose_configs(surgery_a, surgery_b), surgery_c)
- )
-
- assert seq == merged, "Three-way monoid action should satisfy compatibility"
-
-
class TestCompositionTortureTest:
"""Comprehensive stress test for config composition.
@@ -650,19 +710,29 @@ def test_final_config_structure(self, complete_config, additive_surgery_chain):
assert mixer["mixers"]["sliding_window"]["window_size"] == 512
assert mixer["mixers"]["gdn"]["value_heads"] == 16
- def test_no_init_keys_in_result(self, complete_config, additive_surgery_chain):
- """Verify no 'init' keys leak through."""
+ def test_init_keys_preserved_for_planning(self, complete_config, additive_surgery_chain):
+ """Verify 'init' keys are preserved for plan_surgery to see.
- def check_no_init(d, path=""):
- if isinstance(d, dict):
- assert "init" not in d, f"Found 'init' key at {path}"
- for k, v in d.items():
- check_no_init(v, f"{path}.{k}")
+ The `init` field is metadata for weight initialization. It's preserved
+ through composition and only stripped when saving final output.
+ """
+ from fast_llm_external_models.apriel2.conversion.config import strip_init_fields
result = complete_config
for i, surgery in enumerate(additive_surgery_chain):
result = compose_configs(result, surgery)
- check_no_init(result, f"step_{i+1}")
+
+ # init should be in the composed config
+ mixer = result["decoder"]["block"]["mixer"]
+ if "mixers" in mixer:
+ has_init = any("init" in m for m in mixer["mixers"].values())
+ assert has_init, "init should be preserved in composed config"
+
+ # strip_init_fields removes them
+ stripped = strip_init_fields(result)
+ mixer = stripped["decoder"]["block"]["mixer"]
+ if "mixers" in mixer:
+ assert all("init" not in m for m in mixer["mixers"].values())
def test_full_torture_chain(self, complete_config, torture_surgery_chain):
"""Test the full 10-step torture chain produces valid configs."""
diff --git a/fast_llm_external_models/tests/test_apriel2/test_plan_composition_torture.py b/fast_llm_external_models/tests/test_apriel2/test_conversion_e2e.py
similarity index 78%
rename from fast_llm_external_models/tests/test_apriel2/test_plan_composition_torture.py
rename to fast_llm_external_models/tests/test_apriel2/test_conversion_e2e.py
index 3b4adc7f5..b91fb7e51 100644
--- a/fast_llm_external_models/tests/test_apriel2/test_plan_composition_torture.py
+++ b/fast_llm_external_models/tests/test_apriel2/test_conversion_e2e.py
@@ -1,6 +1,6 @@
-"""End-to-end torture test for plan composition.
+"""test_conversion_e2e.py - End-to-end conversion integration tests.
-This tests the FULL pipeline at every step of a surgery chain:
+Tests the FULL pipeline at every step of a surgery chain:
1. Config composition produces valid configs
2. Plan building works for each surgery
3. Plan execution produces valid weights
@@ -16,21 +16,12 @@
import pytest
import torch
-from fast_llm_external_models.tests.test_apriel2.conftest import requires_cuda
-
from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config
-from fast_llm_external_models.apriel2.conversion import (
- compose,
- compose_configs,
- execute,
- plan_surgery,
-)
-from fast_llm_external_models.apriel2.conversion.llava import (
- convert_config as convert_llava_config,
- plan_llava_to_apriel2,
-)
+from fast_llm_external_models.apriel2.conversion import compose, compose_configs, execute, plan_surgery
+from fast_llm_external_models.apriel2.conversion.llava import convert_config as convert_llava_config
+from fast_llm_external_models.apriel2.conversion.llava import plan_llava_to_apriel2
from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForConditionalGeneration
-
+from fast_llm_external_models.tests.test_apriel2.conftest import requires_cuda
# =============================================================================
# Cycling Surgery Generation
@@ -87,40 +78,20 @@ def generate_cycling_surgeries(config: dict) -> list[tuple[dict, str]]:
if sub_name != main_mixer:
# Build surgery path based on block_path
if block_path == "block":
- surgery = {
- "decoder": {
- "block": {"mixer": {"main_mixer_name": sub_name}}
- }
- }
+ surgery = {"decoder": {"block": {"mixer": {"main_mixer_name": sub_name}}}}
else:
# block_path is "blocks.block_name"
block_name = block_path.split(".")[1]
- surgery = {
- "decoder": {
- "blocks": {
- block_name: {"mixer": {"main_mixer_name": sub_name}}
- }
- }
- }
+ surgery = {"decoder": {"blocks": {block_name: {"mixer": {"main_mixer_name": sub_name}}}}}
surgeries.append((surgery, f"cycle {block_path} to {sub_name}"))
# Restore original main_mixer_name
if any(sub_name != main_mixer for sub_name in sub_mixer_names):
if block_path == "block":
- restore = {
- "decoder": {
- "block": {"mixer": {"main_mixer_name": main_mixer}}
- }
- }
+ restore = {"decoder": {"block": {"mixer": {"main_mixer_name": main_mixer}}}}
else:
block_name = block_path.split(".")[1]
- restore = {
- "decoder": {
- "blocks": {
- block_name: {"mixer": {"main_mixer_name": main_mixer}}
- }
- }
- }
+ restore = {"decoder": {"blocks": {block_name: {"mixer": {"main_mixer_name": main_mixer}}}}}
surgeries.append((restore, f"restore {block_path} to {main_mixer}"))
return surgeries
@@ -194,9 +165,7 @@ def source_config(self, llava_pixtral_checkpoint):
with open(llava_pixtral_checkpoint / "config.json") as f:
return json.load(f)
- def test_initial_conversion_produces_working_model(
- self, source_config, source_weights
- ):
+ def test_initial_conversion_produces_working_model(self, source_config, source_weights):
"""Test that Llava โ Apriel2 conversion produces a working model."""
# Convert config
apriel2_config_dict = convert_llava_config(source_config)
@@ -219,9 +188,7 @@ def test_initial_conversion_produces_working_model(
assert outputs.logits.shape == (1, 8, config.vocab_size)
- def test_each_surgery_step_produces_working_model(
- self, source_config, source_weights, additive_surgery_chain
- ):
+ def test_each_surgery_step_produces_working_model(self, source_config, source_weights, additive_surgery_chain):
"""Test that each surgery step produces a model that can forward pass.
Key insight: Surgery plans reference Apriel2 keys, so we must COMPOSE
@@ -290,9 +257,7 @@ def test_each_surgery_step_produces_working_model(
except Exception as e:
pytest.fail(f"Step {i+1}: Forward pass failed - {e}")
- def test_all_stochastic_submixers_via_cycling(
- self, source_config, source_weights, additive_surgery_chain
- ):
+ def test_all_stochastic_submixers_via_cycling(self, source_config, source_weights, additive_surgery_chain):
"""Test ALL sub-mixers in stochastic blocks, not just the main mixer.
Problem: Forward pass only exercises the main_mixer_name. Other sub-mixers
@@ -312,9 +277,7 @@ def test_all_stochastic_submixers_via_cycling(
conversion_plan = plan_llava_to_apriel2(source_config)
# Expand surgery chain with cycling
- expanded_chain = expand_surgery_chain_with_cycling(
- additive_surgery_chain, apriel2_config
- )
+ expanded_chain = expand_surgery_chain_with_cycling(additive_surgery_chain, apriel2_config)
# Build cumulative plan: conversion | surgery_1 | cycling_1a | ... | restore_1 | surgery_2 | ...
current_plan = conversion_plan
@@ -359,9 +322,7 @@ def test_all_stochastic_submixers_via_cycling(
except Exception as e:
pytest.fail(f"{desc}: Forward pass failed - {e}")
- def test_composed_plan_equals_sequential_execution(
- self, source_config, source_weights, additive_surgery_chain
- ):
+ def test_composed_plan_equals_sequential_execution(self, source_config, source_weights, additive_surgery_chain):
"""Test that composing plans gives same result as sequential execution.
This verifies plan composition associativity:
@@ -399,13 +360,9 @@ def test_composed_plan_equals_sequential_execution(
# Compare weights
for key in seq_weights:
if key in composed_weights:
- assert torch.allclose(
- seq_weights[key], composed_weights[key], atol=1e-5
- ), f"Weight mismatch for {key}"
+ assert torch.allclose(seq_weights[key], composed_weights[key], atol=1e-5), f"Weight mismatch for {key}"
- def test_final_model_structure(
- self, source_config, source_weights, additive_surgery_chain
- ):
+ def test_final_model_structure(self, source_config, source_weights, additive_surgery_chain):
"""Verify the final model has the expected structure."""
# Initial conversion
current_config = convert_llava_config(source_config)
@@ -504,9 +461,7 @@ def base_setup(self, llava_pixtral_checkpoint):
"""Set up base config and weights after Llava conversion."""
from safetensors.torch import load_file
- from fast_llm_external_models.apriel2.conversion.llava import (
- convert_config as convert_llava_config,
- )
+ from fast_llm_external_models.apriel2.conversion.llava import convert_config as convert_llava_config
# Load source config and weights
with open(llava_pixtral_checkpoint / "config.json") as f:
@@ -534,9 +489,7 @@ def _merge_surgeries(self, surgeries: list[dict]) -> dict:
result = _deep_merge(result, s)
return result
- def _build_incremental_plans(
- self, base_config: dict, surgeries: list[dict]
- ) -> tuple[list, list[dict]]:
+ def _build_incremental_plans(self, base_config: dict, surgeries: list[dict]) -> tuple[list, list[dict]]:
"""Build incremental plans for each surgery step.
Returns (plans, configs) where configs[i] is the config after surgery i.
@@ -552,9 +505,7 @@ def _build_incremental_plans(
config = target_config
return plans, configs
- def test_incremental_equals_direct_full_chain(
- self, base_setup, additive_surgery_chain
- ):
+ def test_incremental_equals_direct_full_chain(self, base_setup, additive_surgery_chain):
"""Test that composing all incremental plans equals one direct plan.
compose(P1, P2, ..., Pn) โก plan_surgery(base, final)
@@ -575,9 +526,7 @@ def test_incremental_equals_direct_full_chain(
direct_plan = plan_surgery(base_config, final_config)
# Verify same target keys
- assert set(composed_plan.mappings.keys()) == set(
- direct_plan.mappings.keys()
- ), "Plan keys should match"
+ assert set(composed_plan.mappings.keys()) == set(direct_plan.mappings.keys()), "Plan keys should match"
# Execute both and compare weights
composed_weights = execute(composed_plan, base_weights, seed=0)
@@ -611,9 +560,7 @@ def test_every_prefix_consistency(self, base_setup, additive_surgery_chain):
direct = plan_surgery(base_config, configs[k])
# Verify keys match
- assert set(composed.mappings.keys()) == set(
- direct.mappings.keys()
- ), f"Prefix {k}: keys don't match"
+ assert set(composed.mappings.keys()) == set(direct.mappings.keys()), f"Prefix {k}: keys don't match"
# Execute and compare
composed_weights = execute(composed, base_weights, seed=0)
@@ -781,9 +728,7 @@ def torture_setup(self, llava_pixtral_checkpoint):
"""Set up for comprehensive torture tests."""
from safetensors.torch import load_file
- from fast_llm_external_models.apriel2.conversion.llava import (
- convert_config as convert_llava_config,
- )
+ from fast_llm_external_models.apriel2.conversion.llava import convert_config as convert_llava_config
# Load source
with open(llava_pixtral_checkpoint / "config.json") as f:
@@ -801,9 +746,7 @@ def torture_setup(self, llava_pixtral_checkpoint):
return base_config, base_weights
- def test_each_step_produces_valid_config(
- self, torture_setup, comprehensive_torture_chain
- ):
+ def test_each_step_produces_valid_config(self, torture_setup, comprehensive_torture_chain):
"""Test that each surgery step produces a valid config."""
base_config, _ = torture_setup
@@ -818,9 +761,7 @@ def test_each_step_produces_valid_config(
pytest.fail(f"Step {i+1} produced invalid config: {e}")
@requires_cuda
- def test_each_step_produces_working_model(
- self, torture_setup, comprehensive_torture_chain
- ):
+ def test_each_step_produces_working_model(self, torture_setup, comprehensive_torture_chain):
"""Test that each surgery step produces a model that can forward pass.
This is the ultimate integration test - config composition + plan building
@@ -875,9 +816,7 @@ def test_each_step_produces_working_model(
current_weights = new_weights
@requires_cuda
- def test_final_supernet_structure(
- self, torture_setup, comprehensive_torture_chain
- ):
+ def test_final_supernet_structure(self, torture_setup, comprehensive_torture_chain):
"""Verify the final architecture has supernet blocks with all 4 mixer types."""
base_config, base_weights = torture_setup
@@ -914,9 +853,7 @@ def test_final_supernet_structure(
assert outputs.logits.shape == (1, 8, config.vocab_size)
@requires_cuda
- def test_plan_config_consistency_comprehensive(
- self, torture_setup, comprehensive_torture_chain
- ):
+ def test_plan_config_consistency_comprehensive(self, torture_setup, comprehensive_torture_chain):
"""Test that incremental plan composition works for the comprehensive chain.
Note: We cannot compare to a "direct plan" because the comprehensive chain
@@ -1083,66 +1020,6 @@ def mamba_config(self):
},
}
- def test_config_composition_identical_regardless_of_init_mode(self, base_config):
- """Config composition produces same structure with init: transfer vs init: random."""
- # Surgery with init: transfer
- surgery_transfer = {
- "decoder": {
- "block": {
- "mixer": {
- "type": "stochastic",
- "main_mixer_name": "attention",
- "mixers": {
- "attention": {"type": "attention", "init": "transfer"},
- "swa": {
- "type": "attention",
- "init": "transfer",
- "sliding_window": 512,
- },
- },
- },
- },
- },
- }
-
- # Surgery with init: random
- surgery_random = {
- "decoder": {
- "block": {
- "mixer": {
- "type": "stochastic",
- "main_mixer_name": "attention",
- "mixers": {
- "attention": {"type": "attention", "init": "random"},
- "swa": {
- "type": "attention",
- "init": "random",
- "sliding_window": 512,
- },
- },
- },
- },
- },
- }
-
- # Compose configs
- result_transfer = compose_configs(base_config, surgery_transfer)
- result_random = compose_configs(base_config, surgery_random)
-
- # Both should produce identical structure (init is stripped)
- assert result_transfer == result_random, (
- "Config composition should produce identical structure regardless of init mode"
- )
-
- # Verify the structure is correct
- mixer = result_transfer["decoder"]["block"]["mixer"]
- assert mixer["type"] == "stochastic"
- assert "attention" in mixer["mixers"]
- assert "swa" in mixer["mixers"]
- # init should be stripped
- assert "init" not in mixer["mixers"]["attention"]
- assert "init" not in mixer["mixers"]["swa"]
-
def test_plan_surgery_random_succeeds_for_any_type_pair(self, mamba_config):
"""plan_surgery with init: random should succeed even for mamba -> attention."""
# This surgery changes mamba to attention with random init
@@ -1166,7 +1043,7 @@ def test_plan_surgery_random_succeeds_for_any_type_pair(self, mamba_config):
plan = plan_surgery(mamba_config, surgery)
# Verify the plan has the expected target keys
- target_keys = set(str(k) for k in plan.mappings.keys())
+ target_keys = {str(k) for k in plan.mappings.keys()}
assert any("mixer.q_proj" in k for k in target_keys)
def test_plan_surgery_transfer_fails_for_unsupported_type_pair(self, mamba_config):
@@ -1219,7 +1096,7 @@ def test_plan_surgery_transfer_succeeds_for_supported_type_pair(self, base_confi
plan = plan_surgery(base_config, surgery)
# Verify the plan has mamba target keys
- target_keys = set(str(k) for k in plan.mappings.keys())
+ target_keys = {str(k) for k in plan.mappings.keys()}
assert any("mixer.in_proj" in k for k in target_keys)
def test_stochastic_init_random_succeeds_for_any_submixer_type(self, mamba_config):
@@ -1259,7 +1136,7 @@ def test_stochastic_init_random_succeeds_for_any_submixer_type(self, mamba_confi
plan = plan_surgery(mamba_config, surgery)
# Verify both sub-mixers have target keys
- target_keys = set(str(k) for k in plan.mappings.keys())
+ target_keys = {str(k) for k in plan.mappings.keys()}
assert any("mixers.attention.q_proj" in k for k in target_keys)
assert any("mixers.swa.q_proj" in k for k in target_keys)
@@ -1294,7 +1171,7 @@ def test_mixed_init_modes_in_stochastic(self, base_config):
plan = plan_surgery(base_config, surgery)
# Verify both sub-mixers have target keys
- target_keys = set(str(k) for k in plan.mappings.keys())
+ target_keys = {str(k) for k in plan.mappings.keys()}
assert any("mixers.attention.q_proj" in k for k in target_keys)
assert any("mixers.gdn.in_proj_qkvz" in k for k in target_keys)
@@ -1313,8 +1190,8 @@ class TestMarkovianProperty:
"""
@pytest.fixture
- def attention_config(self):
- """Base config with attention."""
+ def attention_config_dict(self):
+ """Base config dict with attention mixer for compose_configs tests."""
return {
"model_type": "apriel2",
"hidden_size": 256,
@@ -1335,43 +1212,7 @@ def attention_config(self):
},
}
- @pytest.fixture
- def stochastic_config(self):
- """Config with stochastic mixer."""
- return {
- "model_type": "apriel2",
- "hidden_size": 256,
- "vocab_size": 1000,
- "decoder": {
- "type": "fixed",
- "num_blocks": 2,
- "block": {
- "mixer": {
- "type": "stochastic",
- "main_mixer_name": "attention",
- "mixers": {
- "attention": {
- "type": "attention",
- "heads": 8,
- "head_groups": 4,
- "head_size": 32,
- },
- "swa": {
- "type": "sliding_window",
- "heads": 8,
- "head_groups": 4,
- "head_size": 32,
- "window_size": 512,
- },
- },
- },
- "mlp": {"type": "mlp", "intermediate_size": 256},
- "normalization": {"type": "rms_norm", "epsilon": 1e-5},
- },
- },
- }
-
- def test_different_paths_same_config_same_plan(self, attention_config):
+ def test_different_paths_same_config_same_plan(self, attention_config_dict):
"""Two different paths to the same config produce identical plans.
Path A: attention -> stochastic{att, swa}
@@ -1398,7 +1239,7 @@ def test_different_paths_same_config_same_plan(self, attention_config):
},
},
}
- config_a = compose_configs(attention_config, surgery_a)
+ config_a = compose_configs(attention_config_dict, surgery_a)
# Path B: First add attention only, then add swa
surgery_b1 = {
@@ -1414,7 +1255,7 @@ def test_different_paths_same_config_same_plan(self, attention_config):
},
},
}
- intermediate_config = compose_configs(attention_config, surgery_b1)
+ intermediate_config = compose_configs(attention_config_dict, surgery_b1)
surgery_b2 = {
"decoder": {
@@ -1465,11 +1306,11 @@ def test_different_paths_same_config_same_plan(self, attention_config):
plan_from_b = plan_surgery(config_b, final_surgery)
# Compare plan mappings
- keys_a = set(str(k) for k in plan_from_a.mappings.keys())
- keys_b = set(str(k) for k in plan_from_b.mappings.keys())
+ keys_a = {str(k) for k in plan_from_a.mappings.keys()}
+ keys_b = {str(k) for k in plan_from_b.mappings.keys()}
assert keys_a == keys_b, "Plans from same config via different paths should be identical"
- def test_init_in_source_config_does_not_affect_plan(self, attention_config):
+ def test_init_in_source_config_does_not_affect_plan(self, attention_config_dict):
"""Manually injecting init into source config doesn't change the plan.
This tests that plan_surgery reads init from surgery, not source.
@@ -1479,8 +1320,8 @@ def test_init_in_source_config_does_not_affect_plan(self, attention_config):
import copy
# Create two copies of the config
- config_with_init = copy.deepcopy(attention_config)
- config_without_init = copy.deepcopy(attention_config)
+ config_with_init = copy.deepcopy(attention_config_dict)
+ config_without_init = copy.deepcopy(attention_config_dict)
# Manually inject init into one (bypassing compose_configs)
config_with_init["decoder"]["block"]["mixer"]["init"] = "random"
@@ -1504,238 +1345,12 @@ def test_init_in_source_config_does_not_affect_plan(self, attention_config):
plan_with = plan_surgery(config_with_init, surgery)
plan_without = plan_surgery(config_without_init, surgery)
- keys_with = set(str(k) for k in plan_with.mappings.keys())
- keys_without = set(str(k) for k in plan_without.mappings.keys())
+ keys_with = {str(k) for k in plan_with.mappings.keys()}
+ keys_without = {str(k) for k in plan_without.mappings.keys()}
# Plans should be identical - source's init field is ignored
assert keys_with == keys_without, "Plan should not depend on init in source config"
- def test_associativity_of_surgery_composition(self, attention_config):
- """Verify associativity: (A โ B) โ C == A โ (B โ C) for surgery specs.
-
- This tests that composing surgeries is associative, which is
- equivalent to Markovianity for plan creation.
- """
- surgery_a = {
- "decoder": {
- "block": {
- "mixer": {
- "type": "stochastic",
- "main_mixer_name": "attention",
- "mixers": {
- "attention": {"type": "attention", "init": "transfer"},
- },
- },
- },
- },
- }
-
- surgery_b = {
- "decoder": {
- "block": {
- "mixer": {
- "mixers": {
- "swa": {
- "type": "sliding_window",
- "init": "transfer",
- "window_size": 512,
- },
- },
- },
- },
- },
- }
-
- surgery_c = {
- "decoder": {
- "block": {
- "mixer": {
- "mixers": {
- "gdn": {
- "type": "gdn",
- "init": "random",
- "value_heads": 8,
- "key_heads": 4,
- "key_head_dim": 32,
- "value_head_dim": 32,
- "convolution_layer": {"kernel_size": 4},
- },
- },
- },
- },
- },
- }
-
- # Left association: ((attention_config โ A) โ B) โ C
- left_1 = compose_configs(attention_config, surgery_a)
- left_2 = compose_configs(left_1, surgery_b)
- left_result = compose_configs(left_2, surgery_c)
-
- # Right association: (attention_config โ A) โ (B โ C)
- # Note: B โ C is partial โ partial = deep merge of surgery specs
- bc_merged = compose_configs(surgery_b, surgery_c)
- right_1 = compose_configs(attention_config, surgery_a)
- right_result = compose_configs(right_1, bc_merged)
-
- assert left_result == right_result, "Surgery composition should be associative"
-
- def test_complete_configs_have_no_init_fields(self, attention_config):
- """Verify that compose_configs strips init from complete configs.
-
- This is the key invariant that enables Markovianity:
- - Complete configs (states) have no init fields
- - Surgery specs (transitions) have init fields
- - Plans read init from surgery, not state
- """
- surgery_with_init = {
- "decoder": {
- "block": {
- "mixer": {
- "type": "stochastic",
- "main_mixer_name": "attention",
- "mixers": {
- "attention": {"type": "attention", "init": "transfer"},
- "swa": {"type": "sliding_window", "init": "random", "window_size": 512},
- },
- },
- },
- },
- }
-
- result = compose_configs(attention_config, surgery_with_init)
-
- # Recursively check for init fields
- def has_init(obj):
- if isinstance(obj, dict):
- if "init" in obj:
- return True
- return any(has_init(v) for v in obj.values())
- if isinstance(obj, list):
- return any(has_init(v) for v in obj)
- return False
-
- assert not has_init(result), "Complete configs should have no init fields"
-
- def test_monoid_action_law_additive_surgeries(self):
- """Monoid action law HOLDS for additive surgeries.
-
- Additive surgeries (no type: declaration) support:
- apply(apply(s, t1), t2) == apply(s, t1 โ t2)
-
- This is because additive operations commute nicely:
- "add {a}" then "add {b}" == "add {a, b}"
- """
- # Start with stochastic (additive surgery target)
- s = {
- "model_type": "apriel2",
- "hidden_size": 256,
- "vocab_size": 1000,
- "decoder": {
- "type": "fixed",
- "num_blocks": 2,
- "block": {
- "mixer": {
- "type": "stochastic",
- "main_mixer_name": "attention",
- "mixers": {
- "attention": {"type": "attention", "heads": 8, "head_groups": 4, "head_size": 32},
- },
- },
- "mlp": {"type": "mlp", "intermediate_size": 256},
- "normalization": {"type": "rms_norm", "epsilon": 1e-5},
- },
- },
- }
-
- # Additive surgeries (no type: declaration)
- t1 = {"decoder": {"block": {"mixer": {"mixers": {"swa": {"type": "sliding_window", "window_size": 512}}}}}}
- t2 = {"decoder": {"block": {"mixer": {"mixers": {"mamba": {"type": "mamba", "d_inner": 512}}}}}}
-
- # Path A: Sequential
- s_prime = compose_configs(s, t1)
- s_double_prime_A = compose_configs(s_prime, t2)
-
- # Path B: Composed
- t1_t2 = compose_configs(t1, t2)
- s_double_prime_B = compose_configs(s, t1_t2)
-
- assert s_double_prime_A == s_double_prime_B, "Monoid action law should hold for additive surgeries"
-
- def test_monoid_action_law_replacement_surgeries_fails(self):
- """Monoid action law FAILS for replacement surgeries (by design).
-
- Replacement surgeries (type: stochastic declared) have:
- apply(apply(s, t1), t2) != apply(s, t1 โ t2)
-
- This is FUNDAMENTAL, not a bug:
- - Sequential: "set to {a}" then "set to {b}" โ {b} (second wins)
- - Composed: merge({a}, {b}) = {a,b}, then apply โ {a,b}
-
- These are genuinely different semantics. The failure documents
- the distinction between declarative composition (merge) and
- operational composition (function application).
- """
- s = {
- "model_type": "apriel2",
- "hidden_size": 256,
- "vocab_size": 1000,
- "decoder": {
- "type": "fixed",
- "num_blocks": 2,
- "block": {
- "mixer": {"type": "attention", "heads": 8, "head_groups": 4, "head_size": 32},
- "mlp": {"type": "mlp", "intermediate_size": 256},
- "normalization": {"type": "rms_norm", "epsilon": 1e-5},
- },
- },
- }
-
- # Replacement surgeries (both declare type: stochastic)
- t1 = {
- "decoder": {
- "block": {
- "mixer": {
- "type": "stochastic",
- "main_mixer_name": "attention",
- "mixers": {"attention": {"type": "attention"}},
- }
- }
- }
- }
- t2 = {
- "decoder": {
- "block": {
- "mixer": {
- "type": "stochastic",
- "main_mixer_name": "swa",
- "mixers": {"swa": {"type": "sliding_window", "window_size": 512}},
- }
- }
- }
- }
-
- # Path A: Sequential (second replacement wins)
- s_prime = compose_configs(s, t1)
- s_double_prime_A = compose_configs(s_prime, t2)
-
- # Path B: Composed (declarations merged)
- t1_t2 = compose_configs(t1, t2)
- s_double_prime_B = compose_configs(s, t1_t2)
-
- # They should be DIFFERENT (law fails)
- assert s_double_prime_A != s_double_prime_B, (
- "Monoid action law should FAIL for replacement surgeries"
- )
-
- # Verify the specific difference:
- # Sequential: only swa (second replacement wins)
- # Composed: both attention and swa (merged declarations)
- mixers_A = set(s_double_prime_A["decoder"]["block"]["mixer"]["mixers"].keys())
- mixers_B = set(s_double_prime_B["decoder"]["block"]["mixer"]["mixers"].keys())
-
- assert mixers_A == {"swa"}, "Sequential: second replacement wins"
- assert mixers_B == {"attention", "swa"}, "Composed: declarations merged"
-
class TestCyclingSurgeryGeneration:
"""Tests for the cycling surgery generation functions.
@@ -1936,7 +1551,7 @@ def test_expand_surgery_chain_adds_cycling(self):
# Verify restore flag
assert expanded[0][2] is False # surgery - not restore
assert expanded[1][2] is False # cycle - not restore
- assert expanded[2][2] is True # restore
+ assert expanded[2][2] is True # restore
def test_expand_surgery_chain_preserves_invariant(self):
"""Test that cycling leaves the chain state invariant."""
@@ -1980,3 +1595,151 @@ def test_expand_surgery_chain_preserves_invariant(self):
# After cycling and restore, we should be back to the same state
assert current_config == config_after_original
+
+
+class TestBiasSurgeryChain:
+ """Torture tests for per-layer bias inheritance through surgery operations.
+
+ Uses apriel2_config_with_bias + bias_surgery_chain to test that:
+ - Qwen-style per-layer attention bias (QKV enabled, O disabled) survives surgery
+ - Non-gated MLP per-layer bias (layer_1 enabled, layer_2 disabled) survives surgery
+ - Bias settings are correctly inherited by new sub-mixers
+ - Bias is correctly tracked in surgery plans
+ """
+
+ @pytest.fixture
+ def bias_source_config(self, apriel2_config_with_bias):
+ """Convert Apriel2Config to dict for surgery operations."""
+ return apriel2_config_with_bias.to_dict()
+
+ def test_bias_survives_stochastic_wrapper(self, bias_source_config, bias_surgery_chain):
+ """Test that bias settings survive wrapping in stochastic mixer."""
+ # Apply first surgery (wrap in stochastic)
+ result = compose_configs(bias_source_config, bias_surgery_chain[0])
+
+ # Check attention sub-mixer inherited bias settings
+ mixer = result["decoder"]["block"]["mixer"]
+ assert mixer["type"] == "stochastic"
+
+ attn_mixer = mixer["mixers"]["attention"]
+ assert attn_mixer["query_layer"]["bias"]["enabled"] is True
+ assert attn_mixer["key_layer"]["bias"]["enabled"] is True
+ assert attn_mixer["value_layer"]["bias"]["enabled"] is True
+ assert attn_mixer["dense_layer"]["bias"]["enabled"] is False
+
+ # Check MLP bias survived
+ mlp = result["decoder"]["block"]["mlp"]
+ assert mlp["layer_1"]["bias"]["enabled"] is True
+ assert mlp["layer_2"]["bias"]["enabled"] is False
+
+ def test_new_submixer_inherits_bias(self, bias_source_config, bias_surgery_chain):
+ """Test that new sub-mixers inherit bias from source attention."""
+ # Apply S1 + S2 (wrap in stochastic, add sliding_window)
+ config = bias_source_config
+ for surgery in bias_surgery_chain[:2]:
+ config = compose_configs(config, surgery)
+
+ # sliding_window should inherit bias from source attention
+ mixer = config["decoder"]["block"]["mixer"]
+ sw_mixer = mixer["mixers"]["sliding_window"]
+
+ assert sw_mixer["query_layer"]["bias"]["enabled"] is True
+ assert sw_mixer["key_layer"]["bias"]["enabled"] is True
+ assert sw_mixer["value_layer"]["bias"]["enabled"] is True
+ assert sw_mixer["dense_layer"]["bias"]["enabled"] is False
+
+ def test_full_bias_chain_produces_valid_config(self, bias_source_config, bias_surgery_chain):
+ """Test that full bias surgery chain produces valid config."""
+ config = bias_source_config
+ for surgery in bias_surgery_chain:
+ config = compose_configs(config, surgery)
+
+ # Verify final config structure
+ mixer = config["decoder"]["block"]["mixer"]
+ assert mixer["type"] == "stochastic"
+ assert "attention" in mixer["mixers"]
+ assert "sliding_window" in mixer["mixers"]
+ assert "full_bias_attn" in mixer["mixers"]
+
+ # attention and sliding_window inherit Qwen-style bias
+ for name in ["attention", "sliding_window"]:
+ sub = mixer["mixers"][name]
+ assert sub["query_layer"]["bias"]["enabled"] is True
+ assert sub["dense_layer"]["bias"]["enabled"] is False
+
+ # full_bias_attn has add_linear_biases=True but per-layer settings inherited from
+ # source take precedence, so O bias is still disabled
+ full_bias = mixer["mixers"]["full_bias_attn"]
+ assert full_bias.get("add_linear_biases") is True
+ # Per-layer dense_layer.bias.enabled=False inherited from source takes precedence
+ assert full_bias["dense_layer"]["bias"]["enabled"] is False
+
+ def test_bias_plan_has_correct_mappings(self, bias_source_config, bias_surgery_chain):
+ """Test that surgery plan correctly includes/excludes bias weight mappings."""
+ # Compose config first to get full target config with inherited bias settings
+ target_config = compose_configs(bias_source_config, bias_surgery_chain[0])
+ plan = plan_surgery(bias_source_config, target_config)
+ mapping_strs = [str(k) for k in plan.mappings.keys()]
+
+ # Should have q_proj.bias (enabled)
+ q_bias = [m for m in mapping_strs if "q_proj.bias" in m]
+ assert len(q_bias) > 0, "Should have q_proj.bias mappings"
+
+ # Should NOT have o_proj.bias (disabled)
+ o_bias = [m for m in mapping_strs if "o_proj.bias" in m]
+ assert len(o_bias) == 0, "Should not have o_proj.bias mappings"
+
+ # Should have up_proj.bias (layer_1 enabled)
+ up_bias = [m for m in mapping_strs if "up_proj.bias" in m]
+ assert len(up_bias) > 0, "Should have up_proj.bias mappings"
+
+ # Should NOT have down_proj.bias (layer_2 disabled)
+ down_bias = [m for m in mapping_strs if "down_proj.bias" in m]
+ assert len(down_bias) == 0, "Should not have down_proj.bias mappings"
+
+ def test_bias_chain_produces_working_model(self, bias_source_config, bias_surgery_chain):
+ """Test that bias surgery chain produces a working model."""
+ from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForCausalLM
+
+ # Apply full chain
+ config = bias_source_config
+ for surgery in bias_surgery_chain:
+ config = compose_configs(config, surgery)
+
+ # Create model
+ apriel_config = Apriel2Config(**config)
+ model = Apriel2ForCausalLM(apriel_config)
+ model.eval()
+
+ # Verify model structure has correct biases
+ block = model.model.decoder.blocks[0]
+
+ # attention sub-mixer should have QKV bias, no O bias
+ attn = block.mixer.mixers["attention"]
+ assert attn.q_proj.bias is not None
+ assert attn.k_proj.bias is not None
+ assert attn.v_proj.bias is not None
+ assert attn.o_proj.bias is None
+
+ # sliding_window should also inherit bias settings
+ sw = block.mixer.mixers["sliding_window"]
+ assert sw.q_proj.bias is not None
+ assert sw.o_proj.bias is None
+
+ # full_bias_attn inherits per-layer bias from source (even with add_linear_biases=True,
+ # per-layer settings take precedence in same-type inheritance)
+ full_bias = block.mixer.mixers["full_bias_attn"]
+ assert full_bias.q_proj.bias is not None
+ # O bias is disabled because inherited per-layer dense_layer.bias.enabled=False
+ # takes precedence over add_linear_biases=True
+ assert full_bias.o_proj.bias is None
+
+ # MLP should have layer_1 bias, no layer_2 bias
+ assert block.mlp.up_proj.bias is not None
+ assert block.mlp.down_proj.bias is None
+
+ # Forward pass should work
+ input_ids = torch.randint(0, config["vocab_size"], (1, 10))
+ with torch.no_grad():
+ outputs = model(input_ids, use_cache=False)
+ assert outputs.logits.shape == (1, 10, config["vocab_size"])
diff --git a/fast_llm_external_models/tests/test_apriel2/test_convert_from_llava.py b/fast_llm_external_models/tests/test_apriel2/test_convert_from_llava.py
index a437f920d..f96f5ac40 100644
--- a/fast_llm_external_models/tests/test_apriel2/test_convert_from_llava.py
+++ b/fast_llm_external_models/tests/test_apriel2/test_convert_from_llava.py
@@ -14,23 +14,15 @@
"""
import json
-from pathlib import Path
-import pytest
import torch
from safetensors import safe_open
-from safetensors.torch import save_file
from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config
-from fast_llm_external_models.apriel2.conversion import (
- convert_llava_config as convert_config,
- execute,
- plan_llava_to_apriel2,
- plan_surgery,
-)
+from fast_llm_external_models.apriel2.conversion import convert_llava_config as convert_config
+from fast_llm_external_models.apriel2.conversion import execute, plan_llava_to_apriel2, plan_surgery
from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForConditionalGeneration
-
# =============================================================================
# Config Conversion Tests
# =============================================================================
@@ -330,9 +322,9 @@ def test_plan_keys_match_model_state_dict(self, llava_pixtral_checkpoint):
extra_in_plan = plan_keys - model_keys
# Filter out expected missing keys (caches, positions, etc.)
- missing_in_plan = {k for k in missing_in_plan if not any(
- skip in k.lower() for skip in ["cache", "position", "mask"]
- )}
+ missing_in_plan = {
+ k for k in missing_in_plan if not any(skip in k.lower() for skip in ["cache", "position", "mask"])
+ }
assert not missing_in_plan, f"Model keys not in plan: {sorted(missing_in_plan)[:10]}"
assert not extra_in_plan, f"Plan keys not in model: {sorted(extra_in_plan)[:10]}"
diff --git a/fast_llm_external_models/tests/test_apriel2/test_equivalence.py b/fast_llm_external_models/tests/test_apriel2/test_equivalence.py
index c59ed2000..9b3eb4efe 100644
--- a/fast_llm_external_models/tests/test_apriel2/test_equivalence.py
+++ b/fast_llm_external_models/tests/test_apriel2/test_equivalence.py
@@ -23,9 +23,6 @@
import torch
from transformers import LlavaForConditionalGeneration
-from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForConditionalGeneration
-
-
# =============================================================================
# Input Configuration
# =============================================================================
@@ -487,8 +484,10 @@ def test_batch_processing_behavior(self, model_pair):
batch_tgt = target.get_image_features(pixel_values).view(-1, batch_src.shape[-1])
# Sequential processing
- singles_src = [get_pixtral_vision_features(source, pixel_values[i:i+1]) for i in range(3)]
- singles_tgt = [target.get_image_features(pixel_values[i:i+1]).view(-1, batch_src.shape[-1]) for i in range(3)]
+ singles_src = [get_pixtral_vision_features(source, pixel_values[i : i + 1]) for i in range(3)]
+ singles_tgt = [
+ target.get_image_features(pixel_values[i : i + 1]).view(-1, batch_src.shape[-1]) for i in range(3)
+ ]
single_concat_src = torch.cat(singles_src, dim=0)
single_concat_tgt = torch.cat(singles_tgt, dim=0)
@@ -500,9 +499,9 @@ def test_batch_processing_behavior(self, model_pair):
print(f"Apriel2 batch vs sequential: {tgt_diff:.6f}")
# Both should have the same behavior (within FP tolerance)
- assert abs(src_diff - tgt_diff) < 1e-6, (
- f"Batch processing behavior differs: src={src_diff:.6f}, tgt={tgt_diff:.6f}"
- )
+ assert (
+ abs(src_diff - tgt_diff) < 1e-6
+ ), f"Batch processing behavior differs: src={src_diff:.6f}, tgt={tgt_diff:.6f}"
if __name__ == "__main__":
diff --git a/fast_llm_external_models/tests/test_apriel2/test_expr_plan.py b/fast_llm_external_models/tests/test_apriel2/test_expr_plan.py
index c487ab3a3..2dccac5ad 100644
--- a/fast_llm_external_models/tests/test_apriel2/test_expr_plan.py
+++ b/fast_llm_external_models/tests/test_apriel2/test_expr_plan.py
@@ -1,15 +1,13 @@
"""Tests for the expression-based plan system."""
import json
+
import pytest
import torch
-from fast_llm_external_models.tests.test_apriel2.conftest import requires_cuda
-
from fast_llm_external_models.apriel2.conversion import (
Concat,
EvalKwargs,
- Expr,
ExprAdapter,
ExprPlan,
Init,
@@ -18,10 +16,9 @@
Slice,
StreamingExecutor,
W,
- compose,
execute,
- fuse,
full_slice,
+ fuse,
make_slice,
plan_dil_attention_to_gdn,
plan_kil_attention_to_kda,
@@ -31,6 +28,7 @@
slice_spec,
substitute,
)
+from fast_llm_external_models.tests.test_apriel2.conftest import requires_cuda
def make_eval_kwargs(
@@ -219,10 +217,13 @@ def test_substitute_init_unchanged(self):
def test_substitute_complex(self):
"""Substitute handles complex nested expressions."""
# Concat of Slice(Ref) and Init
- expr = Concat(exprs=(
- Slice(expr=Ref(key=W("a")), slices=((0, 5, None),)),
- Init(shape=(5,), init_type="zeros"),
- ), dim=0)
+ expr = Concat(
+ exprs=(
+ Slice(expr=Ref(key=W("a")), slices=((0, 5, None),)),
+ Init(shape=(5,), init_type="zeros"),
+ ),
+ dim=0,
+ )
bindings = {W("a"): Ref(key=W("source"))}
result = substitute(expr, bindings)
@@ -238,7 +239,13 @@ class TestFuse:
def test_fuse_flatten_concat(self):
"""Fuse flattens nested Concat with same dim."""
inner = Concat(exprs=(Ref(key=W("a")), Ref(key=W("b"))), dim=0)
- outer = Concat(exprs=(inner, Ref(key=W("c")),), dim=0)
+ outer = Concat(
+ exprs=(
+ inner,
+ Ref(key=W("c")),
+ ),
+ dim=0,
+ )
result = fuse(outer)
assert isinstance(result, Concat)
@@ -250,7 +257,13 @@ def test_fuse_flatten_concat(self):
def test_fuse_no_flatten_different_dim(self):
"""Fuse doesn't flatten Concat with different dim."""
inner = Concat(exprs=(Ref(key=W("a")), Ref(key=W("b"))), dim=1)
- outer = Concat(exprs=(inner, Ref(key=W("c")),), dim=0)
+ outer = Concat(
+ exprs=(
+ inner,
+ Ref(key=W("c")),
+ ),
+ dim=0,
+ )
result = fuse(outer)
assert isinstance(result, Concat)
@@ -340,28 +353,34 @@ class TestExprPlan:
def test_plan_define_and_access(self):
"""Plan stores and retrieves expressions."""
- plan = ExprPlan(mappings={
- W("target"): Ref(key=W("source")),
- })
+ plan = ExprPlan(
+ mappings={
+ W("target"): Ref(key=W("source")),
+ }
+ )
assert W("target") in plan
assert isinstance(plan[W("target")], Ref)
def test_plan_source_keys(self):
"""Plan identifies all source references."""
- plan = ExprPlan(mappings={
- W("a"): Ref(key=W("x")),
- W("b"): Concat(exprs=(Ref(key=W("y")), Ref(key=W("z"))), dim=0),
- W("c"): Init(shape=(10,), init_type="zeros"),
- })
+ plan = ExprPlan(
+ mappings={
+ W("a"): Ref(key=W("x")),
+ W("b"): Concat(exprs=(Ref(key=W("y")), Ref(key=W("z"))), dim=0),
+ W("c"): Init(shape=(10,), init_type="zeros"),
+ }
+ )
assert plan.source_keys() == {W("x"), W("y"), W("z")}
def test_plan_target_keys(self):
"""Plan identifies all target keys."""
- plan = ExprPlan(mappings={
- W("a"): Ref(key=W("x")),
- W("b"): Ref(key=W("y")),
- })
+ plan = ExprPlan(
+ mappings={
+ W("a"): Ref(key=W("x")),
+ W("b"): Ref(key=W("y")),
+ }
+ )
assert plan.target_keys() == {W("a"), W("b")}
@@ -386,9 +405,17 @@ def test_plan_summary(self):
def test_plan_fuse(self):
"""Plan fuse applies optimizations."""
inner = Concat(exprs=(Ref(key=W("a")), Ref(key=W("b"))), dim=0)
- plan = ExprPlan(mappings={
- W("out"): Concat(exprs=(inner, Ref(key=W("c")),), dim=0),
- })
+ plan = ExprPlan(
+ mappings={
+ W("out"): Concat(
+ exprs=(
+ inner,
+ Ref(key=W("c")),
+ ),
+ dim=0,
+ ),
+ }
+ )
fused = plan.fuse()
assert isinstance(fused[W("out")], Concat)
@@ -532,9 +559,11 @@ class TestStreamingExecution:
def test_execute_simple(self):
"""Execute simple plan."""
- plan = ExprPlan(mappings={
- W("out"): Ref(key=W("in")),
- })
+ plan = ExprPlan(
+ mappings={
+ W("out"): Ref(key=W("in")),
+ }
+ )
sources = {W("in"): torch.tensor([1.0, 2.0, 3.0])}
result = execute(plan, sources, seed=42)
@@ -544,9 +573,11 @@ def test_execute_simple(self):
def test_execute_concat(self):
"""Execute plan with Concat."""
- plan = ExprPlan(mappings={
- W("combined"): Concat(exprs=(Ref(key=W("a")), Ref(key=W("b"))), dim=0),
- })
+ plan = ExprPlan(
+ mappings={
+ W("combined"): Concat(exprs=(Ref(key=W("a")), Ref(key=W("b"))), dim=0),
+ }
+ )
sources = {
W("a"): torch.ones(2, 3),
@@ -559,14 +590,19 @@ def test_execute_concat(self):
def test_execute_mil_like(self):
"""Execute MIL-like Concat of Slices and Init."""
# Simulated MIL: in_proj = [z, x, B, C]
- plan = ExprPlan(mappings={
- W("in_proj"): Concat(exprs=(
- Init(shape=(4, 8), init_type="zeros"), # z
- Slice(expr=Ref(key=W("v")), slices=((0, 2, None), (None, None, None))), # x
- Slice(expr=Ref(key=W("k")), slices=((0, 2, None), (None, None, None))), # B
- Slice(expr=Ref(key=W("q")), slices=((0, 4, None), (None, None, None))), # C
- ), dim=0),
- })
+ plan = ExprPlan(
+ mappings={
+ W("in_proj"): Concat(
+ exprs=(
+ Init(shape=(4, 8), init_type="zeros"), # z
+ Slice(expr=Ref(key=W("v")), slices=((0, 2, None), (None, None, None))), # x
+ Slice(expr=Ref(key=W("k")), slices=((0, 2, None), (None, None, None))), # B
+ Slice(expr=Ref(key=W("q")), slices=((0, 4, None), (None, None, None))), # C
+ ),
+ dim=0,
+ ),
+ }
+ )
sources = {
W("q"): torch.ones(4, 8),
@@ -583,11 +619,13 @@ def test_execute_mil_like(self):
def test_streaming_execution(self):
"""Streaming executor processes all targets."""
- plan = ExprPlan(mappings={
- W("out1"): Ref(key=W("shared")),
- W("out2"): Ref(key=W("shared")),
- W("out3"): Ref(key=W("unique")),
- })
+ plan = ExprPlan(
+ mappings={
+ W("out1"): Ref(key=W("shared")),
+ W("out2"): Ref(key=W("shared")),
+ W("out3"): Ref(key=W("unique")),
+ }
+ )
load_calls = []
@@ -858,25 +896,23 @@ def test_plan_dil_execution(self):
key_dim = 64
value_dim = 64
- head_k_dim = 16
- head_v_dim = 16
conv_dim = 2 * key_dim + value_dim # 192
# Create attention weights with per-head distinctive values
# Q: each head gets value (head_idx + 1)
q_weight = torch.zeros(64, 64)
for h in range(4):
- q_weight[h*16:(h+1)*16, :] = float(h + 1)
+ q_weight[h * 16 : (h + 1) * 16, :] = float(h + 1)
# K: each head gets value (head_idx + 1) * 10
k_weight = torch.zeros(64, 64)
for h in range(4):
- k_weight[h*16:(h+1)*16, :] = float((h + 1) * 10)
+ k_weight[h * 16 : (h + 1) * 16, :] = float((h + 1) * 10)
# V: each head gets value (head_idx + 1) * 100
v_weight = torch.zeros(64, 64)
for h in range(4):
- v_weight[h*16:(h+1)*16, :] = float((h + 1) * 100)
+ v_weight[h * 16 : (h + 1) * 16, :] = float((h + 1) * 100)
sources = {
W("attn.q_proj.weight"): q_weight,
@@ -894,30 +930,23 @@ def test_plan_dil_execution(self):
# Q_all (rows 0-63): heads 0,1,2,3 concatenated
for h in range(4):
- assert torch.allclose(
- in_proj_qkvz[h*16:(h+1)*16],
- torch.full((16, 64), float(h + 1))
- )
+ assert torch.allclose(in_proj_qkvz[h * 16 : (h + 1) * 16], torch.full((16, 64), float(h + 1)))
# K_all (rows 64-127): heads 0,1,2,3 concatenated
for h in range(4):
assert torch.allclose(
- in_proj_qkvz[key_dim + h*16:key_dim + (h+1)*16],
- torch.full((16, 64), float((h + 1) * 10))
+ in_proj_qkvz[key_dim + h * 16 : key_dim + (h + 1) * 16], torch.full((16, 64), float((h + 1) * 10))
)
# V_all (rows 128-191): heads 0,1,2,3 concatenated
for h in range(4):
assert torch.allclose(
- in_proj_qkvz[2*key_dim + h*16:2*key_dim + (h+1)*16],
- torch.full((16, 64), float((h + 1) * 100))
+ in_proj_qkvz[2 * key_dim + h * 16 : 2 * key_dim + (h + 1) * 16],
+ torch.full((16, 64), float((h + 1) * 100)),
)
# Z_all (rows 192-255): zeros
- assert torch.allclose(
- in_proj_qkvz[2*key_dim + value_dim:],
- torch.zeros(value_dim, 64)
- )
+ assert torch.allclose(in_proj_qkvz[2 * key_dim + value_dim :], torch.zeros(value_dim, 64))
# in_proj_ba should be zeros
in_proj_ba = result[W("in_proj_ba.weight")]
@@ -971,17 +1000,17 @@ def test_plan_dil_execution_gqa(self):
# Q: 4 heads, each with value (head_idx + 1)
q_weight = torch.zeros(64, 64)
for h in range(4):
- q_weight[h*16:(h+1)*16, :] = float(h + 1)
+ q_weight[h * 16 : (h + 1) * 16, :] = float(h + 1)
# K: 2 kv_heads, each with value (head_idx + 1) * 10
k_weight = torch.zeros(32, 64)
for h in range(2):
- k_weight[h*16:(h+1)*16, :] = float((h + 1) * 10)
+ k_weight[h * 16 : (h + 1) * 16, :] = float((h + 1) * 10)
# V: 2 kv_heads, each with value (head_idx + 1) * 100
v_weight = torch.zeros(32, 64)
for h in range(2):
- v_weight[h*16:(h+1)*16, :] = float((h + 1) * 100)
+ v_weight[h * 16 : (h + 1) * 16, :] = float((h + 1) * 100)
sources = {
W("attn.q_proj.weight"): q_weight,
@@ -1007,22 +1036,22 @@ def test_plan_dil_execution_gqa(self):
# K_all (rows 32-63): k_heads 0,1 (maps to source K heads 0,1 via modulo)
# k_head 0 โ source K head 0 (value 10)
- assert torch.allclose(in_proj_qkvz[key_dim:key_dim+16], torch.full((16, 64), 10.0))
+ assert torch.allclose(in_proj_qkvz[key_dim : key_dim + 16], torch.full((16, 64), 10.0))
# k_head 1 โ source K head 1 (value 20)
- assert torch.allclose(in_proj_qkvz[key_dim+16:key_dim+32], torch.full((16, 64), 20.0))
+ assert torch.allclose(in_proj_qkvz[key_dim + 16 : key_dim + 32], torch.full((16, 64), 20.0))
# V_all (rows 64-127): 4 v_heads, tiled from 2 source KV heads via modulo
# v_head 0 โ src_v_head 0 (value 100)
- assert torch.allclose(in_proj_qkvz[2*key_dim:2*key_dim+16], torch.full((16, 64), 100.0))
+ assert torch.allclose(in_proj_qkvz[2 * key_dim : 2 * key_dim + 16], torch.full((16, 64), 100.0))
# v_head 1 โ src_v_head 1 (value 200)
- assert torch.allclose(in_proj_qkvz[2*key_dim+16:2*key_dim+32], torch.full((16, 64), 200.0))
+ assert torch.allclose(in_proj_qkvz[2 * key_dim + 16 : 2 * key_dim + 32], torch.full((16, 64), 200.0))
# v_head 2 โ src_v_head 0 (value 100, tiled)
- assert torch.allclose(in_proj_qkvz[2*key_dim+32:2*key_dim+48], torch.full((16, 64), 100.0))
+ assert torch.allclose(in_proj_qkvz[2 * key_dim + 32 : 2 * key_dim + 48], torch.full((16, 64), 100.0))
# v_head 3 โ src_v_head 1 (value 200, tiled)
- assert torch.allclose(in_proj_qkvz[2*key_dim+48:2*key_dim+64], torch.full((16, 64), 200.0))
+ assert torch.allclose(in_proj_qkvz[2 * key_dim + 48 : 2 * key_dim + 64], torch.full((16, 64), 200.0))
# Z_all (rows 128-191): zeros
- assert torch.allclose(in_proj_qkvz[2*key_dim+value_dim:], torch.zeros(value_dim, 64))
+ assert torch.allclose(in_proj_qkvz[2 * key_dim + value_dim :], torch.zeros(value_dim, 64))
def test_plan_kil_attention_to_kda(self):
"""AIK plan produces correct structure for attention โ KDA conversion."""
@@ -1188,6 +1217,7 @@ def test_compose_llava_to_mamba(self, llava_pixtral_config, apriel2_config_stoch
# Build surgery plan (need intermediate config)
from fast_llm_external_models.apriel2.conversion.llava import convert_config
+
intermediate_config = convert_config(llava_pixtral_config)
target_config = apriel2_config_stochastic.to_dict()
surgery_plan = plan_surgery(intermediate_config, target_config)
@@ -1210,6 +1240,7 @@ def test_execute_composed_pipeline(self, llava_pixtral_checkpoint):
"""
import json
from pathlib import Path
+
from safetensors.torch import load_file
# Load config
@@ -1448,10 +1479,9 @@ def test_comprehensive_conversion_all_mixer_types(self, llava_pixtral_checkpoint
the conversion produced correct keys and shapes.
"""
import json
- from pathlib import Path
from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config
- from fast_llm_external_models.apriel2.convert import build_plan, convert
+ from fast_llm_external_models.apriel2.convert import convert
from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForConditionalGeneration
# Load LLaVA config
@@ -1477,11 +1507,11 @@ def test_comprehensive_conversion_all_mixer_types(self, llava_pixtral_checkpoint
"type": "pattern",
"num_blocks": 5,
"pattern": [
- "attn", # 0: attention โ attention (passthrough)
- "mamba", # 1: attention โ mamba (MIL)
- "gdn", # 2: attention โ gated_delta_net (DIL)
- "stoch_am", # 3: attention โ stochastic(attention + mamba)
- "stoch_sg", # 4: attention โ stochastic(swa + gdn)
+ "attn", # 0: attention โ attention (passthrough)
+ "mamba", # 1: attention โ mamba (MIL)
+ "gdn", # 2: attention โ gated_delta_net (DIL)
+ "stoch_am", # 3: attention โ stochastic(attention + mamba)
+ "stoch_sg", # 4: attention โ stochastic(swa + gdn)
],
"blocks": {
# Pure attention (passthrough from source)
@@ -1609,7 +1639,8 @@ def test_comprehensive_conversion_all_mixer_types(self, llava_pixtral_checkpoint
"type": "attention",
"heads": llava_config["vision_config"]["num_attention_heads"],
"head_groups": llava_config["vision_config"]["num_attention_heads"],
- "head_size": llava_config["vision_config"]["hidden_size"] // llava_config["vision_config"]["num_attention_heads"],
+ "head_size": llava_config["vision_config"]["hidden_size"]
+ // llava_config["vision_config"]["num_attention_heads"],
"add_linear_biases": False,
"causal": False,
"rotary": {
@@ -1688,7 +1719,6 @@ def test_conversion_plan_targets_match_model_state_dict(self, llava_pixtral_conf
This test validates the plan WITHOUT executing it, by comparing
plan target keys against what the model expects.
"""
- import json
from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config
from fast_llm_external_models.apriel2.convert import build_plan
@@ -1703,7 +1733,7 @@ def test_conversion_plan_targets_match_model_state_dict(self, llava_pixtral_conf
expected_keys = set(model.state_dict().keys())
# Get plan target keys
- plan_target_keys = set(str(k) for k in plan.target_keys())
+ plan_target_keys = {str(k) for k in plan.target_keys()}
# Compare
missing_from_plan = expected_keys - plan_target_keys
@@ -1711,3 +1741,214 @@ def test_conversion_plan_targets_match_model_state_dict(self, llava_pixtral_conf
assert not missing_from_plan, f"Plan missing keys that model expects: {sorted(missing_from_plan)}"
assert not extra_in_plan, f"Plan has extra keys model doesn't expect: {sorted(extra_in_plan)}"
+
+
+class TestBiasPlanGeneration:
+ """Test that surgery plans correctly handle per-layer bias configurations.
+
+ These tests verify that plan_surgery correctly includes/excludes bias
+ weight mappings based on the per-layer bias settings:
+ - query_layer.bias.enabled, key_layer.bias.enabled, etc. for attention
+ - layer_1.bias.enabled, layer_2.bias.enabled for MLP
+ """
+
+ @pytest.fixture
+ def source_config_with_bias(self):
+ """Source config with Qwen-style bias (QKV enabled, O disabled)."""
+ return {
+ "model_type": "apriel2",
+ "hidden_size": 256,
+ "vocab_size": 1000,
+ "decoder": {
+ "type": "fixed",
+ "num_blocks": 2,
+ "block": {
+ "mixer": {
+ "type": "attention",
+ "heads": 8,
+ "head_groups": 4,
+ "head_size": 32,
+ "rotary": {"type": "mistral_1d", "theta": 10000.0},
+ # Qwen-style: QKV bias enabled, O bias disabled
+ "query_layer": {"bias": {"enabled": True}},
+ "key_layer": {"bias": {"enabled": True}},
+ "value_layer": {"bias": {"enabled": True}},
+ "dense_layer": {"bias": {"enabled": False}},
+ },
+ "mlp": {
+ "type": "mlp",
+ "intermediate_size": 512,
+ "gated": False,
+ # Per-layer MLP bias: layer_1 enabled, layer_2 disabled
+ "layer_1": {"bias": {"enabled": True}},
+ "layer_2": {"bias": {"enabled": False}},
+ },
+ "normalization": {"type": "rms_norm", "epsilon": 1e-5},
+ },
+ },
+ }
+
+ def test_plan_includes_enabled_attention_biases(self, source_config_with_bias):
+ """Surgery plan includes bias mappings for enabled attention biases."""
+ from fast_llm_external_models.apriel2.conversion.config import compose_configs
+ from fast_llm_external_models.apriel2.conversion.converters import plan_surgery
+
+ target_config = compose_configs(
+ source_config_with_bias,
+ {
+ "decoder": {
+ "block": {
+ "mixer": {
+ "type": "stochastic",
+ "main_mixer_name": "attention",
+ "mixers": {
+ "attention": {"init": "transfer"},
+ },
+ },
+ "mlp": {"init": "transfer"},
+ },
+ },
+ },
+ )
+
+ plan = plan_surgery(source_config_with_bias, target_config)
+ mapping_strs = [str(k) for k in plan.mappings.keys()]
+
+ # Should have q_proj.bias, k_proj.bias, v_proj.bias mappings
+ q_bias = [m for m in mapping_strs if "q_proj.bias" in m]
+ k_bias = [m for m in mapping_strs if "k_proj.bias" in m]
+ v_bias = [m for m in mapping_strs if "v_proj.bias" in m]
+
+ assert len(q_bias) > 0, "Should have q_proj.bias mappings"
+ assert len(k_bias) > 0, "Should have k_proj.bias mappings"
+ assert len(v_bias) > 0, "Should have v_proj.bias mappings"
+
+ def test_plan_excludes_disabled_attention_biases(self, source_config_with_bias):
+ """Surgery plan excludes bias mappings for disabled attention biases."""
+ from fast_llm_external_models.apriel2.conversion.config import compose_configs
+ from fast_llm_external_models.apriel2.conversion.converters import plan_surgery
+
+ target_config = compose_configs(
+ source_config_with_bias,
+ {
+ "decoder": {
+ "block": {
+ "mixer": {
+ "type": "stochastic",
+ "main_mixer_name": "attention",
+ "mixers": {
+ "attention": {"init": "transfer"},
+ },
+ },
+ "mlp": {"init": "transfer"},
+ },
+ },
+ },
+ )
+
+ plan = plan_surgery(source_config_with_bias, target_config)
+ mapping_strs = [str(k) for k in plan.mappings.keys()]
+
+ # Should NOT have o_proj.bias mappings (disabled)
+ o_bias = [m for m in mapping_strs if "o_proj.bias" in m]
+ assert len(o_bias) == 0, f"Should not have o_proj.bias mappings, found: {o_bias}"
+
+ def test_plan_includes_enabled_mlp_biases(self, source_config_with_bias):
+ """Surgery plan includes bias mappings for enabled MLP biases."""
+ from fast_llm_external_models.apriel2.conversion.config import compose_configs
+ from fast_llm_external_models.apriel2.conversion.converters import plan_surgery
+
+ target_config = compose_configs(
+ source_config_with_bias,
+ {
+ "decoder": {
+ "block": {
+ "mixer": {
+ "type": "stochastic",
+ "main_mixer_name": "attention",
+ "mixers": {
+ "attention": {"init": "transfer"},
+ },
+ },
+ "mlp": {"init": "transfer"},
+ },
+ },
+ },
+ )
+
+ plan = plan_surgery(source_config_with_bias, target_config)
+ mapping_strs = [str(k) for k in plan.mappings.keys()]
+
+ # Should have up_proj.bias (layer_1) mappings
+ up_bias = [m for m in mapping_strs if "up_proj.bias" in m]
+ assert len(up_bias) > 0, "Should have up_proj.bias mappings"
+
+ def test_plan_excludes_disabled_mlp_biases(self, source_config_with_bias):
+ """Surgery plan excludes bias mappings for disabled MLP biases."""
+ from fast_llm_external_models.apriel2.conversion.config import compose_configs
+ from fast_llm_external_models.apriel2.conversion.converters import plan_surgery
+
+ target_config = compose_configs(
+ source_config_with_bias,
+ {
+ "decoder": {
+ "block": {
+ "mixer": {
+ "type": "stochastic",
+ "main_mixer_name": "attention",
+ "mixers": {
+ "attention": {"init": "transfer"},
+ },
+ },
+ "mlp": {"init": "transfer"},
+ },
+ },
+ },
+ )
+
+ plan = plan_surgery(source_config_with_bias, target_config)
+ mapping_strs = [str(k) for k in plan.mappings.keys()]
+
+ # Should NOT have down_proj.bias (layer_2) mappings
+ down_bias = [m for m in mapping_strs if "down_proj.bias" in m]
+ assert len(down_bias) == 0, f"Should not have down_proj.bias mappings, found: {down_bias}"
+
+ def test_plan_random_init_creates_init_expressions_for_bias(self, source_config_with_bias):
+ """Random init creates Init expressions for bias weights."""
+ from fast_llm_external_models.apriel2.conversion.converters import plan_surgery
+
+ # Surgery spec - pass directly to plan_surgery (NOT composed, to preserve init)
+ surgery = {
+ "decoder": {
+ "block": {
+ "mixer": {
+ "type": "stochastic",
+ "main_mixer_name": "attention",
+ "mixers": {
+ "attention": {"init": "transfer"},
+ "new_attention": {
+ "type": "attention",
+ "init": "random", # This triggers random init
+ "heads": 8,
+ "head_groups": 4,
+ "head_size": 32,
+ "rotary": {"type": "mistral_1d", "theta": 10000.0},
+ "add_linear_biases": True, # All biases enabled
+ },
+ },
+ },
+ },
+ },
+ }
+
+ # Pass surgery spec directly - init fields are preserved
+ plan = plan_surgery(source_config_with_bias, surgery)
+
+ # Check that new_attention biases use Init expressions
+ new_mixer_bias_keys = [k for k in plan.mappings.keys() if "new_attention" in str(k) and "bias" in str(k)]
+
+ assert len(new_mixer_bias_keys) > 0, "Should have bias mappings for new_attention"
+
+ for key in new_mixer_bias_keys:
+ expr = plan.mappings[key]
+ assert isinstance(expr, Init), f"{key} should be Init, got {type(expr)}"
diff --git a/fast_llm_external_models/tests/test_apriel2/test_integration.py b/fast_llm_external_models/tests/test_apriel2/test_integration.py
new file mode 100644
index 000000000..e84fa06ef
--- /dev/null
+++ b/fast_llm_external_models/tests/test_apriel2/test_integration.py
@@ -0,0 +1,330 @@
+"""Integration tests for Qwen2 -> Apriel2 -> Fast-LLM conversion pipeline.
+
+Tests verify the full conversion chain:
+1. Qwen2 -> Apriel2 (external module conversion)
+2. Apriel2 + Surgery -> Supernet (stochastic mixer creation)
+3. Supernet -> Fast-LLM -> Supernet (roundtrip through training format)
+
+Test Strategy:
+- Use real HuggingFace model (Qwen2.5-0.5B) for meaningful validation
+- Separate config preservation tests from numerical equivalence tests
+- Parameterize both conversion stages AND input variations
+- Single test implementation applied across all stages
+"""
+
+import json
+import tempfile
+from pathlib import Path
+
+import pytest
+import torch
+
+from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config
+from fast_llm_external_models.apriel2.conversion import compose, compose_configs, execute, plan_surgery
+from fast_llm_external_models.apriel2.conversion.expr import W
+from fast_llm_external_models.apriel2.conversion.qwen2.config import convert_config as convert_qwen2_config
+from fast_llm_external_models.apriel2.conversion.qwen2.plan import plan_qwen2_to_apriel2
+from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForCausalLM
+
+from .conftest import requires_fastllm
+
+# =============================================================================
+# Test Input Variations
+# =============================================================================
+
+TEST_INPUTS = pytest.mark.parametrize(
+ "prompts,max_new_tokens",
+ [
+ pytest.param(["Hello world"], 10, id="single_short"),
+ pytest.param(["Hi", "The quick brown fox jumps over the lazy dog"], 20, id="batch_varied"),
+ pytest.param(["Once upon a time"], 50, id="long_generation"),
+ ],
+)
+
+
+# =============================================================================
+# Conversion Fixtures
+# =============================================================================
+
+
+@pytest.fixture(scope="module")
+def qwen2_source():
+ """Load Qwen2.5-0.5B as the source/reference model."""
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
+
+ model_name = "Qwen/Qwen2.5-0.5B"
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32, trust_remote_code=True)
+ config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
+ model.eval()
+
+ if tokenizer.pad_token is None:
+ tokenizer.pad_token = tokenizer.eos_token
+ tokenizer.padding_side = "left"
+
+ return {
+ "model": model,
+ "tokenizer": tokenizer,
+ "config_dict": config.to_dict(),
+ "state_dict": model.state_dict(),
+ }
+
+
+@pytest.fixture(scope="module")
+def apriel2_converted(qwen2_source):
+ """Stage 1: Qwen2 -> Apriel2."""
+ config_dict = convert_qwen2_config(qwen2_source["config_dict"])
+ plan = plan_qwen2_to_apriel2(qwen2_source["config_dict"])
+ weights = execute(plan, {W(k): v for k, v in qwen2_source["state_dict"].items()}, seed=42)
+
+ config = Apriel2Config(**config_dict)
+ model = Apriel2ForCausalLM(config)
+ model.load_state_dict({str(k): v for k, v in weights.items()}, strict=False)
+ model.eval()
+
+ return {"model": model, "config_dict": config_dict, "plan": plan, "name": "Apriel2"}
+
+
+@pytest.fixture(scope="module")
+def supernet_converted(qwen2_source, apriel2_converted):
+ """Stage 2: Apriel2 + Surgery -> Supernet."""
+ surgery_spec = {
+ "decoder": {
+ "block": {
+ "mixer": {
+ "type": "stochastic",
+ "main_mixer_name": "attention",
+ "mixers": {
+ "attention": {"type": "attention", "init": "transfer"},
+ "sliding_window": {
+ "type": "attention",
+ "init": "transfer",
+ "window_size": 4096,
+ },
+ },
+ },
+ },
+ },
+ }
+
+ apriel_config = apriel2_converted["config_dict"]
+ supernet_config = compose_configs(apriel_config, surgery_spec)
+
+ full_plan = compose(
+ apriel2_converted["plan"],
+ plan_surgery(apriel_config, supernet_config),
+ )
+
+ weights = execute(full_plan, {W(k): v for k, v in qwen2_source["state_dict"].items()}, seed=42)
+
+ config = Apriel2Config(**supernet_config)
+ model = Apriel2ForCausalLM(config)
+ model.load_state_dict({str(k): v for k, v in weights.items()}, strict=False)
+ model.eval()
+
+ return {"model": model, "config_dict": supernet_config, "name": "Supernet"}
+
+
+@pytest.fixture(scope="module")
+def roundtrip_converted(supernet_converted, qwen2_source):
+ """Stage 3: Supernet -> Fast-LLM -> Supernet."""
+ if not torch.cuda.is_available():
+ pytest.skip("Roundtrip conversion requires CUDA (integration tests need realistic hardware)")
+
+ from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, CheckpointSaveConfig, FastLLMCheckpointFormat
+ from fast_llm.engine.checkpoint.convert import ConvertConfig
+ from fast_llm.models.gpt.config import GPTModelConfig
+ from fast_llm.models.gpt.conversion.config import Apriel2TextCheckpointFormat
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ tmpdir = Path(tmpdir)
+ supernet_path = tmpdir / "supernet"
+ fastllm_path = tmpdir / "fastllm"
+ roundtrip_path = tmpdir / "roundtrip"
+
+ supernet_converted["model"].save_pretrained(supernet_path)
+ qwen2_source["tokenizer"].save_pretrained(supernet_path)
+
+ ConvertConfig(
+ model=GPTModelConfig,
+ input=CheckpointLoadConfig(path=supernet_path, format=Apriel2TextCheckpointFormat),
+ output=CheckpointSaveConfig(path=fastllm_path, format=FastLLMCheckpointFormat),
+ ).run()
+
+ ConvertConfig(
+ model=GPTModelConfig,
+ input=CheckpointLoadConfig(path=fastllm_path, format=FastLLMCheckpointFormat),
+ output=CheckpointSaveConfig(path=roundtrip_path, format=Apriel2TextCheckpointFormat),
+ ).run()
+
+ model = Apriel2ForCausalLM.from_pretrained(roundtrip_path)
+ model.eval()
+
+ with open(roundtrip_path / "config.json") as f:
+ config_dict = json.load(f)
+
+ yield {"model": model, "config_dict": config_dict, "name": "Roundtrip"}
+
+
+# =============================================================================
+# Parameterized Fixture: All Conversion Stages
+# =============================================================================
+
+
+@pytest.fixture(params=["apriel2", "supernet", "roundtrip"])
+def converted_model(request, apriel2_converted, supernet_converted):
+ """Parameterized fixture providing each conversion stage for testing.
+
+ This allows a single test to run against all stages automatically.
+ """
+ if request.param == "roundtrip":
+ pytest.importorskip("fast_llm")
+ if not torch.cuda.is_available():
+ pytest.skip("Roundtrip tests require CUDA (integration tests need realistic hardware)")
+ # Lazy-load to avoid fixture evaluation when CUDA unavailable
+ roundtrip_converted = request.getfixturevalue("roundtrip_converted")
+ return roundtrip_converted
+
+ return {
+ "apriel2": apriel2_converted,
+ "supernet": supernet_converted,
+ }[request.param]
+
+
+# =============================================================================
+# Config Preservation Tests
+# =============================================================================
+
+
+@pytest.mark.slow
+class TestConfigPreservation:
+ """Verify configs are correctly preserved through the conversion chain."""
+
+ def test_apriel2_structure(self, qwen2_source, apriel2_converted):
+ """Qwen2 -> Apriel2 preserves model dimensions."""
+ qwen = qwen2_source["config_dict"]
+ apriel = apriel2_converted["config_dict"]
+
+ assert apriel["hidden_size"] == qwen["hidden_size"]
+ assert apriel["vocab_size"] == qwen["vocab_size"]
+ assert apriel["decoder"]["num_blocks"] == qwen["num_hidden_layers"]
+
+ def test_apriel2_bias_pattern(self, apriel2_converted):
+ """Qwen2 -> Apriel2 preserves Qwen-style bias (QKV yes, O no)."""
+ mixer = apriel2_converted["config_dict"]["decoder"]["block"]["mixer"]
+
+ assert mixer["query_layer"]["bias"]["enabled"] is True
+ assert mixer["key_layer"]["bias"]["enabled"] is True
+ assert mixer["value_layer"]["bias"]["enabled"] is True
+ assert mixer["dense_layer"]["bias"]["enabled"] is False
+
+ def test_supernet_structure(self, supernet_converted):
+ """Surgery creates correct stochastic mixer structure."""
+ mixer = supernet_converted["config_dict"]["decoder"]["block"]["mixer"]
+
+ assert mixer["type"] == "stochastic"
+ assert mixer["main_mixer_name"] == "attention"
+ assert set(mixer["mixers"].keys()) == {"attention", "sliding_window"}
+
+ def test_supernet_bias_inheritance(self, supernet_converted):
+ """Submixers inherit bias settings from source."""
+ mixer = supernet_converted["config_dict"]["decoder"]["block"]["mixer"]
+
+ for name in ["attention", "sliding_window"]:
+ assert mixer["mixers"][name]["query_layer"]["bias"]["enabled"] is True
+ assert mixer["mixers"][name]["dense_layer"]["bias"]["enabled"] is False
+
+ @requires_fastllm
+ def test_roundtrip_structure(self, roundtrip_converted):
+ """Fast-LLM roundtrip preserves stochastic mixer structure."""
+ mixer = roundtrip_converted["config_dict"]["decoder"]["block"]["mixer"]
+
+ assert mixer["type"] == "stochastic"
+ assert mixer["main_mixer_name"] == "attention"
+ assert set(mixer["mixers"].keys()) == {"attention", "sliding_window"}
+
+ @requires_fastllm
+ def test_roundtrip_bias_preservation(self, roundtrip_converted):
+ """Fast-LLM roundtrip preserves per-layer bias settings."""
+ mixer = roundtrip_converted["config_dict"]["decoder"]["block"]["mixer"]
+
+ for name in ["attention", "sliding_window"]:
+ assert mixer["mixers"][name]["query_layer"]["bias"]["enabled"] is True
+ assert mixer["mixers"][name]["dense_layer"]["bias"]["enabled"] is False
+
+
+# =============================================================================
+# Numerical Equivalence Tests
+# =============================================================================
+
+
+@pytest.mark.slow
+class TestNumericalEquivalence:
+ """Verify all conversion stages produce numerically identical outputs.
+
+ Uses parameterized fixtures to test all stages with all input variations,
+ giving us 3 stages ร 3 inputs = 9 test cases from a single test function.
+ """
+
+ @TEST_INPUTS
+ def test_logits_match(self, qwen2_source, converted_model, prompts, max_new_tokens):
+ """Converted model produces identical logits to source."""
+ tokenizer = qwen2_source["tokenizer"]
+ ref_model = qwen2_source["model"]
+ test_model = converted_model["model"]
+ stage = converted_model["name"]
+
+ inputs = tokenizer(prompts, return_tensors="pt", padding=True)
+ ref_device = next(ref_model.parameters()).device
+ test_device = next(test_model.parameters()).device
+
+ with torch.no_grad():
+ ref_logits = ref_model(
+ input_ids=inputs.input_ids.to(ref_device),
+ attention_mask=inputs.attention_mask.to(ref_device),
+ ).logits.cpu()
+
+ test_logits = test_model(
+ input_ids=inputs.input_ids.to(test_device),
+ attention_mask=inputs.attention_mask.to(test_device),
+ ).logits.cpu()
+
+ max_diff = (ref_logits - test_logits).abs().max().item()
+ assert torch.allclose(
+ ref_logits, test_logits, rtol=1e-4, atol=1e-4
+ ), f"{stage} logits mismatch: max diff = {max_diff:.6f}"
+
+ @TEST_INPUTS
+ def test_generation_match(self, qwen2_source, converted_model, prompts, max_new_tokens):
+ """Converted model produces identical generation to source."""
+ tokenizer = qwen2_source["tokenizer"]
+ ref_model = qwen2_source["model"]
+ test_model = converted_model["model"]
+ stage = converted_model["name"]
+
+ inputs = tokenizer(prompts, return_tensors="pt", padding=True)
+ ref_device = next(ref_model.parameters()).device
+ test_device = next(test_model.parameters()).device
+
+ with torch.no_grad():
+ ref_gen = ref_model.generate(
+ input_ids=inputs.input_ids.to(ref_device),
+ attention_mask=inputs.attention_mask.to(ref_device),
+ max_new_tokens=max_new_tokens,
+ do_sample=False,
+ pad_token_id=tokenizer.pad_token_id,
+ ).cpu()
+
+ test_gen = test_model.generate(
+ input_ids=inputs.input_ids.to(test_device),
+ attention_mask=inputs.attention_mask.to(test_device),
+ max_new_tokens=max_new_tokens,
+ do_sample=False,
+ pad_token_id=tokenizer.pad_token_id,
+ ).cpu()
+
+ assert torch.equal(ref_gen, test_gen), (
+ f"{stage} generation mismatch:\n"
+ f" Reference: {tokenizer.batch_decode(ref_gen, skip_special_tokens=True)}\n"
+ f" Test: {tokenizer.batch_decode(test_gen, skip_special_tokens=True)}"
+ )
diff --git a/fast_llm_external_models/tests/test_apriel2/test_mixer_equivalence.py b/fast_llm_external_models/tests/test_apriel2/test_mixer_equivalence.py
index 1aa8a56d9..c6f3337e8 100644
--- a/fast_llm_external_models/tests/test_apriel2/test_mixer_equivalence.py
+++ b/fast_llm_external_models/tests/test_apriel2/test_mixer_equivalence.py
@@ -28,15 +28,7 @@
import torch
import torch.nn as nn
-from fast_llm_external_models.apriel2.conversion import (
- Concat,
- ExprPlan,
- Ref,
- Slice,
- W,
- execute,
-)
-
+from fast_llm_external_models.apriel2.conversion import Concat, ExprPlan, Ref, Slice, W, execute
# =============================================================================
# Shared Fixtures
@@ -69,10 +61,10 @@ def hidden_size(request):
@pytest.fixture(
params=[
- pytest.param((8, 8, 32), id="mha-8h-32d"), # MHA: 8 heads, 8 kv heads, 32 head_dim
- pytest.param((8, 4, 32), id="gqa-8h4kv-32d"), # GQA: 8 heads, 4 kv heads, 32 head_dim
- pytest.param((8, 2, 64), id="gqa-8h2kv-64d"), # GQA: 8 heads, 2 kv heads, 64 head_dim
- pytest.param((4, 1, 64), id="mqa-4h1kv-64d"), # MQA: 4 heads, 1 kv head, 64 head_dim
+ pytest.param((8, 8, 32), id="mha-8h-32d"), # MHA: 8 heads, 8 kv heads, 32 head_dim
+ pytest.param((8, 4, 32), id="gqa-8h4kv-32d"), # GQA: 8 heads, 4 kv heads, 32 head_dim
+ pytest.param((8, 2, 64), id="gqa-8h2kv-64d"), # GQA: 8 heads, 2 kv heads, 64 head_dim
+ pytest.param((4, 1, 64), id="mqa-4h1kv-64d"), # MQA: 4 heads, 1 kv head, 64 head_dim
]
)
def attention_config(request):
@@ -90,7 +82,7 @@ def attention_config(request):
params=[
pytest.param((8, 4, 32, 32), id="8v-4k-32d"), # 8 value heads, 4 key heads, symmetric dims
pytest.param((8, 2, 64, 64), id="8v-2k-64d"), # 8 value heads, 2 key heads, larger dims
- pytest.param((4, 2, 32, 64), id="4v-2k-asym"), # Asymmetric key/value dims
+ pytest.param((4, 2, 32, 64), id="4v-2k-asym"), # Asymmetric key/value dims
]
)
def gdn_config(request):
@@ -100,9 +92,9 @@ def gdn_config(request):
@pytest.fixture(
params=[
- pytest.param((4, 8), id="4h-8d"), # 4 heads, 8 head_dim (small)
- pytest.param((8, 16), id="8h-16d"), # 8 heads, 16 head_dim (medium)
- pytest.param((4, 32), id="4h-32d"), # 4 heads, 32 head_dim (large head_dim)
+ pytest.param((4, 8), id="4h-8d"), # 4 heads, 8 head_dim (small)
+ pytest.param((8, 16), id="8h-16d"), # 8 heads, 16 head_dim (medium)
+ pytest.param((4, 32), id="4h-32d"), # 4 heads, 32 head_dim (large head_dim)
]
)
def kda_config(request):
@@ -283,9 +275,21 @@ def plan_qwen3next_gdn_to_apriel2(
for g in range(num_k_heads):
base = g * group_size
q_slices.append(Slice(expr=qkvz_ref, slices=((base, base + head_k_dim, None), (None, None, None))))
- k_slices.append(Slice(expr=qkvz_ref, slices=((base + head_k_dim, base + 2 * head_k_dim, None), (None, None, None))))
- v_slices.append(Slice(expr=qkvz_ref, slices=((base + 2 * head_k_dim, base + 2 * head_k_dim + v_per_group, None), (None, None, None))))
- z_slices.append(Slice(expr=qkvz_ref, slices=((base + 2 * head_k_dim + v_per_group, base + group_size, None), (None, None, None))))
+ k_slices.append(
+ Slice(expr=qkvz_ref, slices=((base + head_k_dim, base + 2 * head_k_dim, None), (None, None, None)))
+ )
+ v_slices.append(
+ Slice(
+ expr=qkvz_ref,
+ slices=((base + 2 * head_k_dim, base + 2 * head_k_dim + v_per_group, None), (None, None, None)),
+ )
+ )
+ z_slices.append(
+ Slice(
+ expr=qkvz_ref,
+ slices=((base + 2 * head_k_dim + v_per_group, base + group_size, None), (None, None, None)),
+ )
+ )
in_proj_qkvz_expr = Concat(
exprs=(
@@ -304,8 +308,15 @@ def plan_qwen3next_gdn_to_apriel2(
b_slices, a_slices = [], []
for g in range(num_k_heads):
base = g * ba_per_group
- b_slices.append(Slice(expr=ba_ref, slices=((base, base + num_v_heads // num_k_heads, None), (None, None, None))))
- a_slices.append(Slice(expr=ba_ref, slices=((base + num_v_heads // num_k_heads, base + ba_per_group, None), (None, None, None))))
+ b_slices.append(
+ Slice(expr=ba_ref, slices=((base, base + num_v_heads // num_k_heads, None), (None, None, None)))
+ )
+ a_slices.append(
+ Slice(
+ expr=ba_ref,
+ slices=((base + num_v_heads // num_k_heads, base + ba_per_group, None), (None, None, None)),
+ )
+ )
in_proj_ba_expr = Concat(
exprs=(Concat(exprs=tuple(b_slices), dim=0), Concat(exprs=tuple(a_slices), dim=0)),
@@ -565,6 +576,7 @@ def test_causal_vs_mistral(
):
"""Verify Apriel2Attention (causal) matches MistralAttention output."""
from transformers.models.mistral.modeling_mistral import MistralAttention, MistralRotaryEmbedding
+
from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2Attention
mixer_config = apriel2_config.decoder["block"]["mixer"]
@@ -593,13 +605,20 @@ def test_causal_vs_mistral(
apriel2_attn.eval()
with torch.no_grad():
- mistral_out = mistral_attn(hidden_states, position_embeddings=position_embeddings, attention_mask=causal_mask)[0]
- apriel2_out = apriel2_attn(hidden_states, attention_mask=causal_mask, position_embeddings=position_embeddings)[0]
+ mistral_out = mistral_attn(
+ hidden_states, position_embeddings=position_embeddings, attention_mask=causal_mask
+ )[0]
+ apriel2_out = apriel2_attn(
+ hidden_states, attention_mask=causal_mask, position_embeddings=position_embeddings
+ )[0]
rtol, atol = tolerance
assert_close(
- apriel2_out, mistral_out, rtol=rtol, atol=atol,
- msg=f"Apriel2Attention vs MistralAttention (batch={batch_size}, seq={seq_len}, hidden={hidden_size})"
+ apriel2_out,
+ mistral_out,
+ rtol=rtol,
+ atol=atol,
+ msg=f"Apriel2Attention vs MistralAttention (batch={batch_size}, seq={seq_len}, hidden={hidden_size})",
)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA")
@@ -613,8 +632,9 @@ def test_noncausal_vs_pixtral(
tolerance,
):
"""Verify Apriel2Attention (non-causal) matches PixtralAttention output."""
- from transformers.models.pixtral.modeling_pixtral import PixtralAttention, PixtralRotaryEmbedding
from transformers.models.pixtral.configuration_pixtral import PixtralVisionConfig
+ from transformers.models.pixtral.modeling_pixtral import PixtralAttention, PixtralRotaryEmbedding
+
from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2TextConfig
from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2Attention
@@ -689,8 +709,11 @@ def test_noncausal_vs_pixtral(
rtol, atol = tolerance
assert_close(
- apriel2_out, pixtral_out, rtol=rtol, atol=atol,
- msg=f"Apriel2Attention (non-causal) vs PixtralAttention (batch={batch_size}, seq={seq_len})"
+ apriel2_out,
+ pixtral_out,
+ rtol=rtol,
+ atol=atol,
+ msg=f"Apriel2Attention (non-causal) vs PixtralAttention (batch={batch_size}, seq={seq_len})",
)
@@ -737,6 +760,7 @@ def test_vs_qwen3next(
):
"""Verify Apriel2GatedDeltaNet matches Qwen3NextGatedDeltaNet output."""
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextGatedDeltaNet
+
from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2GatedDeltaNet
value_heads, key_heads, key_head_dim, value_head_dim = gdn_config
@@ -758,8 +782,10 @@ def test_vs_qwen3next(
# Transfer weights
plan = plan_qwen3next_gdn_to_apriel2(
- num_k_heads=key_heads, num_v_heads=value_heads,
- head_k_dim=key_head_dim, head_v_dim=value_head_dim,
+ num_k_heads=key_heads,
+ num_v_heads=value_heads,
+ head_k_dim=key_head_dim,
+ head_v_dim=value_head_dim,
)
source_weights = extract_module_weights(qwen_gdn)
target_weights = execute(plan, source_weights, seed=seed)
@@ -778,8 +804,11 @@ def test_vs_qwen3next(
rtol, atol = tolerance
assert_close(
- apriel2_out, qwen_out, rtol=rtol, atol=atol,
- msg=f"Apriel2GatedDeltaNet vs Qwen3NextGatedDeltaNet (batch={batch_size}, seq={seq_len})"
+ apriel2_out,
+ qwen_out,
+ rtol=rtol,
+ atol=atol,
+ msg=f"Apriel2GatedDeltaNet vs Qwen3NextGatedDeltaNet (batch={batch_size}, seq={seq_len})",
)
@@ -803,6 +832,7 @@ def test_vs_fla(
):
"""Verify Apriel2 KimiDeltaAttention matches FLA KimiDeltaAttention output."""
from fla.layers.kda import KimiDeltaAttention as FLA_KDA
+
from fast_llm_external_models.apriel2.modeling_apriel2 import KimiDeltaAttention as Apriel2_KDA
num_heads, head_dim = kda_config
@@ -853,8 +883,11 @@ def test_vs_fla(
rtol, atol = tolerance
assert_close(
- apriel2_out, fla_out, rtol=rtol, atol=atol,
- msg=f"Apriel2 KDA vs FLA KDA (batch={batch_size}, seq={seq_len}, hidden={hidden_size})"
+ apriel2_out,
+ fla_out,
+ rtol=rtol,
+ atol=atol,
+ msg=f"Apriel2 KDA vs FLA KDA (batch={batch_size}, seq={seq_len}, hidden={hidden_size})",
)
@@ -913,7 +946,4 @@ def test_gdn_fast_vs_slow(self, gdn_config, batch_size):
slow_out = model(hidden_states)[0].clone()
# Looser tolerance for kernel vs reference comparison
- assert_close(
- fast_out, slow_out, rtol=1e-3, atol=1e-3,
- msg="GDN fast path (CUDA) vs slow path (PyTorch)"
- )
+ assert_close(fast_out, slow_out, rtol=1e-3, atol=1e-3, msg="GDN fast path (CUDA) vs slow path (PyTorch)")
diff --git a/fast_llm_external_models/tests/test_apriel2/test_model_structure.py b/fast_llm_external_models/tests/test_apriel2/test_model_structure.py
index 23856be30..56d2bc6a6 100644
--- a/fast_llm_external_models/tests/test_apriel2/test_model_structure.py
+++ b/fast_llm_external_models/tests/test_apriel2/test_model_structure.py
@@ -1,9 +1,9 @@
"""Tests for Apriel2 model structure and architecture validation."""
-import pytest
import torch
-from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForCausalLM
+
from fast_llm_external_models.apriel2.cache import Apriel2Cache, _AttentionCache, _SSMCache
+from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForCausalLM
class TestStochasticMixerStructure:
@@ -14,20 +14,27 @@ def test_all_submixers_present(self, apriel2_config_all_mixers):
model = Apriel2ForCausalLM(apriel2_config_all_mixers)
stochastic_layer = model.model.decoder.blocks[1] # Layer 1 is the "all_mixers" layer
- assert hasattr(stochastic_layer.mixer, 'mixers'), "Stochastic mixer should have 'mixers' attribute"
+ assert hasattr(stochastic_layer.mixer, "mixers"), "Stochastic mixer should have 'mixers' attribute"
assert set(stochastic_layer.mixer.mixers.keys()) == {
- 'attention', 'swa', 'mamba', 'gdn'
+ "attention",
+ "swa",
+ "mamba",
+ "gdn",
}, "Stochastic mixer should contain all 4 configured mixer types"
# Verify each mixer is the correct type
from fast_llm_external_models.apriel2.modeling_apriel2 import (
- Apriel2Attention, Apriel2Mamba, Apriel2GatedDeltaNet
+ Apriel2Attention,
+ Apriel2GatedDeltaNet,
+ Apriel2Mamba,
)
- assert isinstance(stochastic_layer.mixer.mixers['attention'], Apriel2Attention)
- assert isinstance(stochastic_layer.mixer.mixers['swa'], Apriel2Attention) # SWA is Apriel2Attention with sliding_window
- assert isinstance(stochastic_layer.mixer.mixers['mamba'], Apriel2Mamba)
- assert isinstance(stochastic_layer.mixer.mixers['gdn'], Apriel2GatedDeltaNet)
+ assert isinstance(stochastic_layer.mixer.mixers["attention"], Apriel2Attention)
+ assert isinstance(
+ stochastic_layer.mixer.mixers["swa"], Apriel2Attention
+ ) # SWA is Apriel2Attention with sliding_window
+ assert isinstance(stochastic_layer.mixer.mixers["mamba"], Apriel2Mamba)
+ assert isinstance(stochastic_layer.mixer.mixers["gdn"], Apriel2GatedDeltaNet)
def test_main_mixer_is_configured(self, apriel2_config_all_mixers):
"""Verify main_mixer_name is set correctly."""
@@ -44,7 +51,10 @@ def test_cache_has_all_submixer_slots(self, apriel2_config_all_mixers):
assert isinstance(layer_cache, dict), "Stochastic layer cache should be a dict"
assert set(layer_cache.keys()) == {
- 'attention', 'swa', 'mamba', 'gdn'
+ "attention",
+ "swa",
+ "mamba",
+ "gdn",
}, "Cache should have slots for all 4 mixers"
def test_attention_mixers_use_attention_cache(self, apriel2_config_all_mixers):
@@ -53,12 +63,12 @@ def test_attention_mixers_use_attention_cache(self, apriel2_config_all_mixers):
layer_cache = cache.layers[1]
# Attention-based mixers use AttentionCache
- assert isinstance(layer_cache['attention'], _AttentionCache)
- assert isinstance(layer_cache['swa'], _AttentionCache)
+ assert isinstance(layer_cache["attention"], _AttentionCache)
+ assert isinstance(layer_cache["swa"], _AttentionCache)
# SSM-based mixers use SSMCache
- assert isinstance(layer_cache['mamba'], _SSMCache)
- assert isinstance(layer_cache['gdn'], _SSMCache)
+ assert isinstance(layer_cache["mamba"], _SSMCache)
+ assert isinstance(layer_cache["gdn"], _SSMCache)
def test_parameter_counts_differ_by_config(self):
"""Different configs create models with different parameter counts."""
@@ -74,8 +84,10 @@ def test_parameter_counts_differ_by_config(self):
}
config_tiny = Apriel2Config(
- vocab_size=100, hidden_size=64,
- num_attention_heads=4, num_key_value_heads=2,
+ vocab_size=100,
+ hidden_size=64,
+ num_attention_heads=4,
+ num_key_value_heads=2,
decoder={
"type": "fixed",
"num_blocks": 2,
@@ -88,8 +100,10 @@ def test_parameter_counts_differ_by_config(self):
)
config_stochastic = Apriel2Config(
- vocab_size=100, hidden_size=64,
- num_attention_heads=4, num_key_value_heads=2,
+ vocab_size=100,
+ hidden_size=64,
+ num_attention_heads=4,
+ num_key_value_heads=2,
decoder={
"type": "pattern",
"num_blocks": 2,
@@ -106,14 +120,14 @@ def test_parameter_counts_differ_by_config(self):
"main_mixer_name": "attention",
"mixers": {
"attention": attn_config,
- "mamba": {"type": "mamba", "conv_bias": True, "dt_proj_bias": True}
- }
+ "mamba": {"type": "mamba", "conv_bias": True, "dt_proj_bias": True},
+ },
},
"mlp": {"type": "mlp", "intermediate_size": 256, "gated": True},
"normalization": {"type": "rms_norm"},
- }
- }
- }
+ },
+ },
+ },
)
model_tiny = Apriel2ForCausalLM(config_tiny)
@@ -122,8 +136,9 @@ def test_parameter_counts_differ_by_config(self):
params_tiny = sum(p.numel() for p in model_tiny.parameters())
params_stochastic = sum(p.numel() for p in model_stochastic.parameters())
- assert params_stochastic > params_tiny, \
- "Stochastic mixer should have more parameters (has both attention and mamba)"
+ assert (
+ params_stochastic > params_tiny
+ ), "Stochastic mixer should have more parameters (has both attention and mamba)"
def test_weights_are_initialized(self, apriel2_config_all_mixers):
"""Verify model weights are initialized (not all zeros/constant)."""
@@ -136,9 +151,7 @@ def test_weights_are_initialized(self, apriel2_config_all_mixers):
# Basic sanity: at least some parameters should be non-zero
non_zero_params = sum(
- not torch.all(p == 0)
- for mixer in stochastic_layer.mixer.mixers.values()
- for p in mixer.parameters()
+ not torch.all(p == 0) for mixer in stochastic_layer.mixer.mixers.values() for p in mixer.parameters()
)
assert non_zero_params > 0, "At least some mixer parameters should be non-zero"
diff --git a/fast_llm_external_models/tests/test_apriel2/test_modeling.py b/fast_llm_external_models/tests/test_apriel2/test_modeling.py
index 5dbd36159..8e2f610bb 100644
--- a/fast_llm_external_models/tests/test_apriel2/test_modeling.py
+++ b/fast_llm_external_models/tests/test_apriel2/test_modeling.py
@@ -2,18 +2,23 @@
import pytest
import torch
+
from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForCausalLM
class TestApriel2Modeling:
"""End-to-end tests for Apriel2 model with different configurations."""
- @pytest.mark.parametrize("config_name", [
- "apriel2_config_tiny",
- "apriel2_config_stochastic",
- "apriel2_config_multi_mixer",
- "apriel2_config_all_mixers" # Tests all 4 mixer types
- ])
+ @pytest.mark.parametrize(
+ "config_name",
+ [
+ "apriel2_config_tiny",
+ "apriel2_config_stochastic",
+ "apriel2_config_multi_mixer",
+ "apriel2_config_all_mixers", # Tests all 4 mixer types
+ "apriel2_config_with_bias", # Tests per-layer bias and non-gated MLP
+ ],
+ )
def test_model_end_to_end(self, config_name, request):
"""Test instantiation, forward pass, cache correctness, and generation.
@@ -42,7 +47,7 @@ def test_model_end_to_end(self, config_name, request):
# 2. Forward pass - basic shape validation
outputs = model(input_ids, use_cache=False)
assert outputs.logits.shape == (2, seq_len, config.vocab_size)
- assert hasattr(outputs, 'logits')
+ assert hasattr(outputs, "logits")
# 3. Verify cache is actually being used (not dormant)
split_pos = 30
@@ -52,28 +57,23 @@ def test_model_end_to_end(self, config_name, request):
assert outputs_part1.past_key_values is not None
outputs_correct_cache = model(
- input_ids[:, split_pos:split_pos+1],
- past_key_values=outputs_part1.past_key_values,
- use_cache=True
+ input_ids[:, split_pos : split_pos + 1], past_key_values=outputs_part1.past_key_values, use_cache=True
)
# Test 1: Empty cache should give different results than filled cache
# This verifies cache is being used at all
from fast_llm_external_models.apriel2.cache import Apriel2Cache, _AttentionCache
+
empty_cache = Apriel2Cache(config)
outputs_empty_cache = model(
- input_ids[:, split_pos:split_pos+1],
- past_key_values=empty_cache,
- use_cache=True
+ input_ids[:, split_pos : split_pos + 1], past_key_values=empty_cache, use_cache=True
)
- cache_affects_output = not torch.allclose(
- outputs_correct_cache.logits,
- outputs_empty_cache.logits,
- atol=1e-3
- )
- assert cache_affects_output, f"Cache appears dormant for {config_name} - empty cache gives same results as filled cache"
+ cache_affects_output = not torch.allclose(outputs_correct_cache.logits, outputs_empty_cache.logits, atol=1e-3)
+ assert (
+ cache_affects_output
+ ), f"Cache appears dormant for {config_name} - empty cache gives same results as filled cache"
# Test 2: Corrupted cache (zeros) should give different results than correct cache
# This verifies the actual cache VALUES are being used
@@ -98,17 +98,15 @@ def test_model_end_to_end(self, config_name, request):
corrupted_layer[name].value = torch.zeros_like(correct_sub.value)
outputs_corrupted_cache = model(
- input_ids[:, split_pos:split_pos+1],
- past_key_values=corrupted_cache,
- use_cache=True
+ input_ids[:, split_pos : split_pos + 1], past_key_values=corrupted_cache, use_cache=True
)
cache_values_matter = not torch.allclose(
- outputs_correct_cache.logits,
- outputs_corrupted_cache.logits,
- atol=1e-3
+ outputs_correct_cache.logits, outputs_corrupted_cache.logits, atol=1e-3
)
- assert cache_values_matter, f"Cache values not used for {config_name} - zeroed cache gives same results as correct cache"
+ assert (
+ cache_values_matter
+ ), f"Cache values not used for {config_name} - zeroed cache gives same results as correct cache"
# 4. Cache correctness - validate cache produces same results as no-cache
# Compute full sequence without cache
@@ -117,18 +115,14 @@ def test_model_end_to_end(self, config_name, request):
# Compute in two steps with cache
outputs_part1 = model(input_ids[:, :split_pos], use_cache=True)
outputs_part2 = model(
- input_ids[:, split_pos:split_pos+1],
- past_key_values=outputs_part1.past_key_values,
- use_cache=True
+ input_ids[:, split_pos : split_pos + 1], past_key_values=outputs_part1.past_key_values, use_cache=True
)
# Logits should match between cached and non-cached
# Note: GPU execution with bfloat16/float16 has lower precision than CPU float32,
# so we use a looser tolerance here.
assert torch.allclose(
- outputs_full.logits[:, split_pos, :],
- outputs_part2.logits[:, 0, :],
- atol=1e-3
+ outputs_full.logits[:, split_pos, :], outputs_part2.logits[:, 0, :], atol=1e-3
), f"Cache correctness failed for {config_name}: cached and non-cached logits differ"
# 5. Generation - end-to-end validation
diff --git a/fast_llm_external_models/tests/test_apriel2/test_plan_execution.py b/fast_llm_external_models/tests/test_apriel2/test_plan_execution.py
new file mode 100644
index 000000000..ca0c8739f
--- /dev/null
+++ b/fast_llm_external_models/tests/test_apriel2/test_plan_execution.py
@@ -0,0 +1,598 @@
+"""test_plan_execution.py - Plan execution and algebraic composition laws.
+
+This module provides rigorous, parameterized tests for the mathematical properties
+that the conversion system must satisfy. Each test class corresponds to one
+algebraic structure, and each test method verifies one specific law.
+
+Conceptual Types
+================
+
+The conversion system operates on three conceptual types (all ``dict`` at runtime):
+
+- **S (State)**: Complete config without ``init`` fields
+- **P (Partial Surgery)**: Incomplete config, may have ``init`` fields
+- **T (Transition Spec)**: Complete config WITH ``init`` fields
+
+Algebraic Structures
+====================
+
+1. **Partial Surgeries (P)** form a **Monoid** under deep merge::
+
+ compose_configs : P ร P โ P
+ Identity: {}
+ Associativity: (p1 โ p2) โ p3 = p1 โ (p2 โ p3)
+
+2. **Surgeries act on States** to produce Transition Specs::
+
+ compose_configs : S ร P โ T
+ compose_configs : T ร P โ T
+
+ Action law (additive surgeries): (s ยท p1) ยท p2 = s ยท (p1 โ p2)
+
+3. **Plans** form a **Category** with composition::
+
+ compose : Plan(AโB) ร Plan(BโC) โ Plan(AโC)
+ Associativity: (P1 โ P2) โ P3 = P1 โ (P2 โ P3)
+
+4. **plan_surgery is a Functor** from config pairs to plans::
+
+ plan_surgery : S ร T โ Plan
+ Functoriality: compose(plan(S,T1), plan(T1,T2)) โก plan(S,T2)
+
+ This is semantic equivalence: both produce identical weights when executed.
+
+Important Behaviors Tested
+==========================
+
+- **init stripping**: Between surgery iterations, T โ S conversion via
+ ``strip_init_fields()`` ensures ``init: random`` applies only to the surgery
+ that introduces a component.
+
+- **Bias inheritance**: Per-layer bias settings propagate through surgery chains.
+
+- **Plan composition**: Composed plans produce identical weights to direct plans.
+
+Design Principles
+=================
+
+- Each law gets ONE parameterized test, not multiple similar tests
+- Fixtures provide diverse configs (with/without biases)
+- Corner cases are covered via parameterization, not test proliferation
+- Tests document the laws they verify in their docstrings
+"""
+
+from functools import reduce
+
+import pytest
+import torch
+
+from fast_llm_external_models.apriel2.conversion import (
+ Concat,
+ ExprPlan,
+ Init,
+ Ref,
+ Slice,
+ W,
+ compose,
+ compose_configs,
+ execute,
+ plan_surgery,
+)
+
+# Import shared helper from conftest
+from fast_llm_external_models.tests.test_apriel2.conftest import make_weights_for_config
+
+# =============================================================================
+# Fixtures: Use shared fixtures from conftest.py where possible
+# =============================================================================
+# - base_config_dict: Complete config without biases (Llama-style)
+# - base_config_with_bias_dict: Complete config with QKV biases
+# - additive_surgery_chain: [wrap_stochastic, add_sliding_window, add_gdn]
+# =============================================================================
+
+
+# =============================================================================
+# Test: Plan Composition Associativity
+# =============================================================================
+
+
+class TestPlanCompositionAssociativity:
+ """
+ LAW: Plan composition is associative.
+
+ (Pโ โ Pโ) โ Pโ = Pโ โ (Pโ โ Pโ)
+
+ where โ denotes compose(P1, P2).
+
+ This must hold for the AST structure, not just semantic equivalence.
+ """
+
+ @pytest.mark.parametrize("expr_type", ["ref_chain", "with_concat", "with_slice", "with_init"])
+ def test_associativity(self, expr_type):
+ """Plan composition is associative for various expression types."""
+ # Build three plans that can be composed
+ if expr_type == "ref_chain":
+ p1 = ExprPlan(mappings={W("b"): Ref(key=W("a"))})
+ p2 = ExprPlan(mappings={W("c"): Ref(key=W("b"))})
+ p3 = ExprPlan(mappings={W("d"): Ref(key=W("c"))})
+ elif expr_type == "with_concat":
+ p1 = ExprPlan(mappings={W("x"): Ref(key=W("a")), W("y"): Ref(key=W("b"))})
+ p2 = ExprPlan(mappings={W("xy"): Concat(exprs=(Ref(key=W("x")), Ref(key=W("y"))), dim=0)})
+ p3 = ExprPlan(mappings={W("final"): Ref(key=W("xy"))})
+ elif expr_type == "with_slice":
+ p1 = ExprPlan(mappings={W("full"): Ref(key=W("src"))})
+ p2 = ExprPlan(mappings={W("part"): Slice(expr=Ref(key=W("full")), slices=((0, 5, None),))})
+ p3 = ExprPlan(mappings={W("out"): Ref(key=W("part"))})
+ elif expr_type == "with_init":
+ p1 = ExprPlan(mappings={W("x"): Ref(key=W("a"))})
+ p2 = ExprPlan(
+ mappings={W("y"): Concat(exprs=(Ref(key=W("x")), Init(shape=(5,), init_type="zeros")), dim=0)}
+ )
+ p3 = ExprPlan(mappings={W("z"): Ref(key=W("y"))})
+
+ left = compose(compose(p1, p2), p3)
+ right = compose(p1, compose(p2, p3))
+
+ assert left.mappings == right.mappings, f"Associativity failed for {expr_type}"
+
+
+# =============================================================================
+# Test: Functoriality of plan_surgery (THE CRITICAL PROPERTY)
+# =============================================================================
+
+
+class TestPlanSurgeryFunctoriality:
+ """
+ LAW: plan_surgery is functorial with respect to config composition.
+
+ For a surgery chain Pโ, Pโ, ..., Pโ applied to base state Sโ::
+
+ Tโ = compose_configs(Sโ, Pโ) # S ร P โ T
+ Tโ = compose_configs(Tโ, Pโ) # T ร P โ T (no stripping!)
+ ...
+ Tโ = compose_configs(Tโโโ, Pโ)
+
+ Plan functoriality says::
+
+ compose(plan(Sโ,Tโ), plan(Tโ,Tโ), ...) โก plan(Sโ, Tโ)
+
+ where โก denotes semantic equivalence (identical weights when executed).
+
+ NOTE: This tests T ร P composition WITHOUT stripping between steps.
+ This differs from build_plan which strips (T โ S) between iterations.
+ Both patterns are valid:
+
+ - Without stripping: init fields accumulate, testing plan composition purity
+ - With stripping: init consumed per-step, testing real usage (see
+ test_build_plan_strips_init_between_iterations)
+
+ The functoriality law holds in both cases because plan composition
+ correctly substitutes Ref expressions with their definitions.
+ """
+
+ @pytest.mark.parametrize("chain_length", [1, 2, 3])
+ @pytest.mark.parametrize("use_bias", [True, False])
+ def test_functoriality(
+ self,
+ chain_length,
+ use_bias,
+ base_config_dict,
+ base_config_with_bias_dict,
+ additive_surgery_chain,
+ ):
+ """
+ Composed incremental plans produce same weights as direct plan.
+
+ Parameterized over:
+ - chain_length: Number of surgeries (1, 2, or 3)
+ - use_bias: Whether base config has biases
+ """
+ base_config = base_config_with_bias_dict if use_bias else base_config_dict
+ surgeries = additive_surgery_chain[:chain_length]
+
+ # Build config chain: Cโ โ Cโ โ ... โ Cโ
+ configs = [base_config]
+ for s in surgeries:
+ configs.append(compose_configs(configs[-1], s))
+
+ # Build incremental plans: Pโ = plan_surgery(Cโโโ, Cโ)
+ plans = [plan_surgery(configs[i], configs[i + 1]) for i in range(len(surgeries))]
+
+ # Compose all incremental plans
+ composed_plan = reduce(compose, plans)
+
+ # Build direct plan: plan_surgery(Cโ, Cโ)
+ direct_plan = plan_surgery(configs[0], configs[-1])
+
+ # Execute both on same weights
+ weights = make_weights_for_config(base_config)
+ composed_weights = execute(composed_plan, weights, seed=42)
+ direct_weights = execute(direct_plan, weights, seed=42)
+
+ # Verify semantic equivalence
+ assert set(composed_weights.keys()) == set(
+ direct_weights.keys()
+ ), f"Key sets differ for chain_length={chain_length}, use_bias={use_bias}"
+
+ for key in composed_weights:
+ assert torch.allclose(
+ composed_weights[key], direct_weights[key], atol=1e-6
+ ), f"Weight mismatch for {key} with chain_length={chain_length}, use_bias={use_bias}"
+
+ @pytest.mark.parametrize("split_point", [1, 2])
+ def test_arbitrary_grouping(
+ self,
+ split_point,
+ base_config_with_bias_dict,
+ additive_surgery_chain,
+ ):
+ """
+ Any grouping of surgery chain produces same result.
+
+ For surgeries [Sโ, Sโ, Sโ], tests that:
+ - compose(Pโ, compose(Pโ, Pโ))
+ - compose(compose(Pโ, Pโ), Pโ)
+ - plan_surgery(Cโ, Cโ)
+
+ all produce identical weights.
+ """
+ surgeries = additive_surgery_chain
+
+ # Build config chain
+ configs = [base_config_with_bias_dict]
+ for s in surgeries:
+ configs.append(compose_configs(configs[-1], s))
+
+ # Build incremental plans
+ plans = [plan_surgery(configs[i], configs[i + 1]) for i in range(3)]
+
+ # Different groupings
+ left_grouped = compose(compose(plans[0], plans[1]), plans[2])
+ right_grouped = compose(plans[0], compose(plans[1], plans[2]))
+ direct = plan_surgery(configs[0], configs[-1])
+
+ # Execute all
+ weights = make_weights_for_config(base_config_with_bias_dict)
+ results = {
+ "left": execute(left_grouped, weights, seed=42),
+ "right": execute(right_grouped, weights, seed=42),
+ "direct": execute(direct, weights, seed=42),
+ }
+
+ # All must match
+ keys = set(results["left"].keys())
+ assert keys == set(results["right"].keys()) == set(results["direct"].keys())
+
+ for key in keys:
+ assert torch.allclose(results["left"][key], results["right"][key], atol=1e-6)
+ assert torch.allclose(results["left"][key], results["direct"][key], atol=1e-6)
+
+
+# =============================================================================
+# Test: Bias Inheritance Preservation (Regression for the specific bug)
+# =============================================================================
+
+
+class TestBiasInheritancePreservation:
+ """
+ PROPERTY: Per-layer bias settings must be preserved through surgery chains.
+
+ When a surgery spec does not mention bias settings, they must be inherited
+ from the source config. This is the specific failure mode of the build_plan
+ bug: passing partial surgery specs to plan_surgery lost inherited fields.
+
+ This test verifies the SYMPTOM (missing biases) rather than the LAW
+ (functoriality). It's kept as a focused regression test.
+ """
+
+ @pytest.mark.parametrize("num_surgeries", [1, 2, 3])
+ def test_qkv_biases_preserved_through_chain(
+ self,
+ num_surgeries,
+ base_config_with_bias_dict,
+ additive_surgery_chain,
+ ):
+ """QKV biases (enabled in source) appear in plan after N surgeries."""
+ surgeries = additive_surgery_chain[:num_surgeries]
+
+ # Build config and plan chain
+ configs = [base_config_with_bias_dict]
+ for s in surgeries:
+ configs.append(compose_configs(configs[-1], s))
+
+ plans = [plan_surgery(configs[i], configs[i + 1]) for i in range(num_surgeries)]
+ final_plan = reduce(compose, plans) if len(plans) > 1 else plans[0]
+
+ # Check bias keys present
+ target_keys = {str(k) for k in final_plan.target_keys()}
+
+ assert any("q_proj.bias" in k for k in target_keys), f"q_proj.bias missing after {num_surgeries} surgeries"
+ assert any("k_proj.bias" in k for k in target_keys), f"k_proj.bias missing after {num_surgeries} surgeries"
+ assert any("v_proj.bias" in k for k in target_keys), f"v_proj.bias missing after {num_surgeries} surgeries"
+ # O bias should NOT be present (disabled in source)
+ assert not any(
+ "o_proj.bias" in k for k in target_keys
+ ), f"o_proj.bias should not be present (disabled in source)"
+
+ def test_bias_values_preserved(
+ self,
+ base_config_with_bias_dict,
+ additive_surgery_chain,
+ ):
+ """Bias tensor values are correctly transferred, not just keys."""
+ surgery = additive_surgery_chain[0] # wrap_stochastic
+ c1 = compose_configs(base_config_with_bias_dict, surgery)
+ plan = plan_surgery(base_config_with_bias_dict, c1)
+
+ weights = make_weights_for_config(base_config_with_bias_dict)
+ result = execute(plan, weights, seed=42)
+
+ # Verify values match (not just that keys exist)
+ for i in range(base_config_with_bias_dict["decoder"]["num_blocks"]):
+ src_key = W(f"model.decoder.blocks.{i}.mixer.q_proj.bias")
+ dst_key = W(f"model.decoder.blocks.{i}.mixer.mixers.attention.q_proj.bias")
+
+ assert dst_key in result, f"Missing {dst_key}"
+ assert torch.allclose(weights[src_key], result[dst_key]), f"Bias values differ for block {i}"
+
+
+# =============================================================================
+# Test: build_plan Integration (Regression test for convert.py)
+# =============================================================================
+
+
+class TestBuildPlanIntegration:
+ """
+ REGRESSION: build_plan must compose configs before calling plan_surgery.
+
+ The bug was:
+ plan_surgery(current_config, surgery_config) # WRONG: partial
+
+ Should be:
+ target = compose_configs(current_config, surgery_config)
+ plan_surgery(current_config, target) # CORRECT: complete
+
+ This test verifies the fix in convert.py's build_plan function.
+ """
+
+ @pytest.mark.parametrize("num_surgeries", [1, 2])
+ def test_build_plan_preserves_inherited_fields(
+ self,
+ num_surgeries,
+ base_config_with_bias_dict,
+ additive_surgery_chain,
+ ):
+ """build_plan produces plans with inherited bias mappings."""
+ from fast_llm_external_models.apriel2.convert import build_plan
+
+ surgeries = additive_surgery_chain[:num_surgeries]
+
+ plan, final_config = build_plan(
+ base_config_with_bias_dict,
+ surgeries,
+ source_format="apriel2",
+ )
+
+ # Verify inherited biases in config
+ if num_surgeries >= 1:
+ attn = final_config["decoder"]["block"]["mixer"]["mixers"]["attention"]
+ assert attn.get("query_layer", {}).get("bias", {}).get("enabled") is True
+
+ # Verify bias mappings in plan
+ target_keys = {str(k) for k in plan.target_keys()}
+ assert any(
+ "q_proj.bias" in k for k in target_keys
+ ), f"build_plan with {num_surgeries} surgeries missing q_proj.bias"
+
+
+# =============================================================================
+# Test: init Field Preservation (Critical for random initialization)
+# =============================================================================
+
+
+class TestInitFieldPreservation:
+ """
+ PROPERTY: The `init` field must be visible to plan_surgery.
+
+ The `init` field controls weight initialization mode:
+ - `init: transfer` โ use weight transfer/conversion
+ - `init: random` โ use random initialization
+
+ compose_configs must preserve `init` so plan_surgery can see it.
+ Stripping happens only at final output (when saving to disk).
+ """
+
+ def test_init_random_produces_init_expression(self, base_config_with_bias_dict):
+ """Surgery with init: random produces Init expressions in plan."""
+ surgery = {
+ "decoder": {
+ "block": {
+ "mixer": {
+ "type": "stochastic",
+ "main_mixer_name": "attention",
+ "mixers": {
+ "attention": {"init": "transfer"},
+ "gdn": {"type": "gdn", "init": "random", "convolution_layer": {"kernel_size": 4}},
+ },
+ },
+ },
+ },
+ }
+
+ target = compose_configs(base_config_with_bias_dict, surgery)
+ plan = plan_surgery(base_config_with_bias_dict, target)
+
+ # Check that GDN weights use Init expressions (random init)
+ target_keys = {str(k) for k in plan.target_keys()}
+ gdn_keys = [k for k in target_keys if "gdn" in k.lower()]
+
+ assert len(gdn_keys) > 0, "No GDN keys in plan"
+
+ # Verify at least one GDN weight uses Init (random initialization)
+ has_init_expr = False
+ for key in plan.target_keys():
+ if "gdn" in str(key).lower():
+ expr = plan.mappings[key]
+ if isinstance(expr, Init):
+ has_init_expr = True
+ break
+ # Also check inside Concat/other composite expressions
+ if hasattr(expr, "exprs"):
+ for sub in expr.exprs:
+ if isinstance(sub, Init):
+ has_init_expr = True
+ break
+
+ assert has_init_expr, "init: random should produce Init expressions for GDN weights"
+
+ def test_init_transfer_produces_ref_expression(self, base_config_with_bias_dict):
+ """Surgery with init: transfer produces Ref expressions (weight transfer)."""
+ surgery = {
+ "decoder": {
+ "block": {
+ "mixer": {
+ "type": "stochastic",
+ "main_mixer_name": "attention",
+ "mixers": {
+ "attention": {"init": "transfer"},
+ },
+ },
+ },
+ },
+ }
+
+ target = compose_configs(base_config_with_bias_dict, surgery)
+ plan = plan_surgery(base_config_with_bias_dict, target)
+
+ # Check that attention weights use Ref expressions (transfer)
+ has_ref = False
+ for key in plan.target_keys():
+ if "attention" in str(key) and "q_proj.weight" in str(key):
+ expr = plan.mappings[key]
+ if isinstance(expr, Ref):
+ has_ref = True
+ break
+
+ assert has_ref, "init: transfer should produce Ref expressions for attention weights"
+
+ def test_build_plan_respects_init_random(self, base_config_with_bias_dict):
+ """build_plan correctly uses init: random for weight initialization."""
+ from fast_llm_external_models.apriel2.convert import build_plan
+
+ # Mamba requires many config fields for random init
+ surgery = {
+ "decoder": {
+ "block": {
+ "mixer": {
+ "type": "stochastic",
+ "main_mixer_name": "attention",
+ "mixers": {
+ "attention": {"init": "transfer"},
+ "mamba": {
+ "type": "mamba",
+ "init": "random",
+ "d_inner": 512,
+ "d_state": 16,
+ "dt_rank": 16,
+ "d_xb": 64,
+ "d_conv": 4,
+ "repeat_kv_before_conv": False,
+ "conv_bias": True,
+ "dt_proj_bias": True,
+ "dt_min": 0.001,
+ "dt_max": 0.1,
+ "dt_init_floor": 1e-4,
+ },
+ },
+ },
+ },
+ },
+ }
+
+ plan, final_config = build_plan(
+ base_config_with_bias_dict,
+ [surgery],
+ source_format="apriel2",
+ )
+
+ # Verify mamba weights use Init (random init)
+ has_mamba_init = False
+ for key in plan.target_keys():
+ key_str = str(key)
+ if "mamba" in key_str:
+ expr = plan.mappings[key]
+ if isinstance(expr, Init):
+ has_mamba_init = True
+ break
+
+ assert has_mamba_init, "build_plan should use Init for init: random components"
+
+ def test_build_plan_strips_init_between_iterations(self, base_config_with_bias_dict):
+ """build_plan strips init between iterations (T โ S conversion).
+
+ This tests that the intermediate state between surgeries has no init fields.
+ The composed plan will show Init expressions because plan composition
+ substitutes Ref โ Init, but the semantics are correct: GDN is initialized
+ once (in surgery 1), not re-randomized in surgery 2.
+ """
+ from fast_llm_external_models.apriel2.conversion import compose_configs, plan_surgery, strip_init_fields
+
+ # Surgery 1: Add GDN with random init
+ surgery1 = {
+ "decoder": {
+ "block": {
+ "mixer": {
+ "type": "stochastic",
+ "main_mixer_name": "attention",
+ "mixers": {
+ "attention": {"init": "transfer"},
+ "gdn": {
+ "type": "gdn",
+ "init": "random",
+ "convolution_layer": {"kernel_size": 4},
+ },
+ },
+ },
+ },
+ },
+ }
+
+ # Surgery 2: Add sliding window (doesn't mention GDN)
+ surgery2 = {
+ "decoder": {
+ "block": {
+ "mixer": {
+ "mixers": {
+ "sliding_window": {"init": "transfer", "window_size": 512},
+ },
+ },
+ },
+ },
+ }
+
+ # Simulate build_plan's iteration loop
+ s0 = base_config_with_bias_dict
+
+ # Iteration 1
+ t1 = compose_configs(s0, surgery1)
+ assert t1["decoder"]["block"]["mixer"]["mixers"]["gdn"].get("init") == "random"
+ s1 = strip_init_fields(t1)
+ assert s1["decoder"]["block"]["mixer"]["mixers"]["gdn"].get("init") is None
+
+ # Iteration 2: s1 has no init for GDN
+ t2 = compose_configs(s1, surgery2)
+ assert (
+ t2["decoder"]["block"]["mixer"]["mixers"]["gdn"].get("init") is None
+ ), "GDN should have no init in T2 (wasn't in surgery2, stripped from s1)"
+
+ # plan_surgery(s1, t2) should use Ref for GDN (transfer, not random)
+ plan2 = plan_surgery(s1, t2)
+ gdn_uses_ref = False
+ for key in plan2.target_keys():
+ if "gdn" in str(key):
+ expr = plan2.mappings[key]
+ if isinstance(expr, Ref):
+ gdn_uses_ref = True
+ break
+
+ assert gdn_uses_ref, "plan_surgery(s1, t2) should use Ref for GDN (transfer from s1)"
diff --git a/setup.py b/setup.py
index b273e077e..5c4d0def6 100644
--- a/setup.py
+++ b/setup.py
@@ -1,6 +1,6 @@
-import sys
-import re
import pathlib
+import re
+import sys
try:
import pybind11
@@ -18,6 +18,7 @@
print(f"Error: setuptools version {_SETUPTOOLS_MIN_VERSION} " "or greater is required")
sys.exit(1)
+
def get_version():
"""Read version from fast_llm/__init__.py"""
init_file = pathlib.Path(__file__).parent.joinpath("fast_llm", "__init__.py").read_text()
@@ -26,6 +27,7 @@ def get_version():
return version_match.group(1)
raise RuntimeError("Unable to find version string in fast_llm/__init__.py")
+
cpp_extension = setuptools.Extension(
"fast_llm.csrc.data",
sources=["fast_llm/csrc/data.cpp"],
diff --git a/tests/data/test_tokenizer.py b/tests/data/test_tokenizer.py
index c7fdef9ca..4e9e2fdd5 100644
--- a/tests/data/test_tokenizer.py
+++ b/tests/data/test_tokenizer.py
@@ -40,3 +40,263 @@ def test_tokenize_with_spans(common_tokenizer, spans, expected_token_spans, expe
expected_token_spans = [(begin + 1, end + 1) for begin, end in expected_token_spans]
Assert.eq(tokens.tolist(), expected_tokens)
Assert.eq(token_spans, expected_token_spans)
+
+
+def test_validate_chat_template_no_template(common_tokenizer):
+ """Tokenizer without chat template raises."""
+ with pytest.raises(ValueError, match="does not have a chat template"):
+ common_tokenizer.validate_chat_template()
+
+
+def test_validate_chat_template_no_markers(common_tokenizer):
+ """Tokenizer with chat template but no markers raises."""
+ common_tokenizer.tokenizer.chat_template = "{{ messages }}"
+ with pytest.raises(ValueError, match="does not contain.*generation"):
+ common_tokenizer.validate_chat_template()
+
+
+def test_validate_chat_template_with_markers(common_tokenizer):
+ """Tokenizer with generation markers validates."""
+ common_tokenizer.tokenizer.chat_template = "{% generation %}{{ m }}{% endgeneration %}"
+ common_tokenizer.validate_chat_template()
+
+
+# Realistic chat template following HF conventions (e.g., SmolLM3):
+# The generation block includes the full assistant turn: opening tag, content, and closing tag.
+# This ensures the model learns to emit the closing tag.
+CHAT_TEMPLATE = (
+ "{% for message in messages %}"
+ "{% if message.role == 'assistant' %}"
+ "{% generation %}{{ message.content }}{% endgeneration %}"
+ "{% else %}"
+ "<{{ message.role }}>{{ message.content }}{{ message.role }}>"
+ "{% endif %}"
+ "{% endfor %}"
+)
+
+
+@pytest.mark.parametrize(
+ ("messages", "expected_tokens", "expected_loss_masking_spans"),
+ (
+ # Single turn: full assistant turn (Hello) is trainable
+ # 15 tokens, trainable indices 7-13, loss mask spans cover 0-6 and 14
+ (
+ [{"role": "user", "content": "Hi"}, {"role": "assistant", "content": "Hello"}],
+ [49152, 27, 789, 29, 16946, 750, 789, 2293, 17822, 29, 7371, 750, 17822, 29, 49152],
+ [(0, 7), (14, 15)],
+ ),
+ # Multi-turn: both assistant turns are fully trainable
+ # 27 tokens, trainable indices 7-13 and 19-25
+ (
+ [
+ {"role": "user", "content": "A"},
+ {"role": "assistant", "content": "B"},
+ {"role": "user", "content": "C"},
+ {"role": "assistant", "content": "D"},
+ ],
+ [
+ 49152,
+ 27,
+ 789,
+ 29,
+ 32,
+ 750,
+ 789,
+ 2293,
+ 17822,
+ 29,
+ 33,
+ 750,
+ 17822,
+ 2293,
+ 789,
+ 29,
+ 34,
+ 750,
+ 789,
+ 2293,
+ 17822,
+ 29,
+ 35,
+ 750,
+ 17822,
+ 29,
+ 49152,
+ ],
+ [(0, 7), (14, 19), (26, 27)],
+ ),
+ # System + user + assistant: full assistant turn trainable
+ # 23 tokens, trainable indices 15-21
+ (
+ [
+ {"role": "system", "content": "You are helpful."},
+ {"role": "user", "content": "Hi"},
+ {"role": "assistant", "content": "Hello"},
+ ],
+ [
+ 49152,
+ 27,
+ 3144,
+ 29,
+ 5815,
+ 1139,
+ 44569,
+ 6928,
+ 3144,
+ 2293,
+ 789,
+ 29,
+ 16946,
+ 750,
+ 789,
+ 2293,
+ 17822,
+ 29,
+ 7371,
+ 750,
+ 17822,
+ 29,
+ 49152,
+ ],
+ [(0, 15), (22, 23)],
+ ),
+ # User only: no trainable tokens
+ # 9 tokens, no trainable indices
+ (
+ [{"role": "user", "content": "Hi"}],
+ [49152, 27, 789, 29, 16946, 750, 789, 29, 49152],
+ [(0, 9)],
+ ),
+ # Long multi-turn (85 tokens, 3 assistant responses with tags, tests span machinery)
+ # Trainable: indices 27-40, 49-62, 70-83
+ (
+ [
+ {"role": "system", "content": "You are a helpful assistant that answers questions."},
+ {"role": "user", "content": "What is the capital of France?"},
+ {"role": "assistant", "content": "The capital of France is Paris."},
+ {"role": "user", "content": "What about Germany?"},
+ {"role": "assistant", "content": "The capital of Germany is Berlin."},
+ {"role": "user", "content": "And Italy?"},
+ {"role": "assistant", "content": "The capital of Italy is Rome."},
+ ],
+ [
+ 49152,
+ 27,
+ 3144,
+ 29,
+ 5815,
+ 1139,
+ 373,
+ 44569,
+ 2424,
+ 11886,
+ 954,
+ 15737,
+ 14516,
+ 6928,
+ 3144,
+ 2293,
+ 789,
+ 29,
+ 13938,
+ 438,
+ 331,
+ 25016,
+ 457,
+ 12409,
+ 562,
+ 35838,
+ 789,
+ 2293,
+ 17822,
+ 29,
+ 2111,
+ 25016,
+ 457,
+ 12409,
+ 562,
+ 438,
+ 4235,
+ 280,
+ 6928,
+ 17822,
+ 2293,
+ 789,
+ 29,
+ 13938,
+ 5028,
+ 759,
+ 42226,
+ 35838,
+ 789,
+ 2293,
+ 17822,
+ 29,
+ 2111,
+ 25016,
+ 457,
+ 759,
+ 42226,
+ 438,
+ 29784,
+ 3556,
+ 6928,
+ 17822,
+ 2293,
+ 789,
+ 29,
+ 1996,
+ 4413,
+ 3326,
+ 35838,
+ 789,
+ 2293,
+ 17822,
+ 29,
+ 2111,
+ 25016,
+ 457,
+ 4413,
+ 3326,
+ 438,
+ 613,
+ 1361,
+ 6928,
+ 17822,
+ 29,
+ 49152,
+ ],
+ [(0, 27), (41, 49), (63, 70), (84, 85)],
+ ),
+ ),
+)
+def test_tokenize_chat(common_tokenizer, messages, expected_tokens, expected_loss_masking_spans):
+ common_tokenizer.tokenizer.chat_template = CHAT_TEMPLATE
+ tokens, loss_masking_spans = common_tokenizer.tokenize_chat(messages)
+ Assert.eq(tokens.tolist(), expected_tokens)
+ Assert.eq(loss_masking_spans, expected_loss_masking_spans)
+
+
+@pytest.mark.parametrize(
+ ("train_mask", "expected_loss_spans"),
+ (
+ # All masked (no trainable tokens)
+ ([False, False, False], [(0, 3)]),
+ # All trainable (no spans)
+ ([True, True, True], []),
+ # Single trainable at start
+ ([True, False, False], [(1, 3)]),
+ # Single trainable at end
+ ([False, False, True], [(0, 2)]),
+ # Single trainable in middle
+ ([False, True, False], [(0, 1), (2, 3)]),
+ # Multiple trainable regions (simulates multi-turn conversation)
+ ([False, False, True, True, False, False, True, True, True, False], [(0, 2), (4, 6), (9, 10)]),
+ # Alternating
+ ([False, True, False, True, False], [(0, 1), (2, 3), (4, 5)]),
+ ),
+)
+def test_train_mask_to_loss_spans(train_mask, expected_loss_spans):
+ from fast_llm.data.preprocessing.tokenizer import _train_mask_to_loss_spans
+
+ Assert.eq(_train_mask_to_loss_spans(train_mask), expected_loss_spans)