diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md index 50c5a2c1c..a09f78c6c 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.md +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -8,26 +8,26 @@ assignees: '' --- # ๐ŸŽฏ **Goal (What & Why)** -> **Clearly state the purpose of this feature.** +> **Clearly state the purpose of this feature.** > _(Example: Add FP8 support using torchao to improve training throughput by 1.5x.)_ # ๐Ÿš€ **Execution Plan** -> _(This section may start as an incomplete draft but must be defined before implementation begins.)_ +> _(This section may start as an incomplete draft but must be defined before implementation begins.)_ ### **Step 1: What is the smallest working version?** -> _(Describe the simplest way to implement this feature with minimal effort.)_ +> _(Describe the simplest way to implement this feature with minimal effort.)_ -### **Step 2: What additional optimizations are possible (but optional)?** -> _(List potential refinements that can be added in later PRs if needed.)_ +### **Step 2: What additional optimizations are possible (but optional)?** +> _(List potential refinements that can be added in later PRs if needed.)_ # ๐Ÿ“Œ **Acceptance Criteria** (Must-Haves for Completion) -* The feature must be **functional and tested**. -* The implementation must be **documented in practical terms**. -* The PR must include a **performance/impact summary**. -* **No refactors unless directly necessary** for feature completion. +* The feature must be **functional and tested**. +* The implementation must be **documented in practical terms**. +* The PR must include a **performance/impact summary**. +* **No refactors unless directly necessary** for feature completion. # ๐Ÿ› ๏ธ **Project Management** - [ ] **Assign the project to the Fast-LLM project.** - [ ] **Set the `Estimate` field (in days) in the GitHub project.** - [ ] **Use the `Size` field to categorize the PR size (Small/Medium/Large).** -- [ ] **Assign an owner when opening the issue.** +- [ ] **Assign an owner when opening the issue.** diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 41a2fe7ff..0ed4696da 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -65,7 +65,7 @@ def build(self, preprocessing: PreprocessingConfig) -> SamplableDataset[SampleTy def _load_config(self) -> SampledDatasetConfig[SampleType]: assert self.path.is_file(), f"File {self.path} does not exist." config = yaml.safe_load(self.path.open("r")) - Assert.eq(config.keys(), {"config", "metadata"}) + # TODO: Assert.eq(config.keys(), {"config", "metadata"}) # Disabled for backward compat if config.keys() == {"config", "metadata"}: # Newer format with metadata config = config["config"] diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 503b400c3..a1aadf40a 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -15,30 +15,81 @@ from fast_llm.data.preparator.gpt_memmap.prepare import GPTMemmapDatasetPreparator -@config_class() +@config_class(registry=True) class LanguageModelSourceConfig(Config): """ - A schema holding the name of each relevant column in the dataset. - Setting optional entries will enable the associated feature. + Abstract base class for data source schemas. + + Use `type: document` (default) for documents with text, optional span annotations, and optional images. + Use `type: conversation` for structured chat/conversation datasets. + """ + + @classmethod + def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: + if cls is LanguageModelSourceConfig and cls.get_subclass(default.get("type")) is None: + # Default to DocumentSourceConfig when type is not specified + return DocumentSourceConfig._from_dict(default, strict) + return super()._from_dict(default, strict=strict) + + @functools.cached_property + def columns(self) -> list[str]: + """Columns to read from the dataset.""" + raise NotImplementedError + + @functools.cached_property + def has_loss_masking_span(self) -> bool: + return False + + @functools.cached_property + def has_preference_spans(self) -> bool: + return False + + @functools.cached_property + def has_images(self) -> bool: + return False + + +@config_class(dynamic_type={LanguageModelSourceConfig: "document"}) +class DocumentSourceConfig(LanguageModelSourceConfig): + """ + Source schema for document datasets with text, optional span annotations, and optional images. + + The dataset should have a text column containing the document text. + Optionally, it can have additional columns for: + - Loss masking spans: character ranges to mask from loss computation + - Preference spans: chosen/rejected text for DPO training + - Images: image data with character positions for multimodal training """ text: str = Field( default="text", - desc="Field of the dataset to use.", + desc="Field containing the document text.", + hint=FieldHint.optional, + ) + loss_masking_spans: str | None = Field( + default=None, + desc="Field containing character spans to mask for loss computation.", hint=FieldHint.optional, ) - loss_masking_spans: None | str = Field( - default=None, desc="Field containing character spans to mask for loss computation", hint=FieldHint.optional + chosen_span: str | None = Field( + default=None, + desc="Field containing chosen text for preference optimization.", + hint=FieldHint.optional, ) - chosen_span: None | str = Field( - default=None, desc="Field containing chosen text for preference optimization", hint=FieldHint.optional + rejected_span: str | None = Field( + default=None, + desc="Field containing rejected text for preference optimization.", + hint=FieldHint.optional, ) - rejected_span: None | str = Field( - default=None, desc="Field containing rejected text for preference optimization", hint=FieldHint.optional + images: str | None = Field( + default=None, + desc="Field containing images.", + hint=FieldHint.optional, ) - images: None | str = Field(default=None, desc="Field containing images", hint=FieldHint.optional) - image_positions: None | str = Field( - default=None, desc="Field containing image positions in the text.", hint=FieldHint.optional + image_positions: str | None = Field( + default=None, + desc="Field containing image positions in the text.", + hint=FieldHint.optional, ) @functools.cached_property @@ -48,6 +99,8 @@ def columns(self) -> list[str]: columns.append(self.loss_masking_spans) if self.has_preference_spans: columns.extend([self.chosen_span, self.rejected_span]) + if self.has_images: + columns.extend([self.images, self.image_positions]) return columns @functools.cached_property @@ -67,7 +120,50 @@ def has_images(self) -> bool: def _validate(self): super()._validate() if self.has_preference_spans and self.has_loss_masking_span: - raise ValueError(f"Can not enable both loss masking and preference spans.") + raise ValueError("Cannot enable both loss masking and preference spans.") + + +@config_class(dynamic_type={LanguageModelSourceConfig: "conversation"}) +class ConversationSourceConfig(LanguageModelSourceConfig): + """ + Source schema for chat/conversation datasets (e.g., Tulu 3, ShareGPT, OpenAI format). + + The dataset should have a messages column containing a list of message dicts, + where each message has 'role' and 'content' keys. Common roles include: + - 'system': System prompt + - 'user': User input + - 'assistant': Model response (trained on by default) + - 'tool': Tool/function results + - 'ipython': Code execution results + + The conversation is formatted using the tokenizer's chat template, which must + contain {% generation %}...{% endgeneration %} markers to define which content + to train on. Loss masking spans are automatically computed from these markers. + + Example dataset format: + { + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello!"}, + {"role": "assistant", "content": "Hi there!"}, + ] + } + """ + + messages: str = Field( + default="messages", + desc="Field containing the conversation messages list. Each message should have 'role' and 'content' keys.", + hint=FieldHint.core, + ) + + @functools.cached_property + def columns(self) -> list[str]: + return [self.messages] + + @functools.cached_property + def has_loss_masking_span(self) -> bool: + # Conversation format always generates loss masking spans from chat template markers + return True @config_class() diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index e0f5f02fc..325d33c43 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -28,7 +28,12 @@ ) from fast_llm.data.dataset.memmap import MemmapDataset from fast_llm.data.preparator.config import DatasetPreparator -from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig, LanguageModelSourceConfig +from fast_llm.data.preparator.gpt_memmap.config import ( + ConversationSourceConfig, + DocumentSourceConfig, + GPTMemmapDatasetPreparatorConfig, + LanguageModelSourceConfig, +) from fast_llm.data.preprocessing.abstract import NullPreprocessingConfig from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.data.preprocessing.tokenizer import Tokenizer @@ -132,6 +137,10 @@ def run(self) -> None: # Load tokenizer self._tokenizer = self._config.tokenizer.get_tokenizer() + # Validate chat template for conversation format + if isinstance(self._source_schema, ConversationSourceConfig): + self._tokenizer.validate_chat_template() + # Decide the datatype based on the tokenizer vocabulary size self._data_type = ( get_unsigned_integer_type(self._tokenizer.vocab_size) @@ -216,92 +225,112 @@ def _preprocessing_config(self) -> LanguageModelPreprocessingConfig: ) def _prepare_sample(self, sample: dict[str, typing.Any]) -> LanguageModelSample: - text = sample[self._source_schema.text] - all_spans = [] - if self._source_schema.has_loss_masking_span: - # Spans are typically stored in the (begin, last) format. We convert to (begin, end) range format. - loss_masking_spans = _sort_spans( - (SpanType.loss_masking, (begin, last + 1)) - for begin, last in np.array(sample[self._source_schema.loss_masking_spans], dtype=np.int32) - .reshape(-1, 2) - .tolist() + token_spans_by_type = collections.defaultdict(list) + image_patches = image_token_maps = image_position_ids = patch_counts = None + + if isinstance(self._source_schema, ConversationSourceConfig): + # Conversation format: tokenize messages and get loss masking spans from chat template + tokens, loss_masking_spans = self._tokenizer.tokenize_chat( + sample[self._source_schema.messages], + True, + True, + data_type=self._data_type, ) - all_spans.extend(loss_masking_spans) - - if self._source_schema.has_preference_spans: - full_chosen_text = text + sample[self._source_schema.chosen_span] + self._tokenizer.tokenizer.eos_token - full_rejected_text = self._tokenizer.tokenizer.bos_token + text + sample[self._source_schema.rejected_span] - # compute chosen span - chosen_spans = [(SpanType.chosen, (len(text), len(full_chosen_text)))] - - # compute rejected span - rejected_span = [ - ( - SpanType.rejected, - ( - len(full_chosen_text) + len(self._tokenizer.tokenizer.bos_token) + len(text), - len(full_chosen_text) + len(full_rejected_text), - ), + token_spans_by_type[SpanType.loss_masking] = loss_masking_spans + elif isinstance(self._source_schema, DocumentSourceConfig): + # Document format: use the text-spans pipeline + text = sample[self._source_schema.text] + all_spans = [] + + if self._source_schema.has_loss_masking_span: + # Spans are typically stored in the (begin, last) format. We convert to (begin, end) range format. + loss_masking_spans = _sort_spans( + (SpanType.loss_masking, (begin, last + 1)) + for begin, last in np.array(sample[self._source_schema.loss_masking_spans], dtype=np.int32) + .reshape(-1, 2) + .tolist() + ) + all_spans.extend(loss_masking_spans) + + if self._source_schema.has_preference_spans: + full_chosen_text = text + sample[self._source_schema.chosen_span] + self._tokenizer.tokenizer.eos_token + full_rejected_text = ( + self._tokenizer.tokenizer.bos_token + text + sample[self._source_schema.rejected_span] ) - ] - # pack texts - text = full_chosen_text + full_rejected_text - all_spans.extend(chosen_spans + rejected_span) - - if self._source_schema.has_images: - # Get the images and positions, sorted by position. - images, image_positions = ( - zip( - *sorted( - zip( - sample[self._source_schema.images], - sample[self._source_schema.image_positions], - strict=True, + # compute chosen span + chosen_spans = [(SpanType.chosen, (len(text), len(full_chosen_text)))] + + # compute rejected span + rejected_span = [ + ( + SpanType.rejected, + ( + len(full_chosen_text) + len(self._tokenizer.tokenizer.bos_token) + len(text), + len(full_chosen_text) + len(full_rejected_text), ), - key=lambda x: x[1], ) + ] + # pack texts + text = full_chosen_text + full_rejected_text + all_spans.extend(chosen_spans + rejected_span) + + if self._source_schema.has_images: + # Get the images and positions, sorted by position. + images, image_positions = ( + zip( + *sorted( + zip( + sample[self._source_schema.images], + sample[self._source_schema.image_positions], + strict=True, + ), + key=lambda x: x[1], + ) + ) + if len(sample[self._source_schema.images]) > 0 + else ([], []) ) - if len(sample[self._source_schema.images]) > 0 - else ([], []) - ) - # Get the image patches and associated data. - image_patches, image_position_ids, image_token_maps, image_token_ids, patch_counts = ( - self._config.image_patches.get_patches_from_images(images, self._data_type) + # Get the image patches and associated data. + image_patches, image_position_ids, image_token_maps, image_token_ids, patch_counts = ( + self._config.image_patches.get_patches_from_images(images, self._data_type) + ) + patch_count_cumsum = padded_cumsum(patch_counts).tolist() + # Add an empty "span" at each image position so we know where to insert them in the tokenized sequence. + all_spans.extend([(SpanType.image, (position, position)) for position in image_positions]) + + # Sort the spans by location (begin), keeping track of their type. + # Note: overlapping spans are not supported (explicit assertion in the tokenizer). + span_types, spans = zip(*_sort_spans(all_spans)) if all_spans else ([], []) + # Tokenize the text, and determine the span locations in the tokenized text. + tokens, token_spans = self._tokenizer.tokenize_with_spans( + text, True, True, text_spans=spans, data_type=self._data_type ) - patch_count_cumsum = padded_cumsum(patch_counts).tolist() - # Add an empty "span" at each image position so we know where to insert them in the tokenized sequence. - all_spans.extend([(SpanType.image, (position, position)) for position in image_positions]) - - # Sort the spans by location (begin), keeping track of their type. - # Note: overlapping spans are not supported (explicit assertion in the tokenizer). - span_types, spans = zip(*_sort_spans(all_spans)) if all_spans else ([], []) - # Tokenize the text, and determine the span locations in the tokenized text. - tokens, token_spans = self._tokenizer.tokenize_with_spans( - text, True, True, text_spans=spans, data_type=self._data_type - ) - # Gather token spans by type. - token_spans_by_type = collections.defaultdict(list) - if self._source_schema.has_images: - # Insert the image token ids in the token sequence and shift the spans accordingly. - tokens_shift = 0 - image_index = 0 - for span_type, (begin, end) in zip(span_types, token_spans, strict=True): - # Account for the tokens already inserted. - begin = begin + tokens_shift - end = end + tokens_shift - if span_type == SpanType.image: - # Shift the token map to the image location. - image_token_maps[patch_count_cumsum[image_index] : patch_count_cumsum[image_index + 1]] += begin - # Insert the placeholder and image break tokens. - tokens = torch.cat([tokens[:begin], image_token_ids[image_index], tokens[begin:]]) - tokens_shift += len(image_token_ids[image_index]) - image_index += 1 - else: - token_spans_by_type[span_type].append((begin, end)) + # Gather token spans by type. + if self._source_schema.has_images: + # Insert the image token ids in the token sequence and shift the spans accordingly. + tokens_shift = 0 + image_index = 0 + for span_type, (begin, end) in zip(span_types, token_spans, strict=True): + # Account for the tokens already inserted. + begin = begin + tokens_shift + end = end + tokens_shift + if span_type == SpanType.image: + # Shift the token map to the image location. + image_token_maps[ + patch_count_cumsum[image_index] : patch_count_cumsum[image_index + 1] + ] += begin + # Insert the placeholder and image break tokens. + tokens = torch.cat([tokens[:begin], image_token_ids[image_index], tokens[begin:]]) + tokens_shift += len(image_token_ids[image_index]) + image_index += 1 + else: + token_spans_by_type[span_type].append((begin, end)) + else: + for span_type, token_span in zip(span_types, token_spans, strict=True): + token_spans_by_type[span_type].append(token_span) else: - for span_type, token_span in zip(span_types, token_spans, strict=True): - token_spans_by_type[span_type].append(token_span) + raise NotImplementedError(f"Unsupported source schema type: {type(self._source_schema)}") sample_size = len(tokens) diff --git a/fast_llm/data/preprocessing/tokenizer.py b/fast_llm/data/preprocessing/tokenizer.py index abfb5b3d2..157744f51 100644 --- a/fast_llm/data/preprocessing/tokenizer.py +++ b/fast_llm/data/preprocessing/tokenizer.py @@ -213,3 +213,105 @@ def _remove_delimiters( @property def eod(self): return self.eod_id + + @staticmethod + def _has_generation_markers(template: str | None) -> bool: + """Check if a template has generation markers.""" + return template is not None and "{% generation %}" in template + + def validate_chat_template(self) -> None: + """ + Validate the tokenizer's chat template has generation markers. + + Raises: + ValueError: If the tokenizer lacks a chat template or generation markers. + """ + template = self.tokenizer.chat_template + + if template is None: + raise ValueError( + "Tokenizer does not have a chat template. " + "Conversation format requires a tokenizer with a built-in chat template " + "containing {% generation %}...{% endgeneration %} markers." + ) + + if not self._has_generation_markers(template): + raise ValueError( + "Tokenizer's chat template does not contain {% generation %}...{% endgeneration %} markers. " + "These markers are required to determine which tokens to train on. " + "Please use a tokenizer with generation markers in its chat template." + ) + + def tokenize_chat( + self, + messages: list[dict[str, str]], + begin: bool = True, + end: bool = True, + data_type: DataType = DataType.int64, + ) -> tuple["torch.Tensor", list[tuple[int, int]]]: + """ + Apply chat template and return (tokens, loss_masking_spans). + + The loss_masking_spans mark token ranges to EXCLUDE from training (where the model + should not learn). These are derived from the chat template's generation markers - + tokens outside {% generation %}...{% endgeneration %} blocks are masked. + """ + import torch + + result = self.tokenizer.apply_chat_template( + messages, + tokenize=True, + return_assistant_tokens_mask=True, + return_dict=True, + add_generation_prompt=False, + ) + tokens = result["input_ids"] + train_mask = result["assistant_masks"] + + # Prepend BOS / append EOS if not already present anywhere in the sequence. + # We check anywhere (not just first/last) because some chat templates add trailing + # whitespace after the final EOS token, e.g. "<|im_end|>\n". + prepend_bos = begin and self.bod_id not in tokens + append_eos = end and self.eod_id not in tokens + tokens = [self.bod_id] * prepend_bos + list(tokens) + [self.eod_id] * append_eos + train_mask = [False] * prepend_bos + [bool(m) for m in train_mask] + [False] * append_eos + + # Convert boolean train mask to loss masking spans (spans where train_mask[i] == False) + loss_masking_spans = _train_mask_to_loss_spans(train_mask) + + if self._config.max_vocab_size is not None: + tokens = ( + torch.tensor( + tokens, + dtype=torch.int64 if len(self.tokenizer) > torch.iinfo(data_type.torch).max else data_type.torch, + ) + % self._config.max_vocab_size + ).to(data_type.torch) + else: + tokens = torch.tensor(tokens, dtype=data_type.torch) + return tokens, loss_masking_spans + + +def _train_mask_to_loss_spans(train_mask: list[bool]) -> list[tuple[int, int]]: + """ + Convert a boolean train mask to loss masking spans. + + Args: + train_mask: Boolean list where True = train on this token, False = don't train + + Returns: + List of (begin, end) spans marking token ranges to EXCLUDE from training + (i.e., where train_mask[i] == False). + """ + spans = [] + start = None + for i, should_train in enumerate(train_mask): + if not should_train: + if start is None: + start = i + elif start is not None: + spans.append((start, i)) + start = None + if start is not None: + spans.append((start, len(train_mask))) + return spans diff --git a/fast_llm/engine/evaluation/config.py b/fast_llm/engine/evaluation/config.py index df7ab0f51..90881cdc1 100644 --- a/fast_llm/engine/evaluation/config.py +++ b/fast_llm/engine/evaluation/config.py @@ -8,6 +8,7 @@ if typing.TYPE_CHECKING: from fast_llm.engine.evaluation.evaluator import Evaluator, EvaluatorLmEval, LossEvaluator + from fast_llm.engine.evaluation.forward_kl.evaluator import ForwardKLEvaluator @config_class() @@ -119,3 +120,58 @@ def get_evaluator( from fast_llm.engine.evaluation.lm_eval.evaluator import LmEvalEvaluator return LmEvalEvaluator(name, self, batch_config, data_load_num_proc, train_iters) + + +@config_class(dynamic_type={EvaluatorConfig: "forward_kl"}) +class ForwardKLEvaluatorConfig(EvaluatorConfig): + _abstract: typing.ClassVar[bool] = False + + dataset_path: str | None = Field( + default=None, + desc="HuggingFace dataset path containing teacher traces.", + hint=FieldHint.core, + ) + split: str = Field( + default="validation", + desc="Dataset split to evaluate on. Use 'train+validation' syntax to combine multiple splits.", + hint=FieldHint.optional, + ) + seed: int = Field( + default=42, + desc="Random seed for shuffling traces. Ensures reproducible evaluation across runs.", + hint=FieldHint.optional, + ) + num_samples: int | None = Field( + default=None, + desc="Maximum number of traces to evaluate (after shuffling). None for all.", + hint=FieldHint.optional, + valid=skip_valid_if_none(check_field(Assert.gt, 0)), + ) + batch_size: int = Field( + default=8, + desc="Batch size for forward passes.", + hint=FieldHint.performance, + valid=check_field(Assert.gt, 0), + ) + trust_remote_code: bool = Field( + default=False, + desc="Trust remote code when loading dataset.", + hint=FieldHint.optional, + ) + inference_mixer: str | None = Field( + default=None, + desc="Name of the mixer to use during evaluation (for StochasticMixer models). " + "If None, uses the model's default main_mixer_name.", + hint=FieldHint.optional, + ) + + def get_evaluator( + self, + name: str, + batch_config: BatchConfig, + data_load_num_proc: int, + train_iters: int | None = None, + ) -> "ForwardKLEvaluator": + from fast_llm.engine.evaluation.forward_kl.evaluator import ForwardKLEvaluator + + return ForwardKLEvaluator(name, self, batch_config, data_load_num_proc, train_iters) diff --git a/fast_llm/engine/evaluation/forward_kl/__init__.py b/fast_llm/engine/evaluation/forward_kl/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fast_llm/engine/evaluation/forward_kl/evaluator.py b/fast_llm/engine/evaluation/forward_kl/evaluator.py new file mode 100644 index 000000000..5e69862d2 --- /dev/null +++ b/fast_llm/engine/evaluation/forward_kl/evaluator.py @@ -0,0 +1,451 @@ +import dataclasses +import gc +import hashlib +import logging +import math + +import torch +import torch.nn.functional as F +import tqdm + +from fast_llm.config import NoAutoValidate +from fast_llm.core.distributed import ReduceOp, allreduce_scalar, safe_barrier +from fast_llm.data.data.abstract import Data +from fast_llm.data.sample.language_model import LanguageModelBatch, LanguageModelSample +from fast_llm.data.sample.token import TokenSample +from fast_llm.engine.config_utils.run import Run +from fast_llm.engine.distributed.config import PhaseType +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.engine.evaluation.config import ForwardKLEvaluatorConfig +from fast_llm.engine.evaluation.evaluator import ( + EvaluationMetrics, + Evaluator, + EvaluatorSamplingParameters, + TrainingProgress, +) +from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel +from fast_llm.engine.schedule.runner import ScheduleRunner +from fast_llm.layers.attention.config import AttentionKwargs +from fast_llm.models.gpt.config import GPTBatchConfig +from fast_llm.models.gpt.model import GPTInferenceRunner + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class TraceTensors: + tokens: torch.Tensor # (num_traces, sequence_length) + prompt_lens: torch.Tensor # (num_traces,) + completion_lens: torch.Tensor # (num_traces,) + problem_indices: torch.Tensor # (num_traces,) + teacher_log_probs: torch.Tensor # (num_traces,) + corrects: torch.Tensor # (num_traces,) + num_problems: int + num_skipped: int + + def __len__(self) -> int: + return self.tokens.shape[0] + + @classmethod + def empty(cls, sequence_length: int, device: torch.device, num_skipped: int = 0) -> "TraceTensors": + return cls( + tokens=torch.empty((0, sequence_length), dtype=torch.int64, device=device), + prompt_lens=torch.empty(0, dtype=torch.int64, device=device), + completion_lens=torch.empty(0, dtype=torch.int64, device=device), + problem_indices=torch.empty(0, dtype=torch.int64, device=device), + teacher_log_probs=torch.empty(0, dtype=torch.float64, device=device), + corrects=torch.empty(0, dtype=torch.bool, device=device), + num_problems=0, + num_skipped=num_skipped, + ) + + @classmethod + def from_traces( + cls, + traces: list[dict], + sequence_length: int, + device: torch.device, + ) -> "TraceTensors": + pid_to_idx: dict[str, int] = {} + valid_traces: list[tuple[list[int], list[int], str, float, bool]] = [] + num_skipped = 0 + + for t in traces: + prompt, completion = t["prompt_tokens"], t["completion_tokens"] + if len(prompt) + len(completion) > sequence_length: + num_skipped += 1 + continue + valid_traces.append((prompt, completion, t["problem_id"], t["teacher_log_prob"], t["correct"])) + + if not valid_traces: + return cls.empty(sequence_length, device, num_skipped) + + n = len(valid_traces) + tokens = torch.zeros((n, sequence_length), dtype=torch.int64, device=device) + prompt_lens = torch.empty(n, dtype=torch.int64, device=device) + completion_lens = torch.empty(n, dtype=torch.int64, device=device) + problem_indices = torch.empty(n, dtype=torch.int64, device=device) + teacher_log_probs = torch.empty(n, dtype=torch.float64, device=device) + corrects = torch.empty(n, dtype=torch.bool, device=device) + + for i, (prompt, completion, pid, teacher_lp, correct) in enumerate(valid_traces): + seq = prompt + completion + tokens[i, : len(seq)] = torch.tensor(seq, dtype=torch.int64, device=device) + prompt_lens[i] = len(prompt) + completion_lens[i] = len(completion) + + if pid not in pid_to_idx: + pid_to_idx[pid] = len(pid_to_idx) + problem_indices[i] = pid_to_idx[pid] + teacher_log_probs[i] = teacher_lp + corrects[i] = correct + + return cls( + tokens=tokens, + prompt_lens=prompt_lens, + completion_lens=completion_lens, + problem_indices=problem_indices, + teacher_log_probs=teacher_log_probs, + corrects=corrects, + num_problems=len(pid_to_idx), + num_skipped=num_skipped, + ) + + +class ForwardKLEvaluator[ConfigType: ForwardKLEvaluatorConfig](Evaluator[ConfigType]): + """Shard by PROBLEM (not trace) so each rank gets complete problems. + + This allows computing per-problem IS metrics locally, then reducing scalars. + """ + + _inference_runner: GPTInferenceRunner + _sequence_length: int + _micro_sequence_length: int + + def setup( + self, + distributed: Distributed, + run: Run, + multi_stage: FastLLMModel, + runner: ScheduleRunner, + data: Data, + phase: PhaseType, + ) -> None: + super().setup(distributed, run, multi_stage, runner, data, phase) + self._inference_runner = GPTInferenceRunner(self._multi_stage, runner=self._runner) + self._inference_runner.setup() + self._sequence_length = self._batch_config.sequence_length + self._micro_sequence_length = self._batch_config.micro_sequence_length + self._is_setup = True + + def get_sampling_parameters(self) -> EvaluatorSamplingParameters | None: + return None + + def run( + self, + training_progress: TrainingProgress | None = None, + run_index: int | None = None, + ) -> EvaluationMetrics: + assert self._is_setup + if self._config.dataset_path is None: + return EvaluationMetrics() + + safe_barrier(self._distributed.world_group, f"forward_kl_{self._name} begin") + metrics = self._evaluate() + safe_barrier(self._distributed.world_group, f"forward_kl_{self._name} end") + + if metrics["num_traces"] == 0: + return EvaluationMetrics() + + formatted = ( + f"IS Eval ({self._name}): " + f"acc={metrics['is_accuracy']:.4f}, " + f"ESS={metrics['mean_ess']:.2f}/{metrics['samples_per_problem']:.1f}, " + f"({metrics['num_problems']} problems, {metrics['num_traces']} traces)" + ) + if metrics["num_skipped"] > 0: + formatted += f" [{metrics['num_skipped']} skipped]" + + return EvaluationMetrics( + {f"validation.{self._name}": {k: v for k, v in metrics.items() if k != "num_skipped"}}, + formatted, + ) + + @torch.inference_mode() + def _evaluate(self) -> dict[str, float]: + device = self._distributed.device + data = self._load_traces(device) + + # Switch to eval mode so StochasticMixer uses the main mixer + # instead of randomly sampling. + was_training = self._multi_stage._training + self._multi_stage.train(False) + + # Optionally override the inference mixer for StochasticMixer layers + stochastic_mixers: list = [] + if self._config.inference_mixer is not None: + from fast_llm.layers.decoder.stochastic_mixer import StochasticMixer + + for name, module in self._multi_stage.base_model.named_modules(): + if isinstance(module, StochasticMixer): + stochastic_mixers.append(module) + module._inference_mixer_override = self._config.inference_mixer + logger.info(f"ForwardKL: Set {name} inference mixer to '{self._config.inference_mixer}'") + + try: + batch_size = self._config.batch_size + student_log_probs_batches: list[torch.Tensor] = [] + local_num_batches = math.ceil(len(data) / batch_size) if len(data) > 0 else 0 + + # Synchronize batch count across all world ranks. + # All ranks must execute the same number of forward passes because the forward + # pass involves collective operations (e.g., ZeRO all-gather) that require + # participation from all ranks in the process group. + max_num_batches = int( + allreduce_scalar(local_num_batches, torch.int64, self._distributed.world_group, ReduceOp.MAX) + ) + + if max_num_batches == 0: + return self._reduce_metrics(0.0, 0.0, 0, 0, data.num_skipped) + + # Create dummy data for ranks that have no data or finish early. + # These ranks still need to participate in collective operations. + dummy_tokens = torch.zeros((batch_size, self._sequence_length), dtype=torch.int64, device=device) + dummy_prompt_lens = torch.ones(batch_size, dtype=torch.int64, device=device) + dummy_completion_lens = torch.ones(batch_size, dtype=torch.int64, device=device) + + # Only show progress bar on rank 0 + batch_iter = range(max_num_batches) + if self._distributed.config.rank == 0: + batch_iter = tqdm.tqdm( + batch_iter, + total=max_num_batches, + desc=f"ForwardKL ({self._name})", + unit="batch", + ) + + for batch_idx in batch_iter: + i = batch_idx * batch_size + if i < len(data): + # This rank has real data for this batch + batch_log_probs = self._compute_batch_log_probs( + data.tokens[i : i + batch_size], + data.prompt_lens[i : i + batch_size], + data.completion_lens[i : i + batch_size], + ) + if batch_log_probs is not None: + student_log_probs_batches.append(batch_log_probs) + else: + # This rank has no more data but must still participate in collectives. + # Run a dummy forward pass and discard the result. + self._compute_batch_log_probs(dummy_tokens, dummy_prompt_lens, dummy_completion_lens) + + if not student_log_probs_batches: # non-last PP rank or no local data + return self._reduce_metrics(0.0, 0.0, 0, 0, data.num_skipped) + finally: + # Clear inference mixer override for StochasticMixer layers + for module in stochastic_mixers: + module._inference_mixer_override = None + + # Restore original training mode + if was_training: + self._multi_stage.train(True) + + student_log_probs = torch.cat(student_log_probs_batches) + log_w = student_log_probs - data.teacher_log_probs + + # Diagnostic logging with percentiles + pcts = torch.tensor([0.01, 0.05, 0.10, 0.25, 0.50, 0.75, 0.90, 0.95, 0.99], device=log_w.device) + pct_labels = ["1%", "5%", "10%", "25%", "50%", "75%", "90%", "95%", "99%"] + + def fmt_percentiles(t: torch.Tensor) -> str: + q = torch.quantile(t.float(), pcts) + return ", ".join(f"{l}={v:.1f}" for l, v in zip(pct_labels, q.tolist())) + + logger.info(f"student_log_probs: [{fmt_percentiles(student_log_probs)}]") + logger.info(f"teacher_log_probs: [{fmt_percentiles(data.teacher_log_probs)}]") + logger.info(f"log_w: [{fmt_percentiles(log_w)}]") + + log_sum_all = self._scatter_logsumexp(log_w, data.problem_indices, data.num_problems) + log_w_correct = log_w.masked_fill(~data.corrects, float("-inf")) + log_sum_correct = self._scatter_logsumexp(log_w_correct, data.problem_indices, data.num_problems) + + # IS accuracy; nan_to_num handles -inf - -inf + accuracy = (log_sum_correct - log_sum_all).exp().nan_to_num(0.0) + + # ESS = exp(2*logsumexp(log_w) - logsumexp(2*log_w)) + log_sum_sq = self._scatter_logsumexp(2 * log_w, data.problem_indices, data.num_problems) + ess = (2 * log_sum_all - log_sum_sq).exp().clamp(min=0.0) + + # ESS diagnostics with percentiles + traces_per_problem = torch.bincount(data.problem_indices, minlength=data.num_problems) + multi_trace_mask = traces_per_problem > 1 + if multi_trace_mask.any(): + multi_ess = ess[multi_trace_mask] + multi_traces = traces_per_problem[multi_trace_mask] + logger.info(f"ESS ({multi_trace_mask.sum().item()} multi-trace problems): [{fmt_percentiles(multi_ess)}]") + logger.info(f"traces/problem: [{fmt_percentiles(multi_traces.float())}]") + + return self._reduce_metrics( + accuracy.sum().item(), + ess.sum().item(), + data.num_problems, + len(data), + data.num_skipped, + ) + + def _load_traces(self, device: torch.device) -> TraceTensors: + import datasets + + ds = datasets.load_dataset( + self._config.dataset_path, + split=self._config.split, + trust_remote_code=self._config.trust_remote_code, + ) + + # Shuffle needed because traces are sorted by problem + if self._config.num_samples and len(ds) > self._config.num_samples: + ds = ds.shuffle(seed=self._config.seed).select(range(self._config.num_samples)) + + dp_rank = self._distributed.config.data_rank + dp_size = self._distributed.config.data_parallel + + def belongs_to_shard(example: dict) -> bool: + h = hashlib.md5(example["problem_id"].encode(), usedforsecurity=False).digest() + return int.from_bytes(h[:4], "little") % dp_size == dp_rank + + ds = ds.filter(belongs_to_shard) + traces = list(ds) + + del ds + gc.collect() + + return TraceTensors.from_traces(traces, self._sequence_length, device) + + def _compute_batch_log_probs( + self, + tokens: torch.Tensor, + prompt_lens: torch.Tensor, + completion_lens: torch.Tensor, + ) -> torch.Tensor | None: + batch_size = tokens.shape[0] + lm_batch = self._prepare_batch(tokens, prompt_lens, completion_lens) + + with NoAutoValidate(): + batch_config = GPTBatchConfig( + micro_batch_size=batch_size, + sequence_length=self._sequence_length, + micro_sequence_length=self._micro_sequence_length, + truncate_documents=False, + ) + batch_config.setup(self._distributed.config) + batch_config.validate() + + preprocessed_meta = self._multi_stage.base_model.preprocess_meta(batch_config, PhaseType.inference) + preprocessed = self._multi_stage.base_model.preprocess_batch( + lm_batch, preprocessed_meta, phase=PhaseType.inference, iteration=0 + ) + + # Loop runs through micro-sequences; final kwargs has the logits + for input_, kwargs in preprocessed: + kwargs["global_logits"] = True + self._inference_runner.forward(input_, kwargs) + + if "logits" not in kwargs: # non-last PP stage + return None + + logits = kwargs["logits"] + if kwargs.get(AttentionKwargs.sequence_first, False): + logits = logits.transpose(0, 1) + + device = logits.device + seq_len = logits.shape[1] + + pred_logits = logits[:, :-1, :].contiguous() + targets = tokens[:, 1:].contiguous().to(device) + + # Mask: completion predictions are at [prompt_len-1, prompt_len+completion_len-1) + mask = self._create_completion_mask(prompt_lens, completion_lens, seq_len - 1) + + ce_loss = F.cross_entropy( + pred_logits.view(-1, pred_logits.size(-1)), + targets.view(-1), + reduction="none", + ).view(batch_size, seq_len - 1) + + results = -(ce_loss * mask).sum(dim=1) + + del logits, kwargs, preprocessed, lm_batch + + return results.to(torch.float64) + + def _prepare_batch( + self, + tokens: torch.Tensor, + prompt_lens: torch.Tensor, + completion_lens: torch.Tensor, + ) -> LanguageModelBatch: + samples = [] + for i in range(tokens.shape[0]): + seq_len = int(prompt_lens[i].item()) + int(completion_lens[i].item()) + sample = LanguageModelSample(TokenSample(tokens[i, :seq_len].cpu())) + + pad_len = self._sequence_length - seq_len + if pad_len > 0: + sample = LanguageModelSample.from_documents([sample, sample.get_padding(pad_len)]) + + samples.append(sample) + + return LanguageModelBatch.from_samples(samples) + + def _create_completion_mask( + self, + prompt_lens: torch.Tensor, + completion_lens: torch.Tensor, + seq_len: int, + ) -> torch.Tensor: + device = prompt_lens.device + positions = torch.arange(seq_len, device=device) + start = (prompt_lens - 1).unsqueeze(1) + end = (prompt_lens + completion_lens - 1).unsqueeze(1) + return (positions >= start) & (positions < end) + + def _reduce_metrics( + self, sum_accuracy: float, sum_ess: float, num_problems: int, num_traces: int, num_skipped: int + ) -> dict[str, float]: + group = self._distributed.world_group + sum_accuracy = allreduce_scalar(sum_accuracy, group=group) + sum_ess = allreduce_scalar(sum_ess, group=group) + num_problems = int(allreduce_scalar(num_problems, torch.int64, group=group)) + num_traces = int(allreduce_scalar(num_traces, torch.int64, group=group)) + num_skipped = int(allreduce_scalar(num_skipped, torch.int64, group=group)) + + if num_problems == 0: + return { + "is_accuracy": 0.0, + "mean_ess": 0.0, + "samples_per_problem": 0.0, + "num_traces": 0, + "num_problems": 0, + "num_skipped": num_skipped, + } + + return { + "is_accuracy": sum_accuracy / num_problems, + "mean_ess": sum_ess / num_problems, + "samples_per_problem": num_traces / num_problems, + "num_traces": num_traces, + "num_problems": num_problems, + "num_skipped": num_skipped, + } + + def _scatter_logsumexp(self, src: torch.Tensor, index: torch.Tensor, num_groups: int) -> torch.Tensor: + # Max per group for numerical stability + max_vals = torch.full((num_groups,), float("-inf"), device=src.device, dtype=src.dtype) + max_vals.scatter_reduce_(0, index, src, reduce="amax") + + src_shifted = (src - max_vals[index]).exp() + sum_exp = torch.zeros(num_groups, device=src.device, dtype=src.dtype) + sum_exp.scatter_add_(0, index, src_shifted) + + return max_vals + sum_exp.log() diff --git a/fast_llm/layers/decoder/stochastic_mixer.py b/fast_llm/layers/decoder/stochastic_mixer.py index 984f34b80..76b261a4e 100644 --- a/fast_llm/layers/decoder/stochastic_mixer.py +++ b/fast_llm/layers/decoder/stochastic_mixer.py @@ -106,7 +106,8 @@ def setup(self, distributed: Distributed) -> None: def _sample_mixer_name(self, kwargs: dict[str, typing.Any]) -> str: if not self.training: - return self._config.main_mixer_name + # Allow runtime override of the inference mixer (e.g., for evaluation) + return getattr(self, "_inference_mixer_override", None) or self._config.main_mixer_name generator = kwargs[StochasticMixerKwargs.generator] mixer_idx = torch.multinomial(self._sampling_probs, num_samples=1, generator=generator).item() diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index b1d0c2acd..94ddbded9 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -250,16 +250,10 @@ def _logits_cross_entropy_forward_backward_split( input_, targets, weight, grad_output, kwargs, losses ) if targets is None: - # TODO: Make a proper way of returning the model output. - loss = loss.detach() - if kwargs.get("global_logits"): - if self._vocab_parallel: - loss = gather_op(loss, self._parallel_dim.group, 2) - elif self._sequence_parallel_logits: - loss = gather_op( - loss, self._parallel_dim.group, 0 if kwargs[LanguageModelKwargs.sequence_first] else 1 - ) - kwargs["logits" if self._prediction_distance == 0 else f"logits_{self._prediction_distance}"] = loss + # global_logits: raw logits already stored and gathered in inner function + # non-global_logits: store scaled logits for distillation backwards compat + if not kwargs.get("global_logits"): + kwargs["logits" if self._prediction_distance == 0 else f"logits_{self._prediction_distance}"] = loss.detach() return None, None else: loss = None @@ -342,6 +336,17 @@ def _logits_cross_entropy_forward_backward( dims = None self._debug(logits, "logits", dims, kwargs, scale=self._config.logits_scale_factor) + if kwargs.get("global_logits"): + logits_for_storage = logits.detach() + if self._vocab_parallel: + logits_for_storage = gather_op(logits_for_storage, self._parallel_dim.group, 2) + elif self._sequence_parallel_logits: + logits_for_storage = gather_op( + logits_for_storage, self._parallel_dim.group, 0 if kwargs[LanguageModelKwargs.sequence_first] else 1 + ) + logits_key = "logits" if self._prediction_distance == 0 else f"logits_{self._prediction_distance}" + kwargs[logits_key] = logits_for_storage + if targets is None: return logits * self._config.logits_scale_factor, None dpo_target, lm_target, distillation_target, loss_mask = targets diff --git a/fast_llm/models/gpt/conversion/apriel2.py b/fast_llm/models/gpt/conversion/apriel2.py index 4ed588ed5..91e3be508 100644 --- a/fast_llm/models/gpt/conversion/apriel2.py +++ b/fast_llm/models/gpt/conversion/apriel2.py @@ -39,8 +39,20 @@ def import_config(cls, config: dict) -> dict: "head_groups": config["head_groups"], "head_size": config["head_size"], "rotary": rotary, - "add_linear_biases": config["add_linear_biases"], } + # Per-layer bias configuration mirroring Fast-LLM structure + # If per-layer configs exist, use them; otherwise fall back to add_linear_biases + if "query_layer" in config: + result["query_layer"] = config["query_layer"] + if "key_layer" in config: + result["key_layer"] = config["key_layer"] + if "value_layer" in config: + result["value_layer"] = config["value_layer"] + if "dense_layer" in config: + result["dense_layer"] = config["dense_layer"] + # add_linear_biases serves as default for layers without explicit config + if "add_linear_biases" in config: + result["add_linear_biases"] = config["add_linear_biases"] if "window_size" in config: result["window_size"] = config["window_size"] return result @@ -58,18 +70,37 @@ def export_config(cls, config: AttentionConfig) -> dict: else: raise NotImplementedError(f"Unsupported rotary type: {type(config.rotary).__name__}") - return { + result = { "type": "attention", "heads": config.heads, "head_groups": config.head_groups, "head_size": config.head_size, - "add_linear_biases": config.add_linear_biases, "rotary": { "type": rotary_type, "theta": config.rotary.theta, }, "window_size": config.window_size, } + # Export per-layer bias configuration + # Only include if explicitly set (not None) + if config.query_layer.bias.enabled is not None: + result["query_layer"] = {"bias": {"enabled": config.query_layer.bias.enabled}} + if config.key_layer.bias.enabled is not None: + result["key_layer"] = {"bias": {"enabled": config.key_layer.bias.enabled}} + if config.value_layer.bias.enabled is not None: + result["value_layer"] = {"bias": {"enabled": config.value_layer.bias.enabled}} + if config.dense_layer.bias.enabled is not None: + result["dense_layer"] = {"bias": {"enabled": config.dense_layer.bias.enabled}} + # add_linear_biases as fallback default + result["add_linear_biases"] = config.add_linear_biases + return result + + @classmethod + def _get_effective_bias(cls, layer_config, default: bool) -> bool: + """Get effective bias setting: use layer-specific if set, else default.""" + if layer_config.bias.enabled is not None: + return layer_config.bias.enabled + return default @classmethod def get_converters( @@ -79,11 +110,20 @@ def get_converters( hf_prefix: str, drop_on_export: bool = False, ) -> list[WeightConverter]: + # Determine effective bias for each projection + q_bias = cls._get_effective_bias(config.query_layer, config.add_linear_biases) + k_bias = cls._get_effective_bias(config.key_layer, config.add_linear_biases) + v_bias = cls._get_effective_bias(config.value_layer, config.add_linear_biases) + o_bias = cls._get_effective_bias(config.dense_layer, config.add_linear_biases) + # For key_value, both k and v must have same bias setting + # (they're combined in Fast-LLM's key_value layer) + kv_bias = k_bias and v_bias + return [ *get_weight_and_bias_converters( f"{fast_llm_prefix}.query", f"{hf_prefix}.q_proj", - config.add_linear_biases, + q_bias, QueryWeightConverter, config, drop_on_export=drop_on_export, @@ -91,7 +131,7 @@ def get_converters( *get_weight_and_bias_converters( f"{fast_llm_prefix}.key_value", (f"{hf_prefix}.k_proj", f"{hf_prefix}.v_proj"), - config.add_linear_biases, + kv_bias, KeyValueWeightConverter, config, drop_on_export=drop_on_export, @@ -99,7 +139,7 @@ def get_converters( *get_weight_and_bias_converters( f"{fast_llm_prefix}.dense", f"{hf_prefix}.o_proj", - config.add_linear_biases, + o_bias, drop_on_export=drop_on_export, ), ] @@ -524,6 +564,12 @@ def import_config(cls, config: dict, block_config: dict) -> dict: "gated": mlp_config["gated"], "add_linear_biases": mlp_config["add_linear_biases"], } + # Import per-layer MLP bias settings (layer_1, layer_2) + for layer_name in ("layer_1", "layer_2"): + if layer_name in mlp_config: + layer_cfg = mlp_config[layer_name] + if "bias" in layer_cfg: + mlp[layer_name] = {"bias": layer_cfg["bias"]} normalization = block_config["normalization"] @@ -578,6 +624,11 @@ def export_config(cls, config: DecoderBlockConfig) -> dict: "gated": config.mlp.gated, "add_linear_biases": config.mlp.add_linear_biases, } + # Export per-layer MLP bias settings (layer_1, layer_2) + if config.mlp.layer_1.bias.enabled is not None: + mlp["layer_1"] = {"bias": {"enabled": config.mlp.layer_1.bias.enabled}} + if config.mlp.layer_2.bias.enabled is not None: + mlp["layer_2"] = {"bias": {"enabled": config.mlp.layer_2.bias.enabled}} normalization = {"type": norm_type_str, "epsilon": config.normalization.epsilon} @@ -624,24 +675,56 @@ def get_converters( ) ) - converters.extend( - [ - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_1", - (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), - config.mlp.add_linear_biases, - SplitWeightConverter, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_2", - f"{hf_prefix}.mlp.down_proj", - config.mlp.add_linear_biases, - MLPLayer2Converter, - drop_on_export=drop_on_export, - ), - ] - ) + # Per-layer MLP bias: use layer-specific setting if set, else default + def get_mlp_layer_bias(layer_config, default: bool) -> bool: + if layer_config.bias.enabled is not None: + return layer_config.bias.enabled + return default + + layer_1_bias = get_mlp_layer_bias(config.mlp.layer_1, config.mlp.add_linear_biases) + layer_2_bias = get_mlp_layer_bias(config.mlp.layer_2, config.mlp.add_linear_biases) + + if config.mlp.gated: + # Gated MLP: gate_proj + up_proj -> layer_1 (split), down_proj -> layer_2 + converters.extend( + [ + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_1", + (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), + layer_1_bias, + SplitWeightConverter, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_2", + f"{hf_prefix}.mlp.down_proj", + layer_2_bias, + MLPLayer2Converter, + drop_on_export=drop_on_export, + ), + ] + ) + else: + # Non-gated MLP: up_proj -> layer_1, down_proj -> layer_2 + # Note: layer_2 still needs MLPLayer2Converter for the transpose + converters.extend( + [ + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_1", + f"{hf_prefix}.mlp.up_proj", + layer_1_bias, + WeightConverter, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_2", + f"{hf_prefix}.mlp.down_proj", + layer_2_bias, + MLPLayer2Converter, + drop_on_export=drop_on_export, + ), + ] + ) converters.extend( [ diff --git a/fast_llm/models/gpt/conversion/qwen2.py b/fast_llm/models/gpt/conversion/qwen2.py index a8bc33454..4ebf18c3a 100644 --- a/fast_llm/models/gpt/conversion/qwen2.py +++ b/fast_llm/models/gpt/conversion/qwen2.py @@ -1,15 +1,21 @@ import typing from fast_llm.engine.checkpoint.config import CheckpointFormat +from fast_llm.engine.checkpoint.external import WeightConverter from fast_llm.layers.attention.config import AttentionConfig +from fast_llm.layers.decoder.mlp.config import MLPConfig from fast_llm.models.gpt.conversion.config import Qwen2CheckpointFormat from fast_llm.models.gpt.conversion.llama import ( + KeyValueWeightConverter, LlamaAttentionConverter, LlamaBaseModelConverter, LlamaBlockConverter, LlamaDecoderConverter, LlamaHeadConverter, LlamaHuggingfaceCheckpointHandler, + LlamaMLPConverter, + QueryWeightConverter, + get_weight_and_bias_converters, ) from fast_llm.utils import Assert @@ -17,6 +23,22 @@ class Qwen2AttentionConverter(LlamaAttentionConverter): # TODO: Support sliding window with max_window_layers (need 2 kinds of block?) + @classmethod + def import_config(cls, config: dict) -> dict: + config["attention_bias"] = True + out = super().import_config(config) + out["query_layer"] = {"bias": {"enabled": True}} + out["key_layer"] = {"bias": {"enabled": True}} + out["value_layer"] = {"bias": {"enabled": True}} + out["dense_layer"] = {"bias": {"enabled": False}} + return out + + @classmethod + def export_config(cls, config: AttentionConfig) -> dict: + out = super().export_config(config) + del out["attention_bias"] + return out + @classmethod def _check_config(cls, config: AttentionConfig) -> None: Assert.is_(type(config), AttentionConfig) @@ -32,9 +54,56 @@ def _check_config(cls, config: AttentionConfig) -> None: Assert.is_(config.value_layer.bias.enabled, True) Assert.incl(config.dense_layer.bias.enabled, (None, False)) + @classmethod + def get_converters( + cls, + config: AttentionConfig, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + ) -> list[WeightConverter]: + return [ + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.query", + f"{hf_prefix}.q_proj", + True, + QueryWeightConverter, + config, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.key_value", + (f"{hf_prefix}.k_proj", f"{hf_prefix}.v_proj"), + True, + KeyValueWeightConverter, + config, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.dense", + f"{hf_prefix}.o_proj", + False, + drop_on_export=drop_on_export, + ), + ] + + +class Qwen2MLPConverter(LlamaMLPConverter): + @classmethod + def import_config(cls, config: dict) -> dict: + config["mlp_bias"] = False + return super().import_config(config) + + @classmethod + def export_config(cls, config: MLPConfig) -> dict: + out = super().export_config(config) + del out["mlp_bias"] + return out + class Qwen2BlockConverter(LlamaBlockConverter): mixer_converter_class: typing.ClassVar[type[Qwen2AttentionConverter]] = Qwen2AttentionConverter + mlp_converter_class: typing.ClassVar[type[Qwen2MLPConverter]] = Qwen2MLPConverter class Qwen2DecoderConverter(LlamaDecoderConverter): diff --git a/fast_llm/models/multimodal/conversion/apriel2.py b/fast_llm/models/multimodal/conversion/apriel2.py index b4147a8bf..307a67c63 100644 --- a/fast_llm/models/multimodal/conversion/apriel2.py +++ b/fast_llm/models/multimodal/conversion/apriel2.py @@ -326,9 +326,7 @@ class Apriel2MultimodalBaseModelConverter: @classmethod def import_config(cls, config: dict) -> dict: text_config = Apriel2BaseModelConverter.import_config(config) - vision_config = ( - cls.vision_model_converter_class.import_config(config) if "vision_encoder" in config else None - ) + vision_config = cls.vision_model_converter_class.import_config(config) if "vision_encoder" in config else None result = safe_merge_dicts( text_config, @@ -388,10 +386,7 @@ def get_transformers_configuration_class(cls): @classmethod def get_model_files(cls) -> tuple[str, str, str | None]: - from fast_llm_external_models.apriel2 import ( - configuration_apriel2, - modeling_apriel2, - ) + from fast_llm_external_models.apriel2 import configuration_apriel2, modeling_apriel2 return configuration_apriel2.__file__, modeling_apriel2.__file__, None diff --git a/fast_llm_external_models/apriel2/cache.py b/fast_llm_external_models/apriel2/cache.py index 86c67a085..f83ae87d6 100644 --- a/fast_llm_external_models/apriel2/cache.py +++ b/fast_llm_external_models/apriel2/cache.py @@ -1,17 +1,22 @@ from __future__ import annotations + import torch from transformers.cache_utils import Cache class _AttentionCache: - __slots__ = ["key", "value", "window"] + __slots__ = ["key", "value", "window", "cumulative_length"] def __init__(self, window=None): self.key = None self.value = None self.window = window + self.cumulative_length = 0 def update(self, key, value): + new_tokens = key.shape[-2] + self.cumulative_length += new_tokens + if self.key is None: if self.window and key.shape[-2] > self.window: self.key = key[..., -self.window :, :].contiguous() @@ -35,6 +40,40 @@ def _window(self, cache, new): return cache return torch.cat([cache, new], -2)[..., -self.window :, :].contiguous() + def reset(self): + self.key = None + self.value = None + self.cumulative_length = 0 + + def reorder(self, beam_idx): + if self.key is not None: + self.key = self.key.index_select(0, beam_idx.to(self.key.device)) + self.value = self.value.index_select(0, beam_idx.to(self.value.device)) + + def crop(self, max_length): + if self.key is not None: + self.key = self.key[..., :max_length, :] + self.value = self.value[..., :max_length, :] + self.cumulative_length = self.key.shape[-2] + + def batch_repeat(self, repeats): + if self.key is not None: + self.key = self.key.repeat_interleave(repeats, dim=0) + self.value = self.value.repeat_interleave(repeats, dim=0) + + def batch_select(self, indices): + if self.key is not None: + self.key = self.key.index_select(0, indices.to(self.key.device)) + self.value = self.value.index_select(0, indices.to(self.value.device)) + + @property + def is_initialized(self): + return self.key is not None + + @property + def batch_size(self): + return self.key.shape[0] if self.key is not None else None + class _SSMCache: __slots__ = ["conv", "recurrent"] @@ -43,6 +82,52 @@ def __init__(self): self.conv = None self.recurrent = None + def reset(self): + self.conv = None + self.recurrent = None + + def reorder(self, beam_idx): + if self.conv is not None: + if isinstance(self.conv, tuple): + self.conv = tuple(c.index_select(0, beam_idx.to(c.device)) for c in self.conv) + else: + self.conv = self.conv.index_select(0, beam_idx.to(self.conv.device)) + if self.recurrent is not None: + self.recurrent = self.recurrent.index_select(0, beam_idx.to(self.recurrent.device)) + + def crop(self, max_length): + pass # SSM caches don't have sequence dimension to crop + + def batch_repeat(self, repeats): + if self.conv is not None: + if isinstance(self.conv, tuple): + self.conv = tuple(c.repeat_interleave(repeats, dim=0) for c in self.conv) + else: + self.conv = self.conv.repeat_interleave(repeats, dim=0) + if self.recurrent is not None: + self.recurrent = self.recurrent.repeat_interleave(repeats, dim=0) + + def batch_select(self, indices): + if self.conv is not None: + if isinstance(self.conv, tuple): + self.conv = tuple(c.index_select(0, indices.to(c.device)) for c in self.conv) + else: + self.conv = self.conv.index_select(0, indices.to(self.conv.device)) + if self.recurrent is not None: + self.recurrent = self.recurrent.index_select(0, indices.to(self.recurrent.device)) + + @property + def is_initialized(self): + return self.conv is not None + + @property + def batch_size(self): + if self.conv is None: + return None + if isinstance(self.conv, tuple): + return self.conv[0].shape[0] + return self.conv.shape[0] + class _DummyCacheLayer: pass @@ -93,14 +178,19 @@ def set_active_mixer(self, layer_idx, mixer_name): self.active_mixers[layer_idx] = mixer_name def get_seq_length(self, layer_idx=0): + """Returns the cumulative sequence length of tokens seen by the cache. + + For sliding window caches, this returns the total tokens seen (not just cached). + This matches HuggingFace's DynamicSlidingWindowLayer behavior. + """ layer = self.layers[layer_idx] if isinstance(layer, dict): mixer = self.active_mixers[layer_idx] if mixer and isinstance(layer[mixer], _AttentionCache): - return layer[mixer].key.shape[-2] if layer[mixer].key is not None else 0 + return layer[mixer].cumulative_length return 0 if isinstance(layer, _AttentionCache): - return layer.key.shape[-2] if layer.key is not None else 0 + return layer.cumulative_length return 0 def get_max_cache_shape(self, layer_idx=0): @@ -114,22 +204,61 @@ def get_max_cache_shape(self, layer_idx=0): return None def get_mask_sizes(self, cache_position, layer_idx): + """Return the length and offset of the cache, used to generate the attention mask. + + For standard (non-sliding) attention: + kv_offset = 0 (KV[0] corresponds to sequence position 0) + kv_length = cumulative_length + query_length + + For sliding window attention: + kv_offset = max(cumulative_length - window + 1, 0) + kv_length = min(cumulative_length, window - 1) + query_length + + For SSM/linear layers: + kv_offset = 0, kv_length = query_length (no KV cache to attend to) + """ query_length = cache_position.shape[0] - past_seen_tokens = self.get_seq_length(layer_idx) - kv_length = query_length + past_seen_tokens - kv_offset = past_seen_tokens - return kv_length, kv_offset + layer = self.layers[layer_idx] + + # Handle stochastic layers by getting the active mixer's cache + if isinstance(layer, dict): + mixer = self.active_mixers[layer_idx] + if mixer is None: + # No active mixer set, return defaults + return query_length, 0 + cache = layer[mixer] + else: + cache = layer + + # SSM layers don't have KV cache for attention mask purposes + if isinstance(cache, _SSMCache): + return query_length, 0 + + # Attention cache - check if sliding window + if isinstance(cache, _AttentionCache): + cumulative = cache.cumulative_length + window = cache.window + + if window is not None: + # Sliding window attention + kv_offset = max(cumulative - window + 1, 0) + if cumulative >= window: + kv_length = window - 1 + query_length + else: + kv_length = cumulative + query_length + else: + # Full attention + kv_offset = 0 + kv_length = cumulative + query_length + + return kv_length, kv_offset + + # Fallback + return query_length, 0 @property def has_previous_state(self): - for layer in self.layers: - if isinstance(layer, dict): - for cache in layer.values(): - if isinstance(cache, _SSMCache) and cache.conv is not None: - return True - elif isinstance(layer, _SSMCache) and layer.conv is not None: - return True - return False + return any(isinstance(cache, _SSMCache) and cache.conv is not None for cache in self._iter_caches()) @property def key_cache(self): @@ -147,101 +276,33 @@ def conv_states(self): def recurrent_states(self): return _LayerListAccessor(self, "recurrent") - def reorder_cache(self, beam_idx): - for i, layer in enumerate(self.layers): + def _iter_caches(self): + """Iterate over all leaf cache objects (flattening stochastic layer dicts).""" + for layer in self.layers: if isinstance(layer, dict): - for cache in layer.values(): - self._reorder_cache_obj(cache, beam_idx) + yield from layer.values() else: - self._reorder_cache_obj(layer, beam_idx) + yield layer - def _reorder_cache_obj(self, cache, beam_idx): - if isinstance(cache, _AttentionCache): - if cache.key is not None: - cache.key = cache.key.index_select(0, beam_idx.to(cache.key.device)) - cache.value = cache.value.index_select(0, beam_idx.to(cache.value.device)) - elif isinstance(cache, _SSMCache): - if cache.conv is not None: - # Handle both single tensor (GDN/Mamba) and tuple (KDA) conv states - if isinstance(cache.conv, tuple): - cache.conv = tuple(c.index_select(0, beam_idx.to(c.device)) for c in cache.conv) - else: - cache.conv = cache.conv.index_select(0, beam_idx.to(cache.conv.device)) - if cache.recurrent is not None: - cache.recurrent = cache.recurrent.index_select(0, beam_idx.to(cache.recurrent.device)) + def reorder_cache(self, beam_idx): + for cache in self._iter_caches(): + cache.reorder(beam_idx) def reset(self): - for layer in self.layers: - if isinstance(layer, dict): - for cache in layer.values(): - self._reset_cache_obj(cache) - else: - self._reset_cache_obj(layer) - - def _reset_cache_obj(self, cache): - if isinstance(cache, _AttentionCache): - cache.key = None - cache.value = None - elif isinstance(cache, _SSMCache): - cache.conv = None - cache.recurrent = None + for cache in self._iter_caches(): + cache.reset() def crop(self, max_length): - for layer in self.layers: - if isinstance(layer, dict): - for cache in layer.values(): - if isinstance(cache, _AttentionCache) and cache.key is not None: - cache.key = cache.key[..., :max_length, :] - cache.value = cache.value[..., :max_length, :] - elif isinstance(layer, _AttentionCache) and layer.key is not None: - layer.key = layer.key[..., :max_length, :] - layer.value = layer.value[..., :max_length, :] + for cache in self._iter_caches(): + cache.crop(max_length) def batch_repeat_interleave(self, repeats): - for layer in self.layers: - if isinstance(layer, dict): - for cache in layer.values(): - self._batch_repeat_cache_obj(cache, repeats) - else: - self._batch_repeat_cache_obj(layer, repeats) - - def _batch_repeat_cache_obj(self, cache, repeats): - if isinstance(cache, _AttentionCache): - if cache.key is not None: - cache.key = cache.key.repeat_interleave(repeats, dim=0) - cache.value = cache.value.repeat_interleave(repeats, dim=0) - elif isinstance(cache, _SSMCache): - if cache.conv is not None: - # Handle both single tensor (GDN/Mamba) and tuple (KDA) conv states - if isinstance(cache.conv, tuple): - cache.conv = tuple(c.repeat_interleave(repeats, dim=0) for c in cache.conv) - else: - cache.conv = cache.conv.repeat_interleave(repeats, dim=0) - if cache.recurrent is not None: - cache.recurrent = cache.recurrent.repeat_interleave(repeats, dim=0) + for cache in self._iter_caches(): + cache.batch_repeat(repeats) def batch_select_indices(self, indices): - for layer in self.layers: - if isinstance(layer, dict): - for cache in layer.values(): - self._batch_select_cache_obj(cache, indices) - else: - self._batch_select_cache_obj(layer, indices) - - def _batch_select_cache_obj(self, cache, indices): - if isinstance(cache, _AttentionCache): - if cache.key is not None: - cache.key = cache.key.index_select(0, indices.to(cache.key.device)) - cache.value = cache.value.index_select(0, indices.to(cache.value.device)) - elif isinstance(cache, _SSMCache): - if cache.conv is not None: - # Handle both single tensor (GDN/Mamba) and tuple (KDA) conv states - if isinstance(cache.conv, tuple): - cache.conv = tuple(c.index_select(0, indices.to(c.device)) for c in cache.conv) - else: - cache.conv = cache.conv.index_select(0, indices.to(cache.conv.device)) - if cache.recurrent is not None: - cache.recurrent = cache.recurrent.index_select(0, indices.to(cache.recurrent.device)) + for cache in self._iter_caches(): + cache.batch_select(indices) @property def is_compileable(self): @@ -249,19 +310,7 @@ def is_compileable(self): @property def is_initialized(self): - for layer in self.layers: - if isinstance(layer, dict): - for cache in layer.values(): - if isinstance(cache, _AttentionCache) and cache.key is not None: - return True - if isinstance(cache, _SSMCache) and cache.conv is not None: - return True - else: - if isinstance(layer, _AttentionCache) and layer.key is not None: - return True - if isinstance(layer, _SSMCache) and layer.conv is not None: - return True - return False + return any(cache.is_initialized for cache in self._iter_caches()) @property def is_sliding(self): @@ -280,39 +329,20 @@ def is_sliding(self): @property def max_batch_size(self): - for layer in self.layers: - if isinstance(layer, dict): - for cache in layer.values(): - if isinstance(cache, _AttentionCache) and cache.key is not None: - return cache.key.shape[0] - if isinstance(cache, _SSMCache) and cache.conv is not None: - # Handle both single tensor and tuple conv states - if isinstance(cache.conv, tuple): - return cache.conv[0].shape[0] - return cache.conv.shape[0] - else: - if isinstance(layer, _AttentionCache) and layer.key is not None: - return layer.key.shape[0] - if isinstance(layer, _SSMCache) and layer.conv is not None: - # Handle both single tensor and tuple conv states - if isinstance(layer.conv, tuple): - return layer.conv[0].shape[0] - return layer.conv.shape[0] + for cache in self._iter_caches(): + bs = cache.batch_size + if bs is not None: + return bs return None @property def max_cache_len(self): - max_len = None - for layer in self.layers: - if isinstance(layer, dict): - for cache in layer.values(): - if isinstance(cache, _AttentionCache): - if cache.window is not None: - max_len = cache.window if max_len is None else min(max_len, cache.window) - elif isinstance(layer, _AttentionCache): - if layer.window is not None: - max_len = layer.window if max_len is None else min(max_len, layer.window) - return max_len + windows = [ + cache.window + for cache in self._iter_caches() + if isinstance(cache, _AttentionCache) and cache.window is not None + ] + return min(windows) if windows else None def __len__(self): return len(self.layers) diff --git a/fast_llm_external_models/apriel2/conversion/__init__.py b/fast_llm_external_models/apriel2/conversion/__init__.py index 983a632e0..2c28d1e87 100644 --- a/fast_llm_external_models/apriel2/conversion/__init__.py +++ b/fast_llm_external_models/apriel2/conversion/__init__.py @@ -1,88 +1,138 @@ """Weight conversion system for Apriel2 models. -Architecture Overview -===================== +Overview +======== -This package implements a declarative weight transformation system with two -orthogonal concerns: +This package implements a declarative weight transformation system. The core +abstraction separates config composition (structural) from plan execution (weights). -1. **Config Composition** - Structural transformations of model configs -2. **Plan Building & Execution** - Weight transformations between configs +Conceptual Types +================ -These concerns are intentionally separated: -- Config composition determines WHAT the target architecture looks like -- Plan building determines HOW weights are transformed to match -- The `init` field bridges them: it's config metadata consumed by the plan builder +All configs are ``dict``, but we distinguish three conceptual types: -Key Design Decisions -==================== +**State (S)** - A complete model config without ``init`` fields. + What you load from disk or save after conversion. -**Declarative Plans** - Plans are DATA (JSON-serializable expressions), not functions. This enables: - - Inspection and debugging of transformations - - Serialization for distributed execution - - Composition via substitution rather than function composition - -**Separation of Config and Weights** - The `init` field in surgery specs controls weight handling (transfer vs random) - but does NOT affect config composition. Config composition is purely structural. - After composition, `init` fields are stripped from complete configs. - -**Composition Semantics** - Surgery specs use declarative (merge) composition, not operational (function) - composition. For "additive" surgeries (modifying existing structure), the - monoid action law holds. For "replacement" surgeries (defining complete new - structure), sequential application differs from composed application by design. - -**Cross-Type Derivation** - When converting between mixer types (e.g., attention โ†’ mamba), geometric - parameters are derived where possible: - - attention.heads โ†’ mamba dimensions (MIL conversion) - - attention.heads โ†’ gdn heads (DIL conversion) +**Partial Surgery (P)** - An incomplete config specifying changes. + May contain ``init`` fields (``transfer`` or ``random``). -Module Structure -================ +**Transition Spec (T)** - A complete config WITH ``init`` fields. + The result of applying surgery to a state. Describes both target + structure and weight initialization mode. + +Algebraic Structure +=================== + +**Monoid**: Partial surgeries compose via deep merge:: + + compose_configs : P ร— P โ†’ P + +**Action**: Surgeries act on states to produce transition specs:: + + compose_configs : S ร— P โ†’ T + compose_configs : T ร— P โ†’ T + +**Extraction**: Strip init to get a state:: + + strip_init_fields : T โ†’ S + +**Planning**: Build weight transformation from source state + transition spec:: + + plan_surgery : S ร— T โ†’ Plan -- `config.py` - Config composition (compose_configs, apply_surgery) -- `converters.py` - Plan builders (plan_surgery, plan_mil_attention_to_mamba, etc.) -- `expr.py` - Expression types and plan class (Ref, Slice, Concat, Init, ExprPlan) -- `executor.py` - Plan execution (StreamingExecutor, execute) -- `io.py` - Streaming I/O (SafetensorLoader, ShardedSafetensorWriter) -- `llava/` - Source-specific converter for Llava โ†’ Apriel2 +The ``init`` Field +================== -Example Usage +The ``init`` field in surgeries specifies weight initialization: + +- ``init: transfer`` โ†’ transfer/convert weights from source +- ``init: random`` โ†’ randomly initialize weights + +This field is preserved through ``compose_configs`` so ``plan_surgery`` can read it. +Use ``strip_init_fields`` before saving configs to disk. + +Typical Usage ============= +:: + from fast_llm_external_models.apriel2.conversion import ( compose_configs, plan_surgery, + strip_init_fields, execute, ) - # 1. Compose configs to get target architecture - target_config = compose_configs(source_config, surgery_spec) + # Load source state + source_state = load_config(...) # S - # 2. Build plan for weight transformation - plan = plan_surgery(source_config, surgery_spec) + # Apply surgery + surgery = {"decoder": {"block": {"mixer": {"type": "gdn", "init": "random"}}}} # P + transition = compose_configs(source_state, surgery) # T - # 3. Execute plan to transform weights - target_weights = execute(plan, source_weights, seed=42) + # Build and execute plan + plan = plan_surgery(source_state, transition) + weights = execute(plan, source_weights, seed=42) -For streaming I/O with large models: + # Save (strip init first) + target_state = strip_init_fields(transition) # S + save_config(target_state) - from fast_llm_external_models.apriel2.conversion import ( - StreamingExecutor, - SafetensorLoader, - ShardedSafetensorWriter, - ) +For chained surgeries:: + + current_state = source_state # S + current_plan = identity_plan + + for surgery in surgery_chain: # each P + transition = compose_configs(current_state, surgery) # T + plan = plan_surgery(current_state, transition) + current_plan = compose(current_plan, plan) + current_state = strip_init_fields(transition) # S <- IMPORTANT! + +**Note**: The ``strip_init_fields`` call is critical. It ensures that ``init: random`` +applies only to the surgery that introduces a component. Without stripping, subsequent +surgeries would re-randomize existing components. See ``config.py`` docstring for details. + +Key Design Decisions +==================== + +**Declarative Plans** + Plans are data (expressions), not functions. Enables inspection, + serialization, and composition via substitution. + +**Inheritance Semantics** + When S ร— P โ†’ T, unspecified fields inherit from source. + Cross-type derivation maps geometry (attention.heads โ†’ gdn.value_heads). - with SafetensorLoader(source_files) as loader: - executor = StreamingExecutor(plan, loader) - with ShardedSafetensorWriter(output_dir) as writer: - for key, tensor in executor.execute(seed=42): - writer.add(key, tensor) +**Additive vs Replacement Surgeries** + Additive surgeries (no ``type:`` declaration) satisfy the action law. + Replacement surgeries (explicit ``type:``) use last-write-wins. + +Module Structure +================ + +- ``config.py`` - Config composition (compose_configs, strip_init_fields) +- ``converters.py`` - Plan builders (plan_surgery, plan_mil_attention_to_mamba) +- ``expr.py`` - Expression types (Ref, Slice, Concat, Init, ExprPlan) +- ``executor.py`` - Plan execution (StreamingExecutor, execute) +- ``io.py`` - Streaming I/O (SafetensorLoader, ShardedSafetensorWriter) """ +# Config composition +from fast_llm_external_models.apriel2.conversion.config import compose_configs, strip_init_fields + +# Plan builders (generic) +from fast_llm_external_models.apriel2.conversion.converters import ( + plan_dil_attention_to_gdn, + plan_kil_attention_to_kda, + plan_mil_attention_to_mamba, + plan_surgery, +) + +# Execution +from fast_llm_external_models.apriel2.conversion.executor import MAX_SEED, StreamingExecutor, execute + # Core types and plan operations from fast_llm_external_models.apriel2.conversion.expr import ( Concat, @@ -104,13 +154,6 @@ substitute, ) -# Execution -from fast_llm_external_models.apriel2.conversion.executor import ( - MAX_SEED, - StreamingExecutor, - execute, -) - # I/O utilities from fast_llm_external_models.apriel2.conversion.io import ( DEFAULT_MAX_SHARD_SIZE, @@ -118,22 +161,9 @@ ShardedSafetensorWriter, ) -# Plan builders (generic) -from fast_llm_external_models.apriel2.conversion.converters import ( - plan_mil_attention_to_mamba, - plan_dil_attention_to_gdn, - plan_kil_attention_to_kda, - plan_surgery, -) - -# Config composition -from fast_llm_external_models.apriel2.conversion.config import compose_configs - # Source-specific converters -from fast_llm_external_models.apriel2.conversion.llava import ( - convert_config as convert_llava_config, - plan_llava_to_apriel2, -) +from fast_llm_external_models.apriel2.conversion.llava import convert_config as convert_llava_config +from fast_llm_external_models.apriel2.conversion.llava import plan_llava_to_apriel2 # Rendering (optional, imported lazily by ExprPlan.render_tree) # from fast_llm_external_models.apriel2.conversion.render import render_tree @@ -175,6 +205,7 @@ "plan_kil_attention_to_kda", # Config composition "compose_configs", + "strip_init_fields", # Source-specific converters "convert_llava_config", "plan_llava_to_apriel2", diff --git a/fast_llm_external_models/apriel2/conversion/config.py b/fast_llm_external_models/apriel2/conversion/config.py index 48f8ff44b..3752688c1 100644 --- a/fast_llm_external_models/apriel2/conversion/config.py +++ b/fast_llm_external_models/apriel2/conversion/config.py @@ -1,56 +1,136 @@ """Config composition for Apriel2 architecture transformations. -This module handles STRUCTURAL composition of configs, independent of weight handling. -The `init` field in surgery specs is preserved as metadata for the plan builder but -does not affect how configs are composed. +Conceptual Types +================ -Composition Cases -================= +The system operates on three conceptual types, all represented as ``dict``: -compose_configs(base, overlay) handles four cases based on completeness: +**State (S)** + A complete structural description of a model. Has ``hidden_size`` and ``decoder``. + Does NOT contain ``init`` fields. Represents WHAT a model looks like. -1. **Complete + Partial** โ†’ Apply surgery semantics (inheritance, cross-type derivation) -2. **Partial + Partial** โ†’ Deep merge (monoid operation on surgery specs) -3. **Partial + Complete** โ†’ Overlay wins (complete config replaces partial) -4. **Complete + Complete** โ†’ Deep merge, then strip `init` fields + Example: A saved config.json, or a model you're about to transform. -A config is "complete" if it has `hidden_size` and `decoder` (i.e., it's a full model -config, not a surgery spec). +**Partial Surgery (P)** + An incomplete config specifying fields to change. Missing ``hidden_size`` or + ``decoder``. May contain ``init`` fields specifying weight initialization mode. -Surgery Semantics -================= + Example: ``{"decoder": {"block": {"mixer": {"type": "gdn", "init": "random"}}}}`` -When applying a surgery spec to a complete config: +**Transition Spec (T)** + A complete config WITH ``init`` fields. Describes both the target structure + AND how to initialize weights. This is the output of applying a surgery to + a state - it's a complete specification of the transformation. -**Inheritance** - Unspecified parameters inherit from the source config. New blocks inherit - from the "default" block (first block in pattern, or the single fixed block). + Example: The result of ``compose_configs(state, surgery)`` before stripping. -**Cross-Type Derivation** - When changing mixer types, geometric parameters are derived where possible: - - attention โ†’ sliding_window: preserve heads, head_groups, head_size - - attention โ†’ gdn: heads โ†’ value_heads, head_groups โ†’ key_heads - - attention โ†’ mamba: derive d_inner, d_xb, dt_rank from hidden_size - - attention โ†’ kda: preserve heads, head_size โ†’ head_dim +The distinction between S and T is semantic (presence of ``init``), not structural. +Both are "complete" in the sense of having ``hidden_size`` and ``decoder``. -**Stochastic Mixer Composition** - Two semantics based on whether surgery declares `type: stochastic`: - - Replacement: surgery declares type โ†’ only surgery's sub-mixers included - - Additive: surgery omits type โ†’ source sub-mixers preserved, surgery adds/modifies +Algebraic Structure +=================== - This distinction means the monoid action law holds for additive surgeries but - intentionally fails for replacement surgeries (they have "last-write-wins" semantics). +**Partial Surgeries form a Monoid (P, โˆ˜, {})**:: -The `init` Field -================ + compose_configs : P ร— P โ†’ P (deep merge, overlay wins) + + Identity: compose_configs(p, {}) = compose_configs({}, p) = p + Associativity: compose_configs(compose_configs(a, b), c) + = compose_configs(a, compose_configs(b, c)) + +**Surgeries act on States to produce Transition Specs**:: + + compose_configs : S ร— P โ†’ T (apply surgery with inheritance) + compose_configs : T ร— P โ†’ T (extend transition with more surgery) + +**Action Law (for additive surgeries)**:: + + compose_configs(compose_configs(s, pโ‚), pโ‚‚) = compose_configs(s, compose_configs(pโ‚, pโ‚‚)) + +This law holds when surgeries are "additive" (modifying existing structure without +declaring new types). For "replacement" surgeries (explicitly declaring ``type:``), +the action law intentionally fails - this is last-write-wins semantics. + +**State Extraction**:: + + strip_init_fields : T โ†’ S (remove init metadata for saving) + +Operations Summary +================== + +``compose_configs(base, overlay)`` dispatches based on completeness: + +1. **S ร— P โ†’ T** : Apply surgery to state (inheritance, cross-type derivation) +2. **T ร— P โ†’ T** : Extend transition spec with more surgery +3. **P ร— P โ†’ P** : Merge partial surgeries (monoid operation) +4. **S ร— S โ†’ S** : Merge states (deep merge, rare) +5. **P ร— S โ†’ S** : Overlay wins (complete replaces partial) + +``strip_init_fields(config)`` removes all ``init`` fields, converting T โ†’ S. + +Inheritance Semantics +===================== + +When applying a surgery (S ร— P โ†’ T): + +- Unspecified fields inherit from source state +- New decoder blocks inherit from the "default" block +- Cross-type derivation maps geometry (attention.heads โ†’ gdn.value_heads, etc.) +- Stochastic mixers: additive surgery preserves source mixers, replacement replaces + +The ``init`` Field +================== + +The ``init`` field specifies weight initialization mode for ``plan_surgery()``: + +- ``init: transfer`` โ†’ transfer weights from source (possibly with conversion) +- ``init: random`` โ†’ randomly initialize weights + +**Key invariant**: ``init`` is preserved through composition so ``plan_surgery()`` +can read it. Use ``strip_init_fields()`` to obtain a pure state for: + +- Saving to disk (config.json should not contain ``init``) +- Starting the next surgery iteration (current_state should be S, not T) + +Typical Usage Pattern +===================== + +:: + + current_state: S = load_config(...) + + for surgery: P in surgery_chain: + transition: T = compose_configs(current_state, surgery) # S ร— P โ†’ T + plan = plan_surgery(current_state, transition) # plan reads init from T + current_state: S = strip_init_fields(transition) # T โ†’ S for next iteration + + save_config(current_state) # S has no init fields + +Sequential vs Merged Surgery Application +======================================== + +**IMPORTANT**: Applying surgeries sequentially (with stripping) differs from merging +surgeries first then applying once. This affects ``init`` semantics: + +**Sequential** (recommended):: -The `init` field is metadata for the plan builder, NOT for config composition: -- `init: transfer` โ†’ plan builder creates weight transfer mappings -- `init: random` โ†’ plan builder creates random initialization + t1 = compose_configs(s, p1) # GDN gets init: random + s1 = strip_init_fields(t1) # GDN loses init + t2 = compose_configs(s1, p2) # GDN has init: None โ†’ transfer mode -After surgery is applied to produce a complete config, ALL `init` fields are stripped. -This ensures configs are purely structural and plan creation is Markovian (depends only -on current config + surgery, not on history). +**Merged**:: + + merged = compose_configs(p1, p2) # GDN keeps init: random from p1 + t = compose_configs(s, merged) # GDN has init: random โ†’ random mode + +The sequential approach means ``init: random`` applies **only to the surgery that +introduces a component**. Subsequent surgeries transfer existing weights by default. + +This is the intended behavior: if surgery 1 adds GDN with random init, and surgery 2 +adds sliding window (not mentioning GDN), GDN keeps its weights from surgery 1. + +The merged approach would re-randomize GDN in every execution, which is rarely desired. +Always use the sequential pattern shown in "Typical Usage Pattern" above. """ from __future__ import annotations @@ -65,14 +145,42 @@ def is_complete(config: dict) -> bool: def compose_configs(base: dict, overlay: dict | None) -> dict: - """Compose two configs. + """Compose configs. Dispatches based on completeness of arguments. + + Type Signatures (see module docstring for S, P, T definitions):: + + S ร— P โ†’ T Apply surgery to state, get transition spec + T ร— P โ†’ T Extend transition spec with more surgery + P ร— P โ†’ P Merge partial surgeries (monoid operation) + S ร— S โ†’ S Merge states (deep merge) + P ร— S โ†’ S Overlay wins + + The ``init`` field is preserved in all cases. Use ``strip_init_fields()`` + to convert T โ†’ S for saving or iteration. Args: - base: Base config (complete or partial surgery spec). - overlay: Overlay config (complete or partial surgery spec). + base: State (S), transition spec (T), or partial surgery (P). + overlay: Partial surgery (P) or state (S). Returns: - Composed config. + Composed config. Type depends on inputs (see signatures above). + + Algebraic Properties: + Monoid: ``compose(compose(p1, p2), p3) == compose(p1, compose(p2, p3))`` + + Action law (additive surgeries): + ``compose(compose(s, p1), p2) == compose(s, compose(p1, p2))`` + + Example:: + + # S ร— P โ†’ T (apply surgery to state) + state = {"hidden_size": 256, "decoder": {...}} + surgery = {"decoder": {"block": {"mixer": {"init": "random"}}}} + transition = compose_configs(state, surgery) # T, has init + + # Build plan, then extract state + plan = plan_surgery(state, transition) + new_state = strip_init_fields(transition) # S, no init """ if not overlay: return copy.deepcopy(base) @@ -94,9 +202,8 @@ def compose_configs(base: dict, overlay: dict | None) -> dict: if not base_complete and overlay_complete: return copy.deepcopy(overlay) - # Case 4: Both complete -> deep merge + # Case 4: Both complete -> deep merge (init preserved for plan_surgery) result = _deep_merge(base, overlay) - _strip_keys(result, {"init"}) return result @@ -128,26 +235,53 @@ def _strip_keys(config: Any, keys_to_strip: set[str]) -> None: _strip_keys(item, keys_to_strip) +def strip_init_fields(config: dict) -> dict: + """Return a copy of config with all ``init`` fields stripped (T โ†’ S). + + Converts a transition spec (T) to a state (S) by removing ``init`` metadata. + Use this: + + 1. Before saving configs to disk (config.json should be purely structural) + 2. Between surgery iterations (so subsequent surgeries don't re-randomize) + + See module docstring section "Sequential vs Merged Surgery Application" for + why stripping between iterations is critical. + + Args: + config: Config dict (not modified). Typically a transition spec (T). + + Returns: + A deep copy with all ``init`` fields recursively removed (a state S). + """ + result = copy.deepcopy(config) + _strip_keys(result, {"init"}) + return result + + # ============================================================================= # Surgery application with full semantics # ============================================================================= def apply_surgery(source_config: dict, surgery_config: dict | None) -> dict: - """Apply surgery specification to a complete source config. + """Apply surgery spec to complete config (the monoid action). - This handles: - - Top-level scalar overrides - - Decoder composition (fixed vs pattern) - - Stochastic mixer sub-mixer inheritance - - Cross-type derivation (attention โ†’ gdn, attention โ†’ mamba) + This is the internal implementation of the monoid action: surgery specs + acting on complete configs. Called by compose_configs when base is complete + and overlay is partial. + + Implements inheritance semantics: + - Unspecified fields inherit from source + - Cross-type derivation maps geometry (attention โ†’ gdn, etc.) + - Stochastic sub-mixers inherit from source's main mixer + - `init` fields are PRESERVED for plan_surgery() to see Args: - source_config: Complete Apriel2 config. - surgery_config: Partial surgery specification. + source_config: Complete Apriel2 config (the state being acted on). + surgery_config: Partial surgery spec (the monoid element acting). Returns: - Complete Apriel2 config with surgery applied. + Complete config with surgery applied. `init` fields preserved. """ if not surgery_config: return copy.deepcopy(source_config) @@ -189,8 +323,9 @@ def apply_surgery(source_config: dict, surgery_config: dict | None) -> dict: surgery_config["vision_encoder"], ) - # Strip init keys from final result - _strip_keys(result, {"init"}) + # NOTE: We do NOT strip init keys here. The `init` field is preserved through + # composition so that plan_surgery() can see it and decide between transfer + # vs random initialization. The caller (convert.py) strips init before saving. return result @@ -392,6 +527,12 @@ def _compose_single_mixer(source: dict, surgery: dict, hidden_size: int) -> dict result[key] = surgery[key] elif key in source: result[key] = source[key] + # Copy per-layer bias settings (query_layer, key_layer, value_layer, dense_layer) + for key in ["query_layer", "key_layer", "value_layer", "dense_layer", "add_linear_biases"]: + if key in surgery: + result[key] = surgery[key] + elif key in source: + result[key] = copy.deepcopy(source[key]) # Preserve init if "init" in surgery: result["init"] = surgery["init"] diff --git a/fast_llm_external_models/apriel2/conversion/converters.py b/fast_llm_external_models/apriel2/conversion/converters.py index 6d1350c54..9c9238bb0 100644 --- a/fast_llm_external_models/apriel2/conversion/converters.py +++ b/fast_llm_external_models/apriel2/conversion/converters.py @@ -61,16 +61,7 @@ from __future__ import annotations -from fast_llm_external_models.apriel2.conversion.expr import ( - Concat, - Expr, - ExprPlan, - Init, - Ref, - Slice, - W, -) - +from fast_llm_external_models.apriel2.conversion.expr import Concat, Expr, ExprPlan, Init, Ref, Slice, W # ============================================================================= # SECTION 1: Per-Mixer Plan Functions @@ -79,6 +70,21 @@ # This is the single source of truth for each mixer's weight schema. +def _get_attention_bias_enabled(config: dict, layer_name: str) -> bool: + """Get whether bias is enabled for an attention layer. + + Checks per-layer bias config (e.g., query_layer.bias.enabled). + Falls back to add_linear_biases if not set. + """ + layer_cfg = config.get(layer_name, {}) + bias_cfg = layer_cfg.get("bias", {}) + enabled = bias_cfg.get("enabled") + if enabled is not None: + return enabled + # Fall back to add_linear_biases + return config.get("add_linear_biases", False) + + def _plan_attention_mixer( *, prefix: W, @@ -90,9 +96,13 @@ def _plan_attention_mixer( Weight schema: - q_proj.weight: (q_size, hidden_size) + - q_proj.bias: (q_size,) [if query_layer.bias.enabled] - k_proj.weight: (kv_size, hidden_size) + - k_proj.bias: (kv_size,) [if key_layer.bias.enabled] - v_proj.weight: (kv_size, hidden_size) + - v_proj.bias: (kv_size,) [if value_layer.bias.enabled] - o_proj.weight: (hidden_size, q_size) + - o_proj.bias: (hidden_size,) [if dense_layer.bias.enabled] Args: prefix: Target weight path prefix. @@ -100,12 +110,28 @@ def _plan_attention_mixer( hidden_size: Model hidden size. source_prefix: If provided, passthrough from source. If None, random init. """ + # Check per-layer bias configuration + q_bias = _get_attention_bias_enabled(config, "query_layer") + k_bias = _get_attention_bias_enabled(config, "key_layer") + v_bias = _get_attention_bias_enabled(config, "value_layer") + o_bias = _get_attention_bias_enabled(config, "dense_layer") + if source_prefix is not None: - # Passthrough - return ExprPlan(mappings={ + # Passthrough weights + mappings: dict[W, Expr] = { prefix / proj / "weight": Ref(key=source_prefix / proj / "weight") for proj in ["q_proj", "k_proj", "v_proj", "o_proj"] - }) + } + # Passthrough biases if enabled + if q_bias: + mappings[prefix / "q_proj" / "bias"] = Ref(key=source_prefix / "q_proj" / "bias") + if k_bias: + mappings[prefix / "k_proj" / "bias"] = Ref(key=source_prefix / "k_proj" / "bias") + if v_bias: + mappings[prefix / "v_proj" / "bias"] = Ref(key=source_prefix / "v_proj" / "bias") + if o_bias: + mappings[prefix / "o_proj" / "bias"] = Ref(key=source_prefix / "o_proj" / "bias") + return ExprPlan(mappings=mappings) # Random init heads = config["heads"] @@ -114,12 +140,22 @@ def _plan_attention_mixer( q_size = heads * head_size kv_size = head_groups * head_size - return ExprPlan(mappings={ + mappings = { prefix / "q_proj" / "weight": Init(shape=(q_size, hidden_size), init_type="kaiming"), prefix / "k_proj" / "weight": Init(shape=(kv_size, hidden_size), init_type="kaiming"), prefix / "v_proj" / "weight": Init(shape=(kv_size, hidden_size), init_type="kaiming"), prefix / "o_proj" / "weight": Init(shape=(hidden_size, q_size), init_type="kaiming"), - }) + } + # Random init biases if enabled + if q_bias: + mappings[prefix / "q_proj" / "bias"] = Init(shape=(q_size,), init_type="zeros") + if k_bias: + mappings[prefix / "k_proj" / "bias"] = Init(shape=(kv_size,), init_type="zeros") + if v_bias: + mappings[prefix / "v_proj" / "bias"] = Init(shape=(kv_size,), init_type="zeros") + if o_bias: + mappings[prefix / "o_proj" / "bias"] = Init(shape=(hidden_size,), init_type="zeros") + return ExprPlan(mappings=mappings) def _plan_mamba_mixer( @@ -150,20 +186,22 @@ def _plan_mamba_mixer( """ if source_prefix is not None: # Passthrough - include all possible weights - return ExprPlan(mappings={ - prefix / name: Ref(key=source_prefix / name) - for name in [ - "in_proj.weight", - "out_proj.weight", - "dt_in_proj.weight", - "dt_proj.weight", - "dt_proj.bias", - "conv1d.weight", - "conv1d.bias", - "A_log", - "D", - ] - }) + return ExprPlan( + mappings={ + prefix / name: Ref(key=source_prefix / name) + for name in [ + "in_proj.weight", + "out_proj.weight", + "dt_in_proj.weight", + "dt_proj.weight", + "dt_proj.bias", + "conv1d.weight", + "conv1d.bias", + "A_log", + "D", + ] + } + ) # Random init d_inner = config["d_inner"] @@ -181,9 +219,7 @@ def _plan_mamba_mixer( conv_channels = d_inner if repeat_kv_before_conv else d_xb mappings: dict[W, Expr] = { - prefix / "in_proj" / "weight": Init( - shape=(2 * d_inner + 2 * d_xb, hidden_size), init_type="kaiming" - ), + prefix / "in_proj" / "weight": Init(shape=(2 * d_inner + 2 * d_xb, hidden_size), init_type="kaiming"), prefix / "out_proj" / "weight": Init(shape=(hidden_size, d_inner), init_type="kaiming"), prefix / "dt_in_proj" / "weight": Init(shape=(dt_rank, hidden_size), init_type="kaiming"), prefix / "dt_proj" / "weight": Init(shape=(d_inner, dt_rank), init_type="kaiming"), @@ -230,18 +266,20 @@ def _plan_gdn_mixer( """ if source_prefix is not None: # Passthrough - return ExprPlan(mappings={ - prefix / name: Ref(key=source_prefix / name) - for name in [ - "in_proj_qkvz.weight", - "in_proj_ba.weight", - "out_proj.weight", - "convolution.weight", - "A_log", - "dt_bias", - "norm.weight", - ] - }) + return ExprPlan( + mappings={ + prefix / name: Ref(key=source_prefix / name) + for name in [ + "in_proj_qkvz.weight", + "in_proj_ba.weight", + "out_proj.weight", + "convolution.weight", + "A_log", + "dt_bias", + "norm.weight", + ] + } + ) # Random init num_v_heads = config["value_heads"] @@ -255,17 +293,19 @@ def _plan_gdn_mixer( conv_dim = key_dim * 2 + value_dim qkvz_size = key_dim * 2 + value_dim * 2 # Q, K both key_dim; V, Z both value_dim - return ExprPlan(mappings={ - prefix / "in_proj_qkvz" / "weight": Init(shape=(qkvz_size, hidden_size), init_type="kaiming"), - prefix / "in_proj_ba" / "weight": Init(shape=(num_v_heads * 2, hidden_size), init_type="zeros"), - prefix / "out_proj" / "weight": Init(shape=(hidden_size, value_dim), init_type="kaiming"), - prefix / "convolution" / "weight": Init( - shape=(conv_dim, 1, conv_kernel_size), init_type="scaled_identity_conv" - ), - prefix / "A_log": Init(shape=(num_v_heads,), init_type="slow_decay"), - prefix / "dt_bias": Init(shape=(num_v_heads,), init_type="zeros"), - prefix / "norm" / "weight": Init(shape=(head_v_dim,), init_type="ones"), - }) + return ExprPlan( + mappings={ + prefix / "in_proj_qkvz" / "weight": Init(shape=(qkvz_size, hidden_size), init_type="kaiming"), + prefix / "in_proj_ba" / "weight": Init(shape=(num_v_heads * 2, hidden_size), init_type="zeros"), + prefix / "out_proj" / "weight": Init(shape=(hidden_size, value_dim), init_type="kaiming"), + prefix + / "convolution" + / "weight": Init(shape=(conv_dim, 1, conv_kernel_size), init_type="scaled_identity_conv"), + prefix / "A_log": Init(shape=(num_v_heads,), init_type="slow_decay"), + prefix / "dt_bias": Init(shape=(num_v_heads,), init_type="zeros"), + prefix / "norm" / "weight": Init(shape=(head_v_dim,), init_type="ones"), + } + ) def _plan_kda_mixer( @@ -298,26 +338,28 @@ def _plan_kda_mixer( """ if source_prefix is not None: # Passthrough - return ExprPlan(mappings={ - prefix / name: Ref(key=source_prefix / name) - for name in [ - "q_proj.weight", - "k_proj.weight", - "v_proj.weight", - "o_proj.weight", - "q_conv.weight", - "k_conv.weight", - "v_conv.weight", - "f_a_proj.weight", - "f_b_proj.weight", - "g_a_proj.weight", - "g_b_proj.weight", - "beta_proj.weight", - "A_log", - "dt_bias", - "norm.weight", - ] - }) + return ExprPlan( + mappings={ + prefix / name: Ref(key=source_prefix / name) + for name in [ + "q_proj.weight", + "k_proj.weight", + "v_proj.weight", + "o_proj.weight", + "q_conv.weight", + "k_conv.weight", + "v_conv.weight", + "f_a_proj.weight", + "f_b_proj.weight", + "g_a_proj.weight", + "g_b_proj.weight", + "beta_proj.weight", + "A_log", + "dt_bias", + "norm.weight", + ] + } + ) # Random init num_heads = config["heads"] @@ -325,36 +367,38 @@ def _plan_kda_mixer( projection_size = num_heads * head_dim conv_kernel_size = config.get("convolution_layer", {}).get("kernel_size", 4) - return ExprPlan(mappings={ - # Main projections - prefix / "q_proj" / "weight": Init(shape=(projection_size, hidden_size), init_type="kaiming"), - prefix / "k_proj" / "weight": Init(shape=(projection_size, hidden_size), init_type="kaiming"), - prefix / "v_proj" / "weight": Init(shape=(projection_size, hidden_size), init_type="kaiming"), - prefix / "o_proj" / "weight": Init(shape=(hidden_size, projection_size), init_type="kaiming"), - # Convolutions - prefix / "q_conv" / "weight": Init( - shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv" - ), - prefix / "k_conv" / "weight": Init( - shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv" - ), - prefix / "v_conv" / "weight": Init( - shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv" - ), - # Gate kernels (low-rank factorization) - prefix / "f_a_proj" / "weight": Init(shape=(head_dim, hidden_size), init_type="kaiming"), - prefix / "f_b_proj" / "weight": Init(shape=(projection_size, head_dim), init_type="kaiming"), - # Output gate (low-rank factorization) - prefix / "g_a_proj" / "weight": Init(shape=(head_dim, hidden_size), init_type="kaiming"), - prefix / "g_b_proj" / "weight": Init(shape=(projection_size, head_dim), init_type="kaiming"), - # Beta projection - prefix / "beta_proj" / "weight": Init(shape=(num_heads, hidden_size), init_type="kaiming"), - # Learnable parameters - prefix / "A_log": Init(shape=(num_heads,), init_type="slow_decay"), - prefix / "dt_bias": Init(shape=(projection_size,), init_type="zeros"), - # Normalization - prefix / "norm" / "weight": Init(shape=(head_dim,), init_type="ones"), - }) + return ExprPlan( + mappings={ + # Main projections + prefix / "q_proj" / "weight": Init(shape=(projection_size, hidden_size), init_type="kaiming"), + prefix / "k_proj" / "weight": Init(shape=(projection_size, hidden_size), init_type="kaiming"), + prefix / "v_proj" / "weight": Init(shape=(projection_size, hidden_size), init_type="kaiming"), + prefix / "o_proj" / "weight": Init(shape=(hidden_size, projection_size), init_type="kaiming"), + # Convolutions + prefix + / "q_conv" + / "weight": Init(shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv"), + prefix + / "k_conv" + / "weight": Init(shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv"), + prefix + / "v_conv" + / "weight": Init(shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv"), + # Gate kernels (low-rank factorization) + prefix / "f_a_proj" / "weight": Init(shape=(head_dim, hidden_size), init_type="kaiming"), + prefix / "f_b_proj" / "weight": Init(shape=(projection_size, head_dim), init_type="kaiming"), + # Output gate (low-rank factorization) + prefix / "g_a_proj" / "weight": Init(shape=(head_dim, hidden_size), init_type="kaiming"), + prefix / "g_b_proj" / "weight": Init(shape=(projection_size, head_dim), init_type="kaiming"), + # Beta projection + prefix / "beta_proj" / "weight": Init(shape=(num_heads, hidden_size), init_type="kaiming"), + # Learnable parameters + prefix / "A_log": Init(shape=(num_heads,), init_type="slow_decay"), + prefix / "dt_bias": Init(shape=(projection_size,), init_type="zeros"), + # Normalization + prefix / "norm" / "weight": Init(shape=(head_dim,), init_type="ones"), + } + ) # Dispatcher for per-mixer plan functions @@ -409,16 +453,13 @@ def plan_mil_attention_to_mamba( exprs=( Init(shape=(d_inner, hidden_size), init_type="kaiming"), # z: random Slice( - expr=Ref(key=source_prefix / "v_proj" / "weight"), - slices=((0, d_xb, None), (None, None, None)) + expr=Ref(key=source_prefix / "v_proj" / "weight"), slices=((0, d_xb, None), (None, None, None)) ), # x <- V Slice( - expr=Ref(key=source_prefix / "k_proj" / "weight"), - slices=((0, d_xb, None), (None, None, None)) + expr=Ref(key=source_prefix / "k_proj" / "weight"), slices=((0, d_xb, None), (None, None, None)) ), # B <- K Slice( - expr=Ref(key=source_prefix / "q_proj" / "weight"), - slices=((0, d_inner, None), (None, None, None)) + expr=Ref(key=source_prefix / "q_proj" / "weight"), slices=((0, d_inner, None), (None, None, None)) ), # C <- Q ), dim=0, @@ -532,19 +573,21 @@ def plan_dil_attention_to_gdn( dim=0, ) - return ExprPlan(mappings={ - target_prefix / "in_proj_qkvz" / "weight": in_proj_qkvz_expr, - target_prefix / "in_proj_ba" / "weight": Init( - shape=(2 * num_v_heads, hidden_size), init_type="zeros" - ), # b=a=0 โ†’ ฮฒ=0.5 - target_prefix / "out_proj" / "weight": Ref(key=source_prefix / "o_proj" / "weight"), - target_prefix / "convolution" / "weight": Init( - shape=(conv_dim, 1, conv_kernel_size), init_type="scaled_identity_conv" - ), - target_prefix / "A_log": Init(shape=(num_v_heads,), init_type="slow_decay"), - target_prefix / "dt_bias": Init(shape=(num_v_heads,), init_type="zeros"), - target_prefix / "norm" / "weight": Init(shape=(head_v_dim,), init_type="ones"), - }) + return ExprPlan( + mappings={ + target_prefix / "in_proj_qkvz" / "weight": in_proj_qkvz_expr, + target_prefix + / "in_proj_ba" + / "weight": Init(shape=(2 * num_v_heads, hidden_size), init_type="zeros"), # b=a=0 โ†’ ฮฒ=0.5 + target_prefix / "out_proj" / "weight": Ref(key=source_prefix / "o_proj" / "weight"), + target_prefix + / "convolution" + / "weight": Init(shape=(conv_dim, 1, conv_kernel_size), init_type="scaled_identity_conv"), + target_prefix / "A_log": Init(shape=(num_v_heads,), init_type="slow_decay"), + target_prefix / "dt_bias": Init(shape=(num_v_heads,), init_type="zeros"), + target_prefix / "norm" / "weight": Init(shape=(head_v_dim,), init_type="ones"), + } + ) def plan_kil_attention_to_kda( @@ -595,9 +638,7 @@ def plan_kil_attention_to_kda( for h in range(num_heads): src_h = h % source_num_q_heads row_start = src_h * source_head_dim - q_slices.append( - Slice(expr=q_ref, slices=((row_start, row_start + head_dim, None), (None, None, None))) - ) + q_slices.append(Slice(expr=q_ref, slices=((row_start, row_start + head_dim, None), (None, None, None)))) q_expr = Concat(exprs=tuple(q_slices), dim=0) # K: tile source KV heads to fill target projection_size @@ -608,9 +649,7 @@ def plan_kil_attention_to_kda( for h in range(num_heads): src_h = h % source_num_kv_heads row_start = src_h * source_head_dim - k_slices.append( - Slice(expr=k_ref, slices=((row_start, row_start + head_dim, None), (None, None, None))) - ) + k_slices.append(Slice(expr=k_ref, slices=((row_start, row_start + head_dim, None), (None, None, None)))) k_expr = Concat(exprs=tuple(k_slices), dim=0) # V: tile source KV heads to fill target projection_size @@ -621,41 +660,41 @@ def plan_kil_attention_to_kda( for h in range(num_heads): src_h = h % source_num_kv_heads row_start = src_h * source_head_dim - v_slices.append( - Slice(expr=v_ref, slices=((row_start, row_start + head_dim, None), (None, None, None))) - ) + v_slices.append(Slice(expr=v_ref, slices=((row_start, row_start + head_dim, None), (None, None, None)))) v_expr = Concat(exprs=tuple(v_slices), dim=0) - return ExprPlan(mappings={ - # Transfer main projections - target_prefix / "q_proj" / "weight": q_expr, - target_prefix / "k_proj" / "weight": k_expr, - target_prefix / "v_proj" / "weight": v_expr, - target_prefix / "o_proj" / "weight": Ref(key=source_prefix / "o_proj" / "weight"), - # Random init: convolutions (scaled identity for near-passthrough initially) - target_prefix / "q_conv" / "weight": Init( - shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv" - ), - target_prefix / "k_conv" / "weight": Init( - shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv" - ), - target_prefix / "v_conv" / "weight": Init( - shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv" - ), - # Random init: gate kernels (low-rank factorization) - target_prefix / "f_a_proj" / "weight": Init(shape=(head_dim, hidden_size), init_type="kaiming"), - target_prefix / "f_b_proj" / "weight": Init(shape=(projection_size, head_dim), init_type="kaiming"), - # Random init: output gate (low-rank factorization) - target_prefix / "g_a_proj" / "weight": Init(shape=(head_dim, hidden_size), init_type="kaiming"), - target_prefix / "g_b_proj" / "weight": Init(shape=(projection_size, head_dim), init_type="kaiming"), - # Random init: beta projection - target_prefix / "beta_proj" / "weight": Init(shape=(num_heads, hidden_size), init_type="kaiming"), - # Random init: learnable parameters - target_prefix / "A_log": Init(shape=(num_heads,), init_type="slow_decay"), - target_prefix / "dt_bias": Init(shape=(projection_size,), init_type="zeros"), - # Random init: normalization - target_prefix / "norm" / "weight": Init(shape=(head_dim,), init_type="ones"), - }) + return ExprPlan( + mappings={ + # Transfer main projections + target_prefix / "q_proj" / "weight": q_expr, + target_prefix / "k_proj" / "weight": k_expr, + target_prefix / "v_proj" / "weight": v_expr, + target_prefix / "o_proj" / "weight": Ref(key=source_prefix / "o_proj" / "weight"), + # Random init: convolutions (scaled identity for near-passthrough initially) + target_prefix + / "q_conv" + / "weight": Init(shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv"), + target_prefix + / "k_conv" + / "weight": Init(shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv"), + target_prefix + / "v_conv" + / "weight": Init(shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv"), + # Random init: gate kernels (low-rank factorization) + target_prefix / "f_a_proj" / "weight": Init(shape=(head_dim, hidden_size), init_type="kaiming"), + target_prefix / "f_b_proj" / "weight": Init(shape=(projection_size, head_dim), init_type="kaiming"), + # Random init: output gate (low-rank factorization) + target_prefix / "g_a_proj" / "weight": Init(shape=(head_dim, hidden_size), init_type="kaiming"), + target_prefix / "g_b_proj" / "weight": Init(shape=(projection_size, head_dim), init_type="kaiming"), + # Random init: beta projection + target_prefix / "beta_proj" / "weight": Init(shape=(num_heads, hidden_size), init_type="kaiming"), + # Random init: learnable parameters + target_prefix / "A_log": Init(shape=(num_heads,), init_type="slow_decay"), + target_prefix / "dt_bias": Init(shape=(projection_size,), init_type="zeros"), + # Random init: normalization + target_prefix / "norm" / "weight": Init(shape=(head_dim,), init_type="ones"), + } + ) # ============================================================================= @@ -786,7 +825,70 @@ def plan_surgery( source_config: dict, target_config: dict, ) -> ExprPlan: - """Build plan for Apriel2โ†’Apriel2 surgery (MIL, DIL, KIL, stochastic mixers, etc.).""" + """Build a weight conversion plan: S ร— T โ†’ Plan. + + Creates an ExprPlan mapping target weight keys to expressions over source weights. + Handles same-type passthrough, cross-type conversions (MIL, DIL, KIL), and + stochastic mixer routing. + + Type Signature:: + + plan_surgery : S ร— T โ†’ Plan + + Where S is a state (source) and T is a transition spec (target with ``init`` fields). + + The ``init`` Field + ------------------ + + The ``init`` field in ``target_config`` controls weight initialization: + + - ``init: transfer`` (or absent) โ†’ create Ref expressions (transfer from source) + - ``init: random`` โ†’ create Init expressions (random initialization) + + This is why ``target_config`` should be a transition spec (T) from ``compose_configs``, + not a stripped state (S). If ``init`` fields are missing, all components default to + transfer mode. + + Args: + source_config: State (S) - complete config describing source architecture. + Must have hidden_size, decoder, etc. No ``init`` fields expected. + target_config: Transition spec (T) - complete config with ``init`` fields. + Use ``compose_configs(source, surgery)`` to produce this. + + Returns: + ExprPlan mapping target weight keys to expressions over source weights. + + Example:: + + # Apply a surgery that wraps attention in a stochastic mixer + surgery_spec = { + "decoder": {"block": {"mixer": { + "type": "stochastic", + "mixers": { + "attention": {"init": "transfer"}, + "gdn": {"type": "gdn", "init": "random"}, + } + }}} + } + + # S ร— P โ†’ T + transition = compose_configs(source_config, surgery_spec) + + # S ร— T โ†’ Plan + plan = plan_surgery(source_config, transition) + + # Execute + new_weights = execute(plan, source_weights, seed=42) + + # T โ†’ S for saving + target_state = strip_init_fields(transition) + + Note: + Both arguments must be complete (have hidden_size and decoder). + The target_config should retain ``init`` fields from the surgery spec. + Passing a stripped state as target will cause all components to use + transfer mode, which may not be intended. + """ hidden_size = target_config.get("hidden_size", source_config.get("hidden_size")) assert hidden_size is not None, "hidden_size must be specified in source or target config" @@ -804,18 +906,24 @@ def plan_surgery( target_block = _get_block_config(target_decoder, target_layer_idx) plan += _plan_mixer( - target_layer_idx, source_layer_idx, - source_block.get("mixer", {}), target_block.get("mixer", {}), + target_layer_idx, + source_layer_idx, + source_block.get("mixer", {}), + target_block.get("mixer", {}), hidden_size, ) plan += _plan_mlp( - target_layer_idx, source_layer_idx, - source_block.get("mlp", {}), target_block.get("mlp", {}), + target_layer_idx, + source_layer_idx, + source_block.get("mlp", {}), + target_block.get("mlp", {}), hidden_size, ) plan += _plan_norms( - target_layer_idx, source_layer_idx, - source_block, target_block, + target_layer_idx, + source_layer_idx, + source_block, + target_block, hidden_size, ) @@ -839,14 +947,16 @@ def _plan_non_decoder_weights(config: dict) -> ExprPlan: embed = W("model", "embed_tokens", "weight") mappings[embed] = Ref(key=embed) - head = W("lm_head", "weight") - mappings[head] = Ref(key=head) + # lm_head only if not tied to embeddings + if not config.get("tie_word_embeddings", False): + head = W("lm_head", "weight") + mappings[head] = Ref(key=head) norm = W("model", "norm", "weight") mappings[norm] = Ref(key=norm) - if "vision_encoder" in config: - vision_config = config["vision_encoder"] + vision_config = config.get("vision_encoder") + if vision_config: vision = W("model", "vision_encoder") patch_emb = vision / "embeddings" / "patch_embeddings" / "weight" @@ -950,9 +1060,13 @@ def _plan_mixer( source_prefix = source_mixer_base plan += _plan_mixer_transfer( - matched_source_type, sub_type, - matched_source, sub_config, - source_prefix, target_prefix, hidden_size, + matched_source_type, + sub_type, + matched_source, + sub_config, + source_prefix, + target_prefix, + hidden_size, ) # Passthrough source sub-mixers not in target spec @@ -963,8 +1077,13 @@ def _plan_mixer( source_prefix = source_layer / "mixer" / "mixers" / sub_name target_prefix = target_layer / "mixer" / "mixers" / sub_name plan += _plan_mixer_transfer( - sub_type, sub_type, sub_config, sub_config, - source_prefix, target_prefix, hidden_size, + sub_type, + sub_type, + sub_config, + sub_config, + source_prefix, + target_prefix, + hidden_size, ) return plan @@ -980,12 +1099,34 @@ def _plan_mixer( source_prefix = source_layer / "mixer" return _plan_mixer_transfer( - main_source_type, target_type, - main_source, target_mixer, - source_prefix, target_prefix, hidden_size, + main_source_type, + target_type, + main_source, + target_mixer, + source_prefix, + target_prefix, + hidden_size, ) +def _get_mlp_bias_enabled(config: dict, layer_name: str) -> bool: + """Get whether bias is enabled for an MLP layer. + + Checks per-layer bias config (e.g., layer_1.bias.enabled, layer_2.bias.enabled). + Falls back to add_linear_biases if not set. + + Note: layer_1 corresponds to gate_proj and up_proj (gated MLP) or just up_proj (non-gated) + layer_2 corresponds to down_proj + """ + layer_cfg = config.get(layer_name, {}) + bias_cfg = layer_cfg.get("bias", {}) + enabled = bias_cfg.get("enabled") + if enabled is not None: + return enabled + # Fall back to add_linear_biases + return config.get("add_linear_biases", False) + + def _plan_mlp( target_layer_idx: int, source_layer_idx: int, @@ -1006,7 +1147,7 @@ def _plan_mlp_transfer( target_mlp: dict, hidden_size: int, ) -> ExprPlan: - """Passthrough for MLP weights.""" + """Passthrough for MLP weights and biases.""" source_mlp_path = W("model", "decoder", "blocks", source_layer_idx, "mlp") target_mlp_path = W("model", "decoder", "blocks", target_layer_idx, "mlp") @@ -1019,10 +1160,36 @@ def _plan_mlp_transfer( f"Use 'init: random' to initialize randomly." ) - return ExprPlan(mappings={ - target_mlp_path / proj / "weight": Ref(key=source_mlp_path / proj / "weight") - for proj in ["gate_proj", "up_proj", "down_proj"] - }) + # Check per-layer bias configuration + layer_1_bias = _get_mlp_bias_enabled(target_mlp, "layer_1") + layer_2_bias = _get_mlp_bias_enabled(target_mlp, "layer_2") + + # Check if gated MLP (has gate_proj) or non-gated (only up_proj) + gated = target_mlp.get("gated", True) # Default to gated for backwards compatibility + + # Passthrough weights + # layer_1 = gate_proj + up_proj (gated) or just up_proj (non-gated) + # layer_2 = down_proj + if gated: + weight_projs = ["gate_proj", "up_proj", "down_proj"] + else: + weight_projs = ["up_proj", "down_proj"] + + mappings: dict[W, Expr] = { + target_mlp_path / proj / "weight": Ref(key=source_mlp_path / proj / "weight") for proj in weight_projs + } + + # Passthrough biases if enabled + if layer_1_bias: + if gated: + mappings[target_mlp_path / "gate_proj" / "bias"] = Ref(key=source_mlp_path / "gate_proj" / "bias") + mappings[target_mlp_path / "up_proj" / "bias"] = Ref(key=source_mlp_path / "up_proj" / "bias") + + # layer_2 = down_proj + if layer_2_bias: + mappings[target_mlp_path / "down_proj" / "bias"] = Ref(key=source_mlp_path / "down_proj" / "bias") + + return ExprPlan(mappings=mappings) def _plan_random_mlp( @@ -1030,20 +1197,41 @@ def _plan_random_mlp( target_mlp: dict, hidden_size: int, ) -> ExprPlan: - """Random initialization for MLP.""" + """Random initialization for MLP weights and biases.""" target_mlp_path = W("model", "decoder", "blocks", target_layer_idx, "mlp") intermediate_size = target_mlp["intermediate_size"] - return ExprPlan(mappings={ - target_mlp_path / "gate_proj" / "weight": Init( - shape=(intermediate_size, hidden_size), init_type="kaiming" - ), - target_mlp_path / "up_proj" / "weight": Init( + + # Check per-layer bias configuration + layer_1_bias = _get_mlp_bias_enabled(target_mlp, "layer_1") + layer_2_bias = _get_mlp_bias_enabled(target_mlp, "layer_2") + + # Check if gated MLP (has gate_proj) or non-gated (only up_proj) + gated = target_mlp.get("gated", True) # Default to gated for backwards compatibility + + # Random init weights + mappings: dict[W, Expr] = {} + if gated: + mappings[target_mlp_path / "gate_proj" / "weight"] = Init( shape=(intermediate_size, hidden_size), init_type="kaiming" - ), - target_mlp_path / "down_proj" / "weight": Init( - shape=(hidden_size, intermediate_size), init_type="kaiming" - ), - }) + ) + mappings[target_mlp_path / "up_proj" / "weight"] = Init( + shape=(intermediate_size, hidden_size), init_type="kaiming" + ) + mappings[target_mlp_path / "down_proj" / "weight"] = Init( + shape=(hidden_size, intermediate_size), init_type="kaiming" + ) + + # Random init biases if enabled + if layer_1_bias: + if gated: + mappings[target_mlp_path / "gate_proj" / "bias"] = Init(shape=(intermediate_size,), init_type="zeros") + mappings[target_mlp_path / "up_proj" / "bias"] = Init(shape=(intermediate_size,), init_type="zeros") + + # layer_2 = down_proj + if layer_2_bias: + mappings[target_mlp_path / "down_proj" / "bias"] = Init(shape=(hidden_size,), init_type="zeros") + + return ExprPlan(mappings=mappings) def _plan_norms( @@ -1083,10 +1271,12 @@ def _plan_norms_transfer( f"Use 'init: random' to initialize randomly." ) - return ExprPlan(mappings={ - target_layer / norm_name / "weight": Ref(key=source_layer / norm_name / "weight") - for norm_name in ["input_layernorm", "post_attention_layernorm"] - }) + return ExprPlan( + mappings={ + target_layer / norm_name / "weight": Ref(key=source_layer / norm_name / "weight") + for norm_name in ["input_layernorm", "post_attention_layernorm"] + } + ) def _plan_random_norms( @@ -1095,7 +1285,9 @@ def _plan_random_norms( ) -> ExprPlan: """Random initialization for normalization layers.""" target_layer = W("model", "decoder", "blocks", target_layer_idx) - return ExprPlan(mappings={ - target_layer / norm_name / "weight": Init(shape=(hidden_size,), init_type="ones") - for norm_name in ["input_layernorm", "post_attention_layernorm"] - }) + return ExprPlan( + mappings={ + target_layer / norm_name / "weight": Init(shape=(hidden_size,), init_type="ones") + for norm_name in ["input_layernorm", "post_attention_layernorm"] + } + ) diff --git a/fast_llm_external_models/apriel2/conversion/executor.py b/fast_llm_external_models/apriel2/conversion/executor.py index a6c5672f0..b0779c97f 100644 --- a/fast_llm_external_models/apriel2/conversion/executor.py +++ b/fast_llm_external_models/apriel2/conversion/executor.py @@ -29,7 +29,8 @@ from __future__ import annotations import hashlib -from typing import Callable, Iterator +from collections.abc import Iterator +from typing import Callable import torch from torch import Tensor @@ -81,8 +82,7 @@ def execute( break else: raise ValueError( - "Cannot infer device/dtype: plan has no source references. " - "Provide device and dtype explicitly." + "Cannot infer device/dtype: plan has no source references. " "Provide device and dtype explicitly." ) generator = torch.Generator(device=device) @@ -94,10 +94,7 @@ def execute( # Verify device/dtype consistency for key, tensor in sources.items(): if tensor.device != device or tensor.dtype != dtype: - raise ValueError( - f"Source {key} has {tensor.device}/{tensor.dtype}, " - f"expected {device}/{dtype}" - ) + raise ValueError(f"Source {key} has {tensor.device}/{tensor.dtype}, " f"expected {device}/{dtype}") # Deterministic per-target seed key_offset = int(hashlib.md5(str(target_key).encode()).hexdigest()[:8], 16) diff --git a/fast_llm_external_models/apriel2/conversion/expr.py b/fast_llm_external_models/apriel2/conversion/expr.py index 4867a27ae..34ea106fc 100644 --- a/fast_llm_external_models/apriel2/conversion/expr.py +++ b/fast_llm_external_models/apriel2/conversion/expr.py @@ -52,7 +52,8 @@ import math from collections import defaultdict -from typing import Annotated, Any, Callable, Iterator, Literal, TypedDict, Union, Unpack +from collections.abc import Iterator +from typing import Annotated, Any, Callable, Literal, TypedDict, Union, Unpack import torch from pydantic import BaseModel, ConfigDict, Field, GetCoreSchemaHandler, TypeAdapter @@ -60,7 +61,6 @@ from pydantic_core import CoreSchema, core_schema from torch import Tensor - # ============================================================================= # Weight Path Builder # ============================================================================= @@ -78,7 +78,7 @@ class W(str): mappings[q] = Ref(key=source_q) """ - def __new__(cls, *parts) -> "W": + def __new__(cls, *parts) -> W: # Join parts, stripping any leading/trailing dots from each cleaned = [] for p in parts: @@ -89,12 +89,12 @@ def __new__(cls, *parts) -> "W": cleaned.append(s) return super().__new__(cls, ".".join(cleaned)) - def __truediv__(self, other) -> "W": + def __truediv__(self, other) -> W: if isinstance(other, (list, tuple)): return W(self, *other) return W(self, other) - def __rtruediv__(self, other) -> "W": + def __rtruediv__(self, other) -> W: return W(other, self) @classmethod @@ -156,7 +156,7 @@ class Slice(BaseModel): model_config = ConfigDict(frozen=True) type: Literal["slice"] = "slice" - expr: "Expr" + expr: Expr slices: tuple[tuple[int | None, int | None, int | None], ...] def find_refs(self) -> set[W]: @@ -184,7 +184,7 @@ class Concat(BaseModel): model_config = ConfigDict(frozen=True) type: Literal["concat"] = "concat" - exprs: tuple["Expr", ...] + exprs: tuple[Expr, ...] dim: int = 0 def find_refs(self) -> set[W]: @@ -303,7 +303,7 @@ class Reshape(BaseModel): model_config = ConfigDict(frozen=True) type: Literal["reshape"] = "reshape" - expr: "Expr" + expr: Expr shape: tuple[int, ...] def find_refs(self) -> set[W]: @@ -442,10 +442,10 @@ def __getitem__(self, key: W) -> Expr: def __contains__(self, key: W) -> bool: return key in self.mappings - def __or__(self, other: "ExprPlan") -> "ExprPlan": + def __or__(self, other: ExprPlan) -> ExprPlan: return compose(self, other) - def __add__(self, other: "ExprPlan") -> "ExprPlan": + def __add__(self, other: ExprPlan) -> ExprPlan: return merge(self, other) def source_keys(self) -> set[str]: @@ -471,7 +471,7 @@ def summary(self) -> dict[str, Any]: "metadata": self.metadata, } - def fuse(self) -> "ExprPlan": + def fuse(self) -> ExprPlan: return ExprPlan( mappings={k: fuse(v) for k, v in self.mappings.items()}, source_format=self.source_format, diff --git a/fast_llm_external_models/apriel2/conversion/io.py b/fast_llm_external_models/apriel2/conversion/io.py index e1a261d7e..1f64df0b9 100644 --- a/fast_llm_external_models/apriel2/conversion/io.py +++ b/fast_llm_external_models/apriel2/conversion/io.py @@ -62,7 +62,7 @@ def __init__(self, files: list[Path], device: str = "cpu"): self._handles: dict[Path, Any] = {} self._key_index: dict[str, Path] = {} - def __enter__(self) -> "SafetensorLoader": + def __enter__(self) -> SafetensorLoader: # Pre-build index: key -> file (one-time O(nร—m), then O(1) lookups) for f in self.files: handle = safe_open(f, framework="pt", device=self.device) @@ -128,7 +128,7 @@ def __init__( self._finalized: bool = False self._result_path: Path | None = None - def __enter__(self) -> "ShardedSafetensorWriter": + def __enter__(self) -> ShardedSafetensorWriter: return self def __exit__(self, exc_type, exc_val, exc_tb) -> None: @@ -180,8 +180,7 @@ def _flush(self) -> None: shard_file = self.output_dir / f"{self.base_name}-{self._shard_index:05d}.safetensors.tmp" logger.debug( - f"Writing shard {self._shard_index}: {len(self._buffer)} tensors, " - f"{self._buffer_bytes / 1e9:.2f} GB" + f"Writing shard {self._shard_index}: {len(self._buffer)} tensors, " f"{self._buffer_bytes / 1e9:.2f} GB" ) save_file(self._buffer, shard_file) self._shard_files.append(shard_file) diff --git a/fast_llm_external_models/apriel2/conversion/llava/plan.py b/fast_llm_external_models/apriel2/conversion/llava/plan.py index df485efbd..a97e46c1a 100644 --- a/fast_llm_external_models/apriel2/conversion/llava/plan.py +++ b/fast_llm_external_models/apriel2/conversion/llava/plan.py @@ -1,11 +1,6 @@ """Llava to Apriel2 weight conversion plan.""" -from fast_llm_external_models.apriel2.conversion.expr import ( - Expr, - ExprPlan, - Ref, - W, -) +from fast_llm_external_models.apriel2.conversion.expr import Expr, ExprPlan, Ref, W def plan_llava_to_apriel2(llava_config: dict) -> ExprPlan: diff --git a/fast_llm_external_models/apriel2/conversion/qwen2/__init__.py b/fast_llm_external_models/apriel2/conversion/qwen2/__init__.py new file mode 100644 index 000000000..d0a0b8e6e --- /dev/null +++ b/fast_llm_external_models/apriel2/conversion/qwen2/__init__.py @@ -0,0 +1,6 @@ +"""Qwen2/Qwen2.5 to Apriel2 conversion module.""" + +from fast_llm_external_models.apriel2.conversion.qwen2.config import convert_config +from fast_llm_external_models.apriel2.conversion.qwen2.plan import plan_qwen2_to_apriel2 + +__all__ = ["convert_config", "plan_qwen2_to_apriel2"] diff --git a/fast_llm_external_models/apriel2/conversion/qwen2/config.py b/fast_llm_external_models/apriel2/conversion/qwen2/config.py new file mode 100644 index 000000000..70629fe0e --- /dev/null +++ b/fast_llm_external_models/apriel2/conversion/qwen2/config.py @@ -0,0 +1,79 @@ +"""Qwen2/Qwen2.5 to Apriel2 config conversion.""" + + +def convert_config(qwen2_config: dict) -> dict: + """Convert Qwen2/Qwen2.5 config to Apriel2TextConfig format. + + Qwen2.5 architecture: + - Standard transformer with GQA (grouped query attention) + - QKV bias enabled, O bias disabled + - MLP bias disabled + - Gated SwiGLU MLP + - RMSNorm + - RoPE embeddings + + Args: + qwen2_config: HuggingFace Qwen2Config as dict + + Returns: + Apriel2TextConfig-compatible dict + """ + hidden_size = qwen2_config["hidden_size"] + num_attention_heads = qwen2_config["num_attention_heads"] + num_key_value_heads = qwen2_config.get("num_key_value_heads", num_attention_heads) + head_dim = hidden_size // num_attention_heads + + # Qwen2 uses QKV bias but not O bias - mirror Fast-LLM's per-layer config + return { + "model_type": "apriel2_text", + "architectures": ["Apriel2ForCausalLM"], + "auto_map": { + "AutoConfig": "configuration_apriel2.Apriel2TextConfig", + "AutoModel": "modeling_apriel2.Apriel2TextModel", + "AutoModelForCausalLM": "modeling_apriel2.Apriel2ForCausalLM", + }, + "hidden_size": hidden_size, + "vocab_size": qwen2_config["vocab_size"], + "tie_word_embeddings": qwen2_config.get("tie_word_embeddings", False), + "decoder": { + "type": "fixed", + "num_blocks": qwen2_config["num_hidden_layers"], + "block": { + "mixer": { + "type": "attention", + "heads": num_attention_heads, + "head_groups": num_key_value_heads, + "head_size": head_dim, + # Per-layer bias config matching Fast-LLM structure + "query_layer": {"bias": {"enabled": True}}, + "key_layer": {"bias": {"enabled": True}}, + "value_layer": {"bias": {"enabled": True}}, + "dense_layer": {"bias": {"enabled": False}}, + "rotary": { + "type": "mistral_1d", + "theta": qwen2_config.get("rope_theta", 1000000.0), + }, + }, + "mlp": { + "type": "mlp", + "intermediate_size": qwen2_config["intermediate_size"], + "activation": qwen2_config.get("hidden_act", "silu"), + "gated": True, + "add_linear_biases": False, + }, + "normalization": { + "type": "rms_norm", + "epsilon": qwen2_config.get("rms_norm_eps", 1e-6), + }, + }, + }, + "head": { + "normalization": { + "type": "rms_norm", + "epsilon": qwen2_config.get("rms_norm_eps", 1e-6), + } + }, + "embeddings": { + "max_position_embeddings": qwen2_config.get("max_position_embeddings", 32768), + }, + } diff --git a/fast_llm_external_models/apriel2/conversion/qwen2/plan.py b/fast_llm_external_models/apriel2/conversion/qwen2/plan.py new file mode 100644 index 000000000..c1ec4af8b --- /dev/null +++ b/fast_llm_external_models/apriel2/conversion/qwen2/plan.py @@ -0,0 +1,100 @@ +"""Qwen2/Qwen2.5 to Apriel2 weight conversion plan.""" + +from fast_llm_external_models.apriel2.conversion.expr import Expr, ExprPlan, Ref, W + + +def plan_qwen2_to_apriel2(qwen2_config: dict) -> ExprPlan: + """Build an expression plan for Qwen2/Qwen2.5 to Apriel2 conversion. + + This is a pure mapping (all Ref expressions) since Qwen2โ†’Apriel2 + is just renaming keys. The weight tensors are identical. + + Key mapping (source keys have "model." prefix in safetensors): + Qwen2 (safetensor key) Apriel2 + ---------------------- ------- + model.embed_tokens.weight -> model.embed_tokens.weight + model.norm.weight -> model.norm.weight + model.layers.{i}.input_layernorm.weight -> model.decoder.blocks.{i}.input_layernorm.weight + model.layers.{i}.post_attention_layernorm.weight -> model.decoder.blocks.{i}.post_attention_layernorm.weight + model.layers.{i}.self_attn.q_proj.weight -> model.decoder.blocks.{i}.mixer.q_proj.weight + model.layers.{i}.self_attn.q_proj.bias -> model.decoder.blocks.{i}.mixer.q_proj.bias + model.layers.{i}.self_attn.k_proj.weight -> model.decoder.blocks.{i}.mixer.k_proj.weight + model.layers.{i}.self_attn.k_proj.bias -> model.decoder.blocks.{i}.mixer.k_proj.bias + model.layers.{i}.self_attn.v_proj.weight -> model.decoder.blocks.{i}.mixer.v_proj.weight + model.layers.{i}.self_attn.v_proj.bias -> model.decoder.blocks.{i}.mixer.v_proj.bias + model.layers.{i}.self_attn.o_proj.weight -> model.decoder.blocks.{i}.mixer.o_proj.weight + model.layers.{i}.mlp.gate_proj.weight -> model.decoder.blocks.{i}.mlp.gate_proj.weight + model.layers.{i}.mlp.up_proj.weight -> model.decoder.blocks.{i}.mlp.up_proj.weight + model.layers.{i}.mlp.down_proj.weight -> model.decoder.blocks.{i}.mlp.down_proj.weight + + Note: Qwen2 has QKV biases but no O bias. The Apriel2 config uses per-layer + bias settings (query_layer.bias.enabled=True, dense_layer.bias.enabled=False) + to match this exactly - no workarounds needed. + + Args: + qwen2_config: HuggingFace Qwen2Config as dict + + Returns: + ExprPlan with Ref mappings + """ + mappings: dict[str, Expr] = {} + + num_layers = qwen2_config["num_hidden_layers"] + + # Static mappings (embeddings and final norm) + # Note: Qwen2 safetensor keys have "model." prefix + static_mappings = [ + (W("model", "embed_tokens", "weight"), W("model", "embed_tokens", "weight")), + (W("model", "norm", "weight"), W("model", "norm", "weight")), + ] + + # lm_head - only if not tied + if not qwen2_config.get("tie_word_embeddings", False): + static_mappings.append((W("lm_head", "weight"), W("lm_head", "weight"))) + + for src, tgt in static_mappings: + mappings[tgt] = Ref(key=src) + + # Layer mappings + for layer in range(num_layers): + # Source has "model.layers.{i}" prefix + qwen_layer = W("model", "layers", layer) + apriel_layer = W("model", "decoder", "blocks", layer) + + # Attention projection weights + for proj in ["q_proj", "k_proj", "v_proj", "o_proj"]: + src = qwen_layer / "self_attn" / proj / "weight" + tgt = apriel_layer / "mixer" / proj / "weight" + mappings[tgt] = Ref(key=src) + + # QKV biases (Qwen2 has these, but not O bias) + for proj in ["q_proj", "k_proj", "v_proj"]: + src = qwen_layer / "self_attn" / proj / "bias" + tgt = apriel_layer / "mixer" / proj / "bias" + mappings[tgt] = Ref(key=src) + + # Note: o_proj has no bias in Qwen2, and Apriel2 config has dense_layer.bias.enabled=False + + # MLP projections + for proj in ["gate_proj", "up_proj", "down_proj"]: + src = qwen_layer / "mlp" / proj / "weight" + tgt = apriel_layer / "mlp" / proj / "weight" + mappings[tgt] = Ref(key=src) + + # Layer norms + mappings[apriel_layer / "input_layernorm" / "weight"] = Ref(key=qwen_layer / "input_layernorm" / "weight") + mappings[apriel_layer / "post_attention_layernorm" / "weight"] = Ref( + key=qwen_layer / "post_attention_layernorm" / "weight" + ) + + return ExprPlan( + mappings=mappings, + source_format="qwen2", + target_format="apriel2", + metadata={ + "num_layers": num_layers, + "hidden_size": qwen2_config["hidden_size"], + "num_attention_heads": qwen2_config["num_attention_heads"], + "num_key_value_heads": qwen2_config.get("num_key_value_heads", qwen2_config["num_attention_heads"]), + }, + ) diff --git a/fast_llm_external_models/apriel2/conversion/render.py b/fast_llm_external_models/apriel2/conversion/render.py index d71fa03e1..f9a0c8ac1 100644 --- a/fast_llm_external_models/apriel2/conversion/render.py +++ b/fast_llm_external_models/apriel2/conversion/render.py @@ -8,17 +8,11 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING +from fast_llm_external_models.apriel2.conversion.expr import Concat, Init, Ref, Reshape, Slice + if TYPE_CHECKING: from fast_llm_external_models.apriel2.conversion.expr import Expr, ExprPlan -from fast_llm_external_models.apriel2.conversion.expr import ( - Concat, - Init, - Ref, - Reshape, - Slice, -) - @dataclass class PlanTreeNode: @@ -28,10 +22,10 @@ class PlanTreeNode: After merging, leaf nodes contain aggregated values from multiple siblings. """ - children: dict[str, "PlanTreeNode"] = field(default_factory=dict) + children: dict[str, PlanTreeNode] = field(default_factory=dict) # For leaf nodes: list of (sibling_key, expr) pairs # Before merge: single item, after merge: multiple items from merged siblings - values: list[tuple[str, "Expr"]] = field(default_factory=list) + values: list[tuple[str, Expr]] = field(default_factory=list) def is_leaf(self) -> bool: return len(self.children) == 0 @@ -61,7 +55,7 @@ def _build_plan_tree(plan: ExprPlan) -> PlanTreeNode: return root -def _expr_signature(expr: "Expr") -> tuple: +def _expr_signature(expr: Expr) -> tuple: """Get a signature for an expression that determines merge compatibility. Expressions with different signatures should not be merged together. @@ -453,7 +447,7 @@ def _render_plan_tree( ) -def _format_aggregated_leaf(values: list[tuple[str, "Expr"]]) -> str: +def _format_aggregated_leaf(values: list[tuple[str, Expr]]) -> str: """Format a leaf with aggregated values using pattern discovery. Args: @@ -494,7 +488,7 @@ def _format_aggregated_leaf(values: list[tuple[str, "Expr"]]) -> str: return _format_single_expr(first_expr) -def _format_single_expr(expr: "Expr") -> str: +def _format_single_expr(expr: Expr) -> str: """Format a single expression using ML notation.""" match expr: case Ref(key=key): @@ -531,7 +525,7 @@ def _format_single_expr(expr: "Expr") -> str: return f"= {type(expr).__name__}" -def _format_concat_part(expr: "Expr") -> str: +def _format_concat_part(expr: Expr) -> str: """Format a single part of a concat (for short display).""" match expr: case Ref(key=key): @@ -570,7 +564,7 @@ def _format_slice_notation(slices: tuple) -> str: return f"[{', '.join(slice_strs)}]" -def _format_aggregated_concat(values: list[tuple[str, "Expr"]]) -> str: +def _format_aggregated_concat(values: list[tuple[str, Expr]]) -> str: """Format aggregated Concat expressions with pattern discovery.""" # Get the first concat to understand structure first_concat = values[0][1] @@ -590,7 +584,7 @@ def _format_aggregated_concat(values: list[tuple[str, "Expr"]]) -> str: return f"= [{sep.join(formatted_parts)}]" -def _format_aggregated_concat_part(values: list[tuple[str, "Expr"]]) -> str: +def _format_aggregated_concat_part(values: list[tuple[str, Expr]]) -> str: """Format a single part of an aggregated concat.""" if len(values) == 1: return _format_concat_part(values[0][1]) @@ -619,7 +613,7 @@ def _format_aggregated_concat_part(values: list[tuple[str, "Expr"]]) -> str: return _format_concat_part(first_expr) -def _format_aggregated_slice(values: list[tuple[str, "Expr"]]) -> str: +def _format_aggregated_slice(values: list[tuple[str, Expr]]) -> str: """Format aggregated Slice expressions with pattern discovery.""" first_slice = values[0][1] if not isinstance(first_slice, Slice): diff --git a/fast_llm_external_models/apriel2/convert.py b/fast_llm_external_models/apriel2/convert.py index cbf921b31..66c419dfd 100644 --- a/fast_llm_external_models/apriel2/convert.py +++ b/fast_llm_external_models/apriel2/convert.py @@ -15,6 +15,7 @@ Supported source formats: - llava: Llava/Pixtral models +- qwen2: Qwen2/Qwen2.5 models - apriel2: Apriel2 models (surgery-only mode - no conversion, just apply surgeries) """ @@ -29,10 +30,7 @@ import yaml from tqdm import tqdm -# Allow running as script or module -if __name__ == "__main__": - sys.path.insert(0, str(Path(__file__).parent.parent.parent)) - +# Import source-specific converters from fast_llm_external_models.apriel2.conversion import ( DEFAULT_MAX_SHARD_SIZE, ExprPlan, @@ -41,11 +39,16 @@ StreamingExecutor, compose, compose_configs, - plan_surgery, ) - -# Import source-specific converters from fast_llm_external_models.apriel2.conversion import llava as llava_converter +from fast_llm_external_models.apriel2.conversion import plan_surgery +from fast_llm_external_models.apriel2.conversion import qwen2 as qwen2_converter +from fast_llm_external_models.apriel2.conversion import strip_init_fields + +# Allow running as script or module +if __name__ == "__main__": + sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + logger = logging.getLogger(__name__) @@ -73,6 +76,7 @@ def _identity_plan(config: dict) -> ExprPlan: # Each entry maps format name to (config_converter, plan_builder) SOURCE_FORMATS: dict[str, tuple[Callable[[dict], dict], Callable[[dict], ExprPlan]]] = { "llava": (llava_converter.convert_config, llava_converter.plan_llava_to_apriel2), + "qwen2": (qwen2_converter.convert_config, qwen2_converter.plan_qwen2_to_apriel2), "apriel2": (_identity_config, _identity_plan), } @@ -88,8 +92,12 @@ def detect_source_format(config: dict) -> str | None: if model_type in ("llava", "pixtral") or "text_config" in config: return "llava" + # Qwen2/Qwen2.5 detection + if model_type == "qwen2": + return "qwen2" + # Apriel2 detection - check for Apriel2-specific structure - if model_type == "apriel2" or "decoder" in config: + if model_type in ("apriel2", "apriel2_text") or "decoder" in config: return "apriel2" return None @@ -142,15 +150,21 @@ def build_plan( # Apply surgery chain if requested if surgery_configs: for i, surgery_config in enumerate(surgery_configs, 1): - surgery_plan = plan_surgery(current_config, surgery_config) - logger.info(f"Built surgery plan [{i}/{len(surgery_configs)}]: {surgery_plan.summary()['num_targets']} targets") + # S ร— P โ†’ T: compose state with surgery to get transition spec + target_config = compose_configs(current_config, surgery_config) + + # S ร— T โ†’ Plan: build plan from source state and transition spec + surgery_plan = plan_surgery(current_config, target_config) + logger.info( + f"Built surgery plan [{i}/{len(surgery_configs)}]: {surgery_plan.summary()['num_targets']} targets" + ) - # Compose: current -> surgery + # Compose plans current_plan = compose(current_plan, surgery_plan) logger.info(f"Composed plan [{i}/{len(surgery_configs)}]: {current_plan.summary()['num_targets']} targets") - # Compose configs: merge surgery spec into current config - current_config = compose_configs(current_config, surgery_config) + # T โ†’ S: strip init for next iteration (init is consumed by plan_surgery) + current_config = strip_init_fields(target_config) return current_plan, current_config @@ -211,9 +225,7 @@ def convert( executor = StreamingExecutor(full_plan, loader) with ShardedSafetensorWriter(output_dir, max_shard_size=max_shard_size) as writer: - for target_key, tensor in tqdm( - executor.execute(seed), desc="Converting", total=len(full_plan) - ): + for target_key, tensor in tqdm(executor.execute(seed), desc="Converting", total=len(full_plan)): writer.add(target_key, tensor) return final_config @@ -282,9 +294,7 @@ def resolve_input(input_path: str) -> Path: def main(): - parser = argparse.ArgumentParser( - description="Convert HuggingFace checkpoint to Apriel2 HF format" - ) + parser = argparse.ArgumentParser(description="Convert HuggingFace checkpoint to Apriel2 HF format") parser.add_argument( "input", type=str, @@ -384,8 +394,7 @@ def main(): safetensor_files = sorted(input_dir.glob("*.safetensors")) if not safetensor_files: raise ValueError( - f"No safetensor files found in {input_dir}. " - "Plan-based conversion requires safetensor files." + f"No safetensor files found in {input_dir}. " "Plan-based conversion requires safetensor files." ) # Convert using plan-based approach with streaming sharded output @@ -400,11 +409,11 @@ def main(): show_plan=args.show_plan or args.verbose, ) - # Save config + # Save config (build_plan returns S which has no init, but strip defensively) output_config_file = args.output_dir / "config.json" logger.info(f"Saving config to {output_config_file}") with open(output_config_file, "w") as f: - json.dump(apriel2_config, f, indent=2) + json.dump(strip_init_fields(apriel2_config), f, indent=2) # Copy tokenizer files copy_tokenizer_files(input_dir, args.output_dir) diff --git a/fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml b/fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml new file mode 100644 index 000000000..34672916c --- /dev/null +++ b/fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml @@ -0,0 +1,103 @@ +# Dataset preparation config for Tulu 3 SFT mixture with Qwen2 tokenizer +# +# This config converts the Tulu 3 SFT dataset (conversation format) to +# Fast-LLM's memmap format, with automatic loss masking span computation +# to train only on assistant responses. +# +# ============================================================================= +# TOKENIZER SETUP (one-time) +# ============================================================================= +# +# The tokenizer must have a chat template with {% generation %} markers. +# Qwen2's default template doesn't have these, so we need to patch it. +# +# IMPORTANT: The entire assistant turn (opening tag + content + closing tag) +# must be inside the {% generation %} block. This ensures the model learns to +# produce the full assistant response including special tokens like <|im_end|>. +# Reference: https://huggingface.co/HuggingFaceTB/SmolLM3-3B/blob/main/chat_template.jinja +# +# Run this Python script to create a patched tokenizer: +# +# from transformers import AutoTokenizer +# +# tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") +# +# # Patch chat template: wrap ENTIRE assistant turn in generation markers +# tokenizer.chat_template = '''{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system +# You are a helpful assistant.<|im_end|> +# ' }}{% endif %}{% if message['role'] == 'assistant' %}{% generation %}{{ '<|im_start|>assistant +# ' + message['content'] + '<|im_end|> +# ' }}{% endgeneration %}{% else %}{{ '<|im_start|>' + message['role'] + ' +# ' + message['content'] + '<|im_end|> +# ' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant +# ' }}{% endif %}''' +# +# tokenizer.save_pretrained("/path/to/qwen2-instruct-with-markers") +# +# ============================================================================= +# DATA PREPARATION +# ============================================================================= +# +# Small dataset (for testing): +# +# fast-llm prepare gpt_memmap \ +# -c fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml \ +# dataset.split=train[:1000] \ +# output_path=/path/to/tulu3-prepared-small +# +# Full dataset (~939K samples, ~6 minutes): +# +# fast-llm prepare gpt_memmap \ +# -c fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml +# +# ============================================================================= +# VERIFICATION +# ============================================================================= +# +# To verify the prepared dataset has loss masking spans: +# +# import pathlib +# from fast_llm.data.dataset.memmap import MemmapDataset +# from fast_llm.data.sample.language_model import LanguageModelSample +# from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig +# +# dataset = MemmapDataset[LanguageModelSample]( +# 'tulu3', +# pathlib.Path('/path/to/tulu3-prepared/shard_0_0.fast_llm_dataset'), +# LanguageModelPreprocessingConfig(use_loss_masking_spans=True) +# ) +# +# doc = dataset.get_document(0) +# print(f'Tokens: {len(doc.tokens.tokens)}') +# print(f'Loss masking spans: {doc.loss_masking_spans.ranges}') +# +# ============================================================================= + +# Dataset configuration +dataset: + # Tulu 3 SFT mixture from AllenAI + path: allenai/tulu-3-sft-mixture + split: train + + # Source schema for conversation format + source_schema: + # Use conversation type (vs default "document" type) + type: conversation + + # Column containing the messages list + messages: messages + +# Tokenizer configuration +# IMPORTANT: Must use a tokenizer with {% generation %} markers in its chat template. +# See instructions above to create a patched Qwen2 tokenizer. +tokenizer: + path: /path/to/qwen2-instruct-with-markers + # Qwen2 doesn't have a BOS token by default, use <|endoftext|> as BOS + bos_token: "<|endoftext|>" + +# Output configuration +output_path: /path/to/tulu3-prepared + +# Processing configuration +num_workers: 8 +documents_per_shard: 100000 diff --git a/fast_llm_external_models/apriel2/examples/train_supernet_qwen2.yaml b/fast_llm_external_models/apriel2/examples/train_supernet_qwen2.yaml new file mode 100644 index 000000000..aad168713 --- /dev/null +++ b/fast_llm_external_models/apriel2/examples/train_supernet_qwen2.yaml @@ -0,0 +1,193 @@ +# Training config for Qwen2-based Apriel2 stochastic supernet on Tulu 3 SFT data +# +# This config trains a stochastic supernet where each layer can sample from +# multiple mixer types (attention, sliding window, gated delta net, KDA). +# Only the mixer weights are trained; all other weights are frozen. +# Activation-level distillation from a teacher model guides the training. +# +# ============================================================================= +# PREREQUISITES +# ============================================================================= +# +# 1. TOKENIZER SETUP +# +# Qwen2's default chat template doesn't have generation markers needed for +# loss masking. Create a patched tokenizer following the SmolLM3 pattern: +# https://huggingface.co/HuggingFaceTB/SmolLM3-3B/blob/main/chat_template.jinja +# +# IMPORTANT: The ENTIRE assistant turn (opening tag + content + closing tag) +# must be inside {% generation %}...{% endgeneration %} markers. +# +# from transformers import AutoTokenizer +# tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") +# # Wrap entire assistant turn in generation markers (NOT just content!) +# tokenizer.chat_template = '''{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system +# You are a helpful assistant.<|im_end|> +# ' }}{% endif %}{% if message['role'] == 'assistant' %}{% generation %}{{ '<|im_start|>assistant +# ' + message['content'] + '<|im_end|> +# ' }}{% endgeneration %}{% else %}{{ '<|im_start|>' + message['role'] + ' +# ' + message['content'] + '<|im_end|> +# ' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant +# ' }}{% endif %}''' +# tokenizer.save_pretrained("/path/to/qwen2-instruct-with-markers") +# +# 2. PREPARE TULU 3 DATASET +# +# Small dataset (for testing): +# +# fast-llm prepare gpt_memmap \ +# -c fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml \ +# tokenizer.path=/path/to/qwen2-instruct-with-markers \ +# dataset.split=train[:1000] \ +# output_path=/path/to/tulu3-prepared-small +# +# Full dataset (~939K samples, ~6 minutes): +# +# fast-llm prepare gpt_memmap \ +# -c fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml \ +# tokenizer.path=/path/to/qwen2-instruct-with-markers \ +# output_path=/path/to/tulu3-prepared +# +# 3. CONVERT QWEN2 TO APRIEL2 SUPERNET (student model) +# +# This creates a stochastic supernet with multiple mixer types per layer: +# +# python fast_llm_external_models/apriel2/convert.py \ +# Qwen/Qwen2.5-0.5B-Instruct \ +# /path/to/qwen2-supernet \ +# --surgery fast_llm_external_models/apriel2/examples/stochastic_supernet.yaml +# +# 4. CONVERT QWEN2 TO APRIEL2 (teacher model) +# +# The teacher is the original model without surgery, used for distillation: +# +# python fast_llm_external_models/apriel2/convert.py \ +# Qwen/Qwen2.5-0.5B-Instruct \ +# /path/to/qwen2-teacher +# +# 5. RUN TRAINING +# +# Update paths below and run: +# +# fast-llm train gpt \ +# -c fast_llm_external_models/apriel2/examples/train_supernet_qwen2.yaml +# +# For long runs, use nohup: +# +# nohup fast-llm train gpt \ +# -c fast_llm_external_models/apriel2/examples/train_supernet_qwen2.yaml \ +# > training.log 2>&1 & +# tail -f training.log +# +# ============================================================================= +# PERFORMANCE TUNING +# ============================================================================= +# +# Default config uses seq=2048, micro_batch=2, batch=64 (~131k tokens/iter). +# Adjust settings based on your GPU memory: +# - Reduce micro_batch_size or sequence_length if OOM +# - Increase micro_batch_size or sequence_length if memory available +# +# ============================================================================= +# OUTPUT +# ============================================================================= +# +# Checkpoints: /path/to/qwen2-supernet-trained/checkpoints/{iteration}/ +# Exports: /path/to/qwen2-supernet-trained/export/apriel2_text/{iteration}/ +# +# ============================================================================= + +# Load pretrained model (Qwen2 converted to Apriel2 supernet) +pretrained: + path: /path/to/qwen2-supernet + format: apriel2_text + model_weights: true + load_config: model + +# Model config +model: + base_model: + # Freeze all components except the mixer + decoder: + block: + mlp: + lr_scale: 0.0 # Freeze MLP + normalization: + lr_scale: 0.0 # Freeze layer norms + distillation_model: teacher + activation_distillation_factor: 0.5 + embeddings: + lr_scale: 0.0 # Freeze word embeddings + head: + lr_scale: 0.0 # Freeze output head + # cross_entropy_implementation: torch + distillation_model: teacher + distillation_loss_factor: 1.0 + distillation_loss_implementation: reverse_kl + multi_stage: + zero_stage: 2 + distributed: + compute_dtype: bf16 + seed: 42 + +# Teacher model for activation-level distillation +reference_models: + teacher: + model: + type: gpt + pretrained: + path: /path/to/qwen2-teacher + format: apriel2_text + model_weights: true + load_config: model + +# Batch configuration +batch: + sequence_length: 2048 + micro_batch_size: 2 + batch_size: 64 + truncate_documents: false + use_loss_masking_spans: true + +# Data configuration (prepared Tulu 3 dataset) +data: + datasets: + training: + type: file + path: /path/to/tulu3-prepared/fast_llm_config.yaml + +# Optimizer configuration +optimizer: + learning_rate: + base: 3.0e-05 + decay_style: cosine + warmup_iterations: 100 + decay_iterations: 10000 + minimum: 1.0e-06 + weight_decay: 0.1 + beta_1: 0.9 + beta_2: 0.95 + +# Training configuration +# At seq=2048, batch=64: ~131k tokens/iter +training: + train_iters: 10000 + num_workers: 4 + logs: + interval: 10 + checkpoint: + interval: 100 + export: + interval: 100 + format: apriel2_text + test_iters: 0 + evaluators: {} + # Weights & Biases configuration (optional, uncomment to enable) + # wandb: + # entity_name: your-entity + # project_name: your-project + # group_name: your-group + +# Experiment directory +run: + experiment_dir: /path/to/qwen2-supernet-trained diff --git a/fast_llm_external_models/apriel2/examples/train_supernet_small.yaml b/fast_llm_external_models/apriel2/examples/train_supernet_small.yaml index 78c22e57f..be4d06e0a 100644 --- a/fast_llm_external_models/apriel2/examples/train_supernet_small.yaml +++ b/fast_llm_external_models/apriel2/examples/train_supernet_small.yaml @@ -107,7 +107,7 @@ model: lr_scale: 0.0 # Freeze layer norms (norm_1 and norm_2 in each block) # Activation-level distillation: teach mixers to mimic teacher's attention outputs distillation_model: teacher - activation_distillation_factor: 0.1 + activation_distillation_factor: 0.8 embeddings: lr_scale: 0.0 # Freeze word embeddings head: diff --git a/fast_llm_external_models/apriel2/modeling_apriel2.py b/fast_llm_external_models/apriel2/modeling_apriel2.py index 4c263b4e2..240240cd6 100644 --- a/fast_llm_external_models/apriel2/modeling_apriel2.py +++ b/fast_llm_external_models/apriel2/modeling_apriel2.py @@ -24,8 +24,8 @@ is_torch_flex_attn_available, ) -from fast_llm_external_models.apriel2.cache import Apriel2Cache -from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config, Apriel2TextConfig +from .cache import Apriel2Cache +from .configuration_apriel2 import Apriel2Config, Apriel2TextConfig # GDN implementation - matches Fast-LLM's gdn.py exactly try: @@ -395,14 +395,30 @@ def __init__(self, d_model: int, mixer_config: dict, layer_idx: int, config): # cross_document_attention: if False, use cu_seqlens to isolate sequences (e.g., images) self.cross_document_attention = mixer_config.get("cross_document_attention", True) - # Whether to add biases to linear projections - add_bias = mixer_config.get("add_linear_biases", False) - - # Projections (Fast-LLM weight names: q_proj, k_proj, v_proj, o_proj) - self.q_proj = nn.Linear(d_model, self.num_heads * self.head_dim, bias=add_bias) - self.k_proj = nn.Linear(d_model, self.num_key_value_heads * self.head_dim, bias=add_bias) - self.v_proj = nn.Linear(d_model, self.num_key_value_heads * self.head_dim, bias=add_bias) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, d_model, bias=add_bias) + # Bias configuration mirroring Fast-LLM's structure: + # - add_linear_biases: bool (default for all projections) + # - query_layer: {"bias": {"enabled": bool}} (per-layer override) + # - key_layer: {"bias": {"enabled": bool}} + # - value_layer: {"bias": {"enabled": bool}} + # - dense_layer: {"bias": {"enabled": bool}} + default_bias = mixer_config.get("add_linear_biases", False) + + def get_layer_bias(layer_name: str) -> bool: + layer_cfg = mixer_config.get(layer_name, {}) + bias_cfg = layer_cfg.get("bias", {}) + enabled = bias_cfg.get("enabled") + return default_bias if enabled is None else enabled + + q_bias = get_layer_bias("query_layer") + k_bias = get_layer_bias("key_layer") + v_bias = get_layer_bias("value_layer") + o_bias = get_layer_bias("dense_layer") + + # Projections + self.q_proj = nn.Linear(d_model, self.num_heads * self.head_dim, bias=q_bias) + self.k_proj = nn.Linear(d_model, self.num_key_value_heads * self.head_dim, bias=k_bias) + self.v_proj = nn.Linear(d_model, self.num_key_value_heads * self.head_dim, bias=v_bias) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, d_model, bias=o_bias) @classmethod def setup( @@ -1017,6 +1033,8 @@ def torch_chunk_gated_delta_rule( if not output_final_state: last_recurrent_state = None + elif last_recurrent_state is not None: + last_recurrent_state = last_recurrent_state.to(initial_dtype) core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1]) core_attn_out = core_attn_out[:, :, :sequence_length] core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) @@ -1225,7 +1243,9 @@ def forward(self, hidden_states: torch.Tensor, past_key_values=None, attention_m mixed_qkv = self.convolution.update( mixed_qkv.squeeze(2), # [batch, conv_dim, 1] -> [batch, conv_dim] conv_state, - ).unsqueeze(2) # [batch, conv_dim] -> [batch, conv_dim, 1] + ).unsqueeze( + 2 + ) # [batch, conv_dim] -> [batch, conv_dim, 1] else: # Prefill mode use_cache = past_key_values is not None @@ -1270,8 +1290,14 @@ def forward(self, hidden_states: torch.Tensor, past_key_values=None, attention_m output_final_state=past_key_values is not None, use_qk_l2norm_in_kernel=True, ) + # Ensure state is in same dtype as hidden_states (fla kernel may return float32) + if last_recurrent_state is not None: + last_recurrent_state = last_recurrent_state.to(hidden_states.dtype) else: # Recurrent mode for single token decode + # Convert recurrent_state to match hidden_states dtype if needed + if recurrent_state is not None and recurrent_state.dtype != hidden_states.dtype: + recurrent_state = recurrent_state.to(hidden_states.dtype) output, last_recurrent_state = self._recurrent_gated_delta_rule( query, key, value, g, beta_gate, recurrent_state ) @@ -1294,7 +1320,16 @@ def forward(self, hidden_states: torch.Tensor, past_key_values=None, attention_m return (output,) def _recurrent_gated_delta_rule(self, query, key, value, g, beta, state): - """Single-step recurrent update for cached inference.""" + """Single-step recurrent update for cached inference. + + Input shapes: [batch, seq=1, heads, dim] + Need shapes: [batch, heads, dim] for einsum operations + """ + # Transpose from [batch, seq, heads, dim] to [batch, heads, seq, dim] + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + # L2 normalize query and key query = _l2norm(query, dim=-1, eps=1e-6) key = _l2norm(key, dim=-1, eps=1e-6) @@ -1307,7 +1342,9 @@ def _recurrent_gated_delta_rule(self, query, key, value, g, beta, state): beta = beta.squeeze(1) # Update state: S = exp(g) * S + beta * k^T @ v - decay = g.exp().unsqueeze(-1).unsqueeze(-1) # [batch, heads, 1, 1] + # Keep everything in the same dtype as input (exp() returns float32, need to convert back) + input_dtype = query.dtype + decay = g.exp().to(input_dtype).unsqueeze(-1).unsqueeze(-1) # [batch, heads, 1, 1] k_outer_v = torch.einsum("bhk,bhv->bhkv", key * beta.unsqueeze(-1), value) state = decay * state + k_outer_v @@ -1315,6 +1352,12 @@ def _recurrent_gated_delta_rule(self, query, key, value, g, beta, state): output = torch.einsum("bhk,bhkv->bhv", query, state) output = output.unsqueeze(2) # [batch, heads, 1, v_dim] + # Transpose back to [batch, seq=1, heads, v_dim] + output = output.transpose(1, 2) + + # Ensure state matches output dtype + state = state.to(output.dtype) + return output, state @classmethod @@ -1447,9 +1490,7 @@ def __init__( # Normalization - use GatedRMSNormalization (same wrapper as GDN, with sigmoid activation) self.norm = GatedRMSNormalization(self.head_dim, eps=self.norm_eps, activation=self.norm_activation) - def _apply_conv( - self, x: torch.Tensor, conv: CausalConv1d, conv_state: torch.Tensor | None, use_cache: bool - ): + def _apply_conv(self, x: torch.Tensor, conv: CausalConv1d, conv_state: torch.Tensor | None, use_cache: bool): """ Apply causal convolution with cache support. @@ -1828,16 +1869,36 @@ def __init__( self.post_attention_layernorm = self._create_norm(norm_config, hidden_size, rms_norm_eps) def _create_mlp(self, mlp_config: dict, hidden_size: int): - """Create MLP based on config.""" + """Create MLP based on config. + + Supports per-layer bias configuration mirroring Fast-LLM: + - add_linear_biases: default bias setting for all layers + - layer_1.bias.enabled: override for up_proj/gate_proj + - layer_2.bias.enabled: override for down_proj + """ mlp_type = mlp_config.get("type", "mlp") if mlp_type == "mlp": intermediate_size = mlp_config["intermediate_size"] activation = mlp_config.get("activation", "silu") - gated = mlp_config["gated"] - bias = mlp_config.get("add_linear_biases", False) + gated = mlp_config.get("gated", False) + + # Per-layer bias configuration (mirrors Fast-LLM structure) + default_bias = mlp_config.get("add_linear_biases", False) + + def get_layer_bias(layer_name: str) -> bool: + layer_cfg = mlp_config.get(layer_name, {}) + bias_cfg = layer_cfg.get("bias", {}) + enabled = bias_cfg.get("enabled") + return default_bias if enabled is None else enabled + + layer_1_bias = get_layer_bias("layer_1") + layer_2_bias = get_layer_bias("layer_2") if gated: + # MistralMLP uses gate_proj, up_proj, down_proj (all bias controlled together) + # For now, we use the default bias setting for gated MLPs + # TODO: Add per-layer bias support to gated MLP mlp_cfg = SimpleNamespace( hidden_size=hidden_size, intermediate_size=intermediate_size, @@ -1845,7 +1906,13 @@ def _create_mlp(self, mlp_config: dict, hidden_size: int): ) return MistralMLP(mlp_cfg) else: - return SimpleMLP(hidden_size, intermediate_size, activation, bias) + return SimpleMLP( + hidden_size, + intermediate_size, + activation, + layer_1_bias=layer_1_bias, + layer_2_bias=layer_2_bias, + ) else: raise ValueError(f"Unknown MLP type: {mlp_type}") @@ -2179,6 +2246,8 @@ def forward( class Apriel2ForCausalLM(Apriel2PreTrainedModel, GenerationMixin): """Apriel2 model with a language modeling head (text-only).""" + _tied_weights_keys = ["lm_head.weight"] + def __init__(self, config: Apriel2TextConfig): super().__init__(config) self.model = Apriel2TextModel(config) @@ -2186,6 +2255,7 @@ def __init__(self, config: Apriel2TextConfig): self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing + # post_init() calls init_weights() which calls tie_weights() if config.tie_word_embeddings self.post_init() def get_input_embeddings(self): @@ -2583,14 +2653,26 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: class SimpleMLP(nn.Module): - """Non-gated MLP: up_proj -> activation -> down_proj.""" + """Non-gated MLP: up_proj -> activation -> down_proj. + + Supports per-layer bias configuration mirroring Fast-LLM: + - layer_1_bias: bias for up_proj (layer_1 in Fast-LLM naming) + - layer_2_bias: bias for down_proj (layer_2 in Fast-LLM naming) + """ - def __init__(self, hidden_size: int, intermediate_size: int, activation: str = "silu", bias: bool = False): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + activation: str = "silu", + layer_1_bias: bool = False, + layer_2_bias: bool = False, + ): super().__init__() from transformers.activations import ACT2FN - self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=bias) - self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=bias) + self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=layer_1_bias) + self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=layer_2_bias) self.act_fn = ACT2FN[activation] def forward(self, x): diff --git a/fast_llm_external_models/tests/test_apriel2/conftest.py b/fast_llm_external_models/tests/test_apriel2/conftest.py index 8585aec65..21b90b097 100644 --- a/fast_llm_external_models/tests/test_apriel2/conftest.py +++ b/fast_llm_external_models/tests/test_apriel2/conftest.py @@ -1,23 +1,44 @@ """Test fixtures for Apriel2 model tests.""" +from collections.abc import Generator from pathlib import Path -from typing import Generator import pytest import torch from transformers import LlavaConfig, LlavaForConditionalGeneration, MistralConfig +from fast_llm_external_models.apriel2.cache import _AttentionCache, _SSMCache + + +# Register custom marks +def pytest_configure(config): + config.addinivalue_line("markers", "slow: marks tests as slow (deselect with '-m \"not slow\"')") + + +def _can_import_fast_llm(): + """Check if Fast-LLM is available.""" + try: + return True + except ImportError: + return False + # Skip marker for tests that require CUDA for Mamba forward pass requires_cuda = pytest.mark.skipif( - not torch.cuda.is_available(), - reason="SSM mixers (Mamba) require CUDA for forward pass" + not torch.cuda.is_available(), reason="SSM mixers (Mamba) require CUDA for forward pass" ) +# Skip marker for tests that require Fast-LLM +requires_fastllm = pytest.mark.skipif(not _can_import_fast_llm(), reason="Fast-LLM not available") -@pytest.fixture(autouse=True) + +@pytest.fixture(scope="module", autouse=True) def set_default_device(): - """Set default device to CUDA for all tests (Mamba requires CUDA).""" + """Set default device to CUDA for all tests (Mamba requires CUDA). + + Module-scoped to ensure it runs before any module-scoped fixtures + that load models (e.g., qwen2_model_and_tokenizer). + """ if torch.cuda.is_available(): old_device = torch.get_default_device() torch.set_default_device("cuda") @@ -27,9 +48,12 @@ def set_default_device(): yield -@pytest.fixture(autouse=True) +@pytest.fixture(scope="module", autouse=True) def set_default_dtype(): - """Set default dtype to float32 for numerical comparison tests.""" + """Set default dtype to float32 for numerical comparison tests. + + Module-scoped to ensure it runs before any module-scoped fixtures. + """ old_dtype = torch.get_default_dtype() torch.set_default_dtype(torch.float32) yield @@ -135,14 +159,11 @@ def model_pair(request, small_pixtral_model, tmp_path): tuple: (source_model, target_model, expected_atol, variant_name) """ import json + from safetensors import safe_open from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config - from fast_llm_external_models.apriel2.conversion import ( - convert_llava_config, - execute, - plan_llava_to_apriel2, - ) + from fast_llm_external_models.apriel2.conversion import convert_llava_config, execute, plan_llava_to_apriel2 from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForConditionalGeneration source = small_pixtral_model @@ -638,12 +659,12 @@ def apriel2_config_comprehensive(): "type": "pattern", "num_blocks": 6, "pattern": [ - "attn", # 0: pure full attention - "swa", # 1: pure sliding window attention - "mamba", # 2: pure mamba - "gdn", # 3: pure gated delta net - "stoch_attn_mamba", # 4: stochastic attention + mamba - "stoch_swa_gdn", # 5: stochastic swa + gated delta net + "attn", # 0: pure full attention + "swa", # 1: pure sliding window attention + "mamba", # 2: pure mamba + "gdn", # 3: pure gated delta net + "stoch_attn_mamba", # 4: stochastic attention + mamba + "stoch_swa_gdn", # 5: stochastic swa + gated delta net ], "blocks": { "attn": { @@ -761,6 +782,52 @@ def apriel2_config_comprehensive(): ) +@pytest.fixture +def apriel2_config_with_bias(): + """Apriel2 config with Qwen-style per-layer bias and non-gated MLP. + + This config exercises: + - Per-layer attention bias (QKV bias enabled, O bias disabled) + - Non-gated MLP with per-layer bias (layer_1 enabled, layer_2 disabled) + - Config structure parity with Fast-LLM's AffineLinearConfig + + Critical for testing bias inheritance through surgery operations. + """ + from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config + + return Apriel2Config( + vocab_size=100, + hidden_size=64, + decoder={ + "type": "fixed", + "num_blocks": 2, + "block": { + "mixer": { + "type": "attention", + "heads": 4, + "head_groups": 2, + "head_size": 16, + "rotary": {"type": "mistral_1d", "theta": 10000.0}, + # Qwen-style: QKV bias enabled, O bias disabled + "query_layer": {"bias": {"enabled": True}}, + "key_layer": {"bias": {"enabled": True}}, + "value_layer": {"bias": {"enabled": True}}, + "dense_layer": {"bias": {"enabled": False}}, + }, + "mlp": { + "type": "mlp", + "intermediate_size": 256, + "gated": False, # Non-gated MLP (SimpleMLP) + # Per-layer MLP bias + "layer_1": {"bias": {"enabled": True}}, + "layer_2": {"bias": {"enabled": False}}, + }, + "normalization": {"type": "rms_norm", "epsilon": 1e-5}, + }, + }, + ) + + @pytest.fixture def apriel2_cache(apriel2_config_tiny): """Create empty Apriel2Cache from tiny config.""" @@ -863,6 +930,77 @@ def additive_surgery_chain(): ] +@pytest.fixture +def bias_surgery_chain(): + """Surgery chain that exercises bias inheritance through surgery operations. + + Designed to be used with apriel2_config_with_bias as the source config. + Tests that per-layer bias settings (Qwen-style QKV bias, non-gated MLP bias) + are correctly inherited through: + - Stochastic wrapper creation + - Adding new sub-mixers that inherit from source + - Cross-type derivation (attention โ†’ sliding_window) + + Source config has: + - Attention: query/key/value bias enabled, dense bias disabled + - MLP: layer_1 bias enabled, layer_2 bias disabled (non-gated) + """ + return [ + # S1: Wrap in stochastic - bias should transfer to attention sub-mixer + { + "decoder": { + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"init": "transfer"}, + }, + }, + "mlp": {"init": "transfer"}, + "normalization": {"init": "transfer"}, + }, + }, + }, + # S2: Add sliding_window - should inherit bias from source attention + { + "decoder": { + "block": { + "mixer": { + "mixers": { + "sliding_window": { + "type": "attention", + "init": "transfer", + "window_size": 512, + }, + }, + }, + }, + }, + }, + # S3: Add new attention with DIFFERENT bias config (random init) + { + "decoder": { + "block": { + "mixer": { + "mixers": { + "full_bias_attn": { + "type": "attention", + "init": "random", + "heads": 4, + "head_groups": 2, + "head_size": 16, + "rotary": {"type": "mistral_1d", "theta": 10000.0}, + "add_linear_biases": True, # All biases enabled + }, + }, + }, + }, + }, + }, + ] + + @pytest.fixture def comprehensive_torture_chain(): """Comprehensive torture chain exercising ALL conversion paths. @@ -885,7 +1023,7 @@ def comprehensive_torture_chain(): # MIL requires: d_inner <= Q rows (256), d_xb <= K/V rows (128) mamba_params = { "d_inner": 256, # Must be <= heads*head_size = 256 - "d_xb": 64, # Must be <= head_groups*head_size = 128 + "d_xb": 64, # Must be <= head_groups*head_size = 128 "dt_rank": 16, "d_state": 16, "d_conv": 4, @@ -1532,3 +1670,330 @@ def torture_surgery_chain(): }, }, ] + + +# ============================================================================= +# Shared Config Dict Fixtures (for compose_configs / plan_surgery tests) +# ============================================================================= + + +@pytest.fixture +def base_config_dict(): + """Complete Apriel2 config dict without biases (Llama-style). + + Use this as the base config for testing compose_configs and plan_surgery. + Returns a dict (not Apriel2Config) for direct use with compose_configs. + """ + return { + "model_type": "apriel2", + "hidden_size": 256, + "vocab_size": 1000, + "bos_token_id": 1, + "eos_token_id": 2, + "tie_word_embeddings": False, + "decoder": { + "type": "fixed", + "num_blocks": 2, + "block": { + "mixer": { + "type": "attention", + "heads": 8, + "head_groups": 4, + "head_size": 32, + "rotary": {"type": "mistral_1d", "theta": 10000.0}, + }, + "mlp": {"type": "mlp", "intermediate_size": 512, "gated": True}, + "normalization": {"type": "rms_norm", "epsilon": 1e-5}, + }, + }, + } + + +@pytest.fixture +def base_config_with_bias_dict(): + """Complete Apriel2 config dict with Qwen-style biases. + + - QKV bias enabled, O bias disabled + - Gated MLP (no per-layer bias control in this style) + + Use this for testing bias inheritance through surgery operations. + Returns a dict (not Apriel2Config) for direct use with compose_configs. + """ + return { + "model_type": "apriel2", + "hidden_size": 256, + "vocab_size": 1000, + "bos_token_id": 1, + "eos_token_id": 2, + "tie_word_embeddings": False, + "decoder": { + "type": "fixed", + "num_blocks": 2, + "block": { + "mixer": { + "type": "attention", + "heads": 8, + "head_groups": 4, + "head_size": 32, + "rotary": {"type": "mistral_1d", "theta": 10000.0}, + "query_layer": {"bias": {"enabled": True}}, + "key_layer": {"bias": {"enabled": True}}, + "value_layer": {"bias": {"enabled": True}}, + "dense_layer": {"bias": {"enabled": False}}, + }, + "mlp": {"type": "mlp", "intermediate_size": 512, "gated": True}, + "normalization": {"type": "rms_norm", "epsilon": 1e-5}, + }, + }, + } + + +def make_weights_for_config(config: dict) -> dict: + """Create random weights matching a config's expected schema. + + This is a helper function (not a fixture) for creating test weights. + Use it in tests that need weights for plan execution. + + Args: + config: Complete Apriel2 config dict + + Returns: + Dict mapping weight key strings to torch tensors + """ + from fast_llm_external_models.apriel2.conversion import W + + hidden = config["hidden_size"] + vocab = config["vocab_size"] + decoder = config["decoder"] + num_blocks = decoder["num_blocks"] + block = decoder["block"] + mixer = block["mixer"] + mlp = block["mlp"] + + heads = mixer["heads"] + head_groups = mixer["head_groups"] + head_size = mixer["head_size"] + inter = mlp["intermediate_size"] + + # Check bias settings + has_q_bias = mixer.get("query_layer", {}).get("bias", {}).get("enabled", False) + has_k_bias = mixer.get("key_layer", {}).get("bias", {}).get("enabled", False) + has_v_bias = mixer.get("value_layer", {}).get("bias", {}).get("enabled", False) + + weights = {} + weights["model.embed_tokens.weight"] = torch.randn(vocab, hidden) + + for i in range(num_blocks): + p = f"model.decoder.blocks.{i}" + + # Attention + weights[f"{p}.mixer.q_proj.weight"] = torch.randn(heads * head_size, hidden) + weights[f"{p}.mixer.k_proj.weight"] = torch.randn(head_groups * head_size, hidden) + weights[f"{p}.mixer.v_proj.weight"] = torch.randn(head_groups * head_size, hidden) + weights[f"{p}.mixer.o_proj.weight"] = torch.randn(hidden, heads * head_size) + + if has_q_bias: + weights[f"{p}.mixer.q_proj.bias"] = torch.randn(heads * head_size) + if has_k_bias: + weights[f"{p}.mixer.k_proj.bias"] = torch.randn(head_groups * head_size) + if has_v_bias: + weights[f"{p}.mixer.v_proj.bias"] = torch.randn(head_groups * head_size) + + # MLP + weights[f"{p}.mlp.up_proj.weight"] = torch.randn(inter, hidden) + weights[f"{p}.mlp.gate_proj.weight"] = torch.randn(inter, hidden) + weights[f"{p}.mlp.down_proj.weight"] = torch.randn(hidden, inter) + + # Norms + weights[f"{p}.input_layernorm.weight"] = torch.randn(hidden) + weights[f"{p}.post_attention_layernorm.weight"] = torch.randn(hidden) + + weights["model.norm.weight"] = torch.randn(hidden) + weights["lm_head.weight"] = torch.randn(vocab, hidden) + + return {W(k): v for k, v in weights.items()} + + +# ============================================================================= +# Cache Test Fixtures - Tensor Dimensions +# ============================================================================= + + +@pytest.fixture +def batch_size(): + """Default batch size for cache tests.""" + return 2 + + +@pytest.fixture +def num_heads(): + """Default number of attention heads for cache tests.""" + return 4 + + +@pytest.fixture +def head_dim(): + """Default head dimension for cache tests.""" + return 16 + + +@pytest.fixture +def make_kv(batch_size, num_heads, head_dim): + """Factory fixture for creating KV tensors.""" + + def _make_kv(seq_len): + return ( + torch.randn(batch_size, num_heads, seq_len, head_dim), + torch.randn(batch_size, num_heads, seq_len, head_dim), + ) + + return _make_kv + + +# ============================================================================= +# Cache Test Fixtures - HuggingFace Cache Layers +# ============================================================================= + + +@pytest.fixture +def hf_dynamic_layer(): + """HuggingFace DynamicLayer for full attention contract testing.""" + from transformers.cache_utils import DynamicLayer + + return DynamicLayer() + + +@pytest.fixture +def hf_sliding_layer(window_size): + """HuggingFace DynamicSlidingWindowLayer for sliding window contract testing.""" + from transformers.cache_utils import DynamicSlidingWindowLayer + + return DynamicSlidingWindowLayer(sliding_window=window_size) + + +# ============================================================================= +# Cache Test Fixtures - Apriel2 Low-level Caches +# ============================================================================= + + +@pytest.fixture +def apriel_attention_cache(): + """Apriel2 attention cache without window (full attention).""" + return _AttentionCache(window=None) + + +@pytest.fixture +def apriel_sliding_cache(window_size): + """Apriel2 attention cache with sliding window.""" + return _AttentionCache(window=window_size) + + +@pytest.fixture +def ssm_cache(): + """Apriel2 SSM cache for Mamba/GDN/KDA layers.""" + return _SSMCache() + + +# ============================================================================= +# Cache Test Fixtures - Apriel2 Configs (Simple Versions) +# ============================================================================= + + +@pytest.fixture +def attention_config(): + """Pure attention config (2 layers, no sliding window).""" + from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config + + return Apriel2Config( + vocab_size=100, + hidden_size=64, + decoder={ + "type": "fixed", + "num_blocks": 2, + "block": { + "mixer": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16}, + "mlp": {"type": "mlp", "intermediate_size": 256}, + "normalization": {"type": "rms_norm", "epsilon": 1e-5}, + }, + }, + ) + + +@pytest.fixture +def swa_config(): + """Sliding window attention config (2 layers, window=8).""" + from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config + + return Apriel2Config( + vocab_size=100, + hidden_size=64, + decoder={ + "type": "fixed", + "num_blocks": 2, + "block": { + "mixer": { + "type": "attention", + "heads": 4, + "head_groups": 2, + "head_size": 16, + "window_size": 8, + }, + "mlp": {"type": "mlp", "intermediate_size": 256}, + "normalization": {"type": "rms_norm", "epsilon": 1e-5}, + }, + }, + ) + + +@pytest.fixture +def ssm_config(): + """Pure SSM config (2 layers).""" + from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config + + return Apriel2Config( + vocab_size=100, + hidden_size=64, + decoder={ + "type": "fixed", + "num_blocks": 2, + "block": { + "mixer": {"type": "mamba", "state_size": 16}, + "mlp": {"type": "mlp", "intermediate_size": 256}, + "normalization": {"type": "rms_norm", "epsilon": 1e-5}, + }, + }, + ) + + +@pytest.fixture +def stochastic_config(): + """Stochastic mixer config with attention and mamba (2 layers).""" + from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config + + return Apriel2Config( + vocab_size=100, + hidden_size=64, + decoder={ + "type": "fixed", + "num_blocks": 2, + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16}, + "mamba": {"type": "mamba", "state_size": 16}, + }, + }, + "mlp": {"type": "mlp", "intermediate_size": 256}, + "normalization": {"type": "rms_norm", "epsilon": 1e-5}, + }, + }, + ) + + +# Parameterized window size fixture (used by hf_sliding_layer and apriel_sliding_cache) +@pytest.fixture(params=[4, 8, 16, 32]) +def window_size(request): + """Parameterized window sizes for sliding window tests.""" + return request.param diff --git a/fast_llm_external_models/tests/test_apriel2/test_cache.py b/fast_llm_external_models/tests/test_apriel2/test_cache.py deleted file mode 100644 index ca8158b4f..000000000 --- a/fast_llm_external_models/tests/test_apriel2/test_cache.py +++ /dev/null @@ -1,1258 +0,0 @@ -"""Comprehensive tests for Apriel2Cache. - -Architecture Overview -===================== -Apriel2Cache manages state for autoregressive generation across different mixer types: - -1. **Attention Cache** (_AttentionCache): Stores key/value states - - Supports sliding window (window_size) for SWA - - Efficient roll optimization for single-token decode - -2. **SSM Cache** (_SSMCache): Stores conv and recurrent states - - Used by Mamba, GDN, KDA - - KDA uses tuple conv states (q, k, v), others use single tensor - -3. **Stochastic Mixer Routing**: For layers with multiple mixer options - - Each mixer has independent cache (no sharing) - - active_mixer pointer routes operations to correct sub-cache - - Switching mixers preserves each mixer's independent state - -Cache Invalidation Semantics -============================ -When switching between mixers in a stochastic layer: -- Each mixer maintains its OWN independent history -- Switching does NOT invalidate the previous mixer's cache -- Switching does NOT copy state between mixers -- To invalidate: call reset() explicitly - -This is intentional for training with stochastic sampling where each mixer -should learn from its own history. For inference, main_mixer_name is fixed. - -Test Organization -================= -1. CREATION & PROPERTIES - Cache initialization, config parsing -2. ATTENTION CACHE - Updates, sliding window, concatenation -3. SSM CACHE - Conv states, recurrent states, KDA tuples -4. STOCHASTIC ROUTING - Active mixer, isolation, switching -5. CACHE INVALIDATION - Reset, per-mixer reset, coherence -6. BEAM SEARCH - batch_repeat, reorder, select -7. HF INTEGRATION - get_mask_sizes, indexing, properties -8. GENERATION PATTERNS - Prefillโ†’decode, cropโ†’continue -9. ERROR HANDLING - Guards, bounds, invalid operations -""" - -import pytest -import torch - -from fast_llm_external_models.apriel2.cache import ( - Apriel2Cache, - _AttentionCache, - _SSMCache, -) - - -# ============================================================================= -# FIXTURES - Configs and Sample Data -# ============================================================================= - - -@pytest.fixture -def tiny_attention_config(): - """Minimal config with pure attention layers.""" - from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config - - return Apriel2Config( - vocab_size=100, - hidden_size=64, - decoder={ - "type": "fixed", - "num_blocks": 2, - "block": { - "mixer": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16}, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - }, - ) - - -@pytest.fixture -def swa_config(): - """Config with sliding window attention.""" - from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config - - return Apriel2Config( - vocab_size=100, - hidden_size=64, - decoder={ - "type": "fixed", - "num_blocks": 2, - "block": { - "mixer": { - "type": "attention", - "heads": 4, - "head_groups": 2, - "head_size": 16, - "window_size": 8, # Small for testing - }, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - }, - ) - - -@pytest.fixture -def ssm_config(): - """Config with pure SSM layers (mamba).""" - from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config - - return Apriel2Config( - vocab_size=100, - hidden_size=64, - decoder={ - "type": "fixed", - "num_blocks": 2, - "block": { - "mixer": { - "type": "mamba", - "d_inner": 128, - "d_state": 16, - "dt_rank": 4, - "d_conv": 4, - }, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - }, - ) - - -@pytest.fixture -def kda_config(): - """Config with pure KDA layers.""" - from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config - - return Apriel2Config( - vocab_size=100, - hidden_size=64, - decoder={ - "type": "fixed", - "num_blocks": 2, - "block": { - "mixer": { - "type": "kda", - "heads": 4, - "head_dim": 16, - "convolution_layer": {"kernel_size": 4}, - "normalization": {"epsilon": 1e-5}, - }, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - }, - ) - - -@pytest.fixture -def stochastic_config(): - """Config with stochastic mixer (attention + mamba).""" - from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config - - return Apriel2Config( - vocab_size=100, - hidden_size=64, - decoder={ - "type": "pattern", - "num_blocks": 2, - "pattern": ["attn", "stochastic"], - "blocks": { - "attn": { - "mixer": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16}, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - "stochastic": { - "mixer": { - "type": "stochastic", - "main_mixer_name": "attention", - "mixers": { - "attention": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16}, - "mamba": {"type": "mamba", "d_inner": 128, "d_state": 16, "dt_rank": 4, "d_conv": 4}, - }, - }, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - }, - }, - ) - - -@pytest.fixture -def all_mixers_config(): - """Config with stochastic mixer containing all 5 mixer types.""" - from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config - - return Apriel2Config( - vocab_size=100, - hidden_size=64, - decoder={ - "type": "pattern", - "num_blocks": 2, - "pattern": ["attn", "all_mixers"], - "blocks": { - "attn": { - "mixer": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16}, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - "all_mixers": { - "mixer": { - "type": "stochastic", - "main_mixer_name": "attention", - "mixers": { - "attention": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16}, - "swa": { - "type": "attention", - "heads": 4, - "head_groups": 2, - "head_size": 16, - "window_size": 1024, - }, - "mamba": {"type": "mamba", "d_inner": 128, "d_state": 16, "dt_rank": 4, "d_conv": 4}, - "gdn": { - "type": "gdn", - "value_heads": 4, - "key_heads": 2, - "key_head_dim": 16, - "value_head_dim": 16, - "convolution_layer": {"kernel_size": 4}, - }, - "kda": { - "type": "kda", - "heads": 4, - "head_dim": 16, - "convolution_layer": {"kernel_size": 4}, - "normalization": {"epsilon": 1e-5}, - }, - }, - }, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - }, - }, - ) - - -@pytest.fixture -def multi_window_config(): - """Config with multiple different window sizes.""" - from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config - - return Apriel2Config( - vocab_size=100, - hidden_size=64, - decoder={ - "type": "pattern", - "num_blocks": 3, - "pattern": ["full", "small_window", "large_window"], - "blocks": { - "full": { - "mixer": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16}, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - "small_window": { - "mixer": { - "type": "attention", - "heads": 4, - "head_groups": 2, - "head_size": 16, - "window_size": 512, - }, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - "large_window": { - "mixer": { - "type": "attention", - "heads": 4, - "head_groups": 2, - "head_size": 16, - "window_size": 2048, - }, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - }, - }, - ) - - -@pytest.fixture -def sample_kv(): - """Sample key/value tensors: [batch=2, heads=4, seq=10, head_dim=16].""" - return torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16) - - -@pytest.fixture -def sample_conv_single(): - """Sample single-tensor conv state: [batch=2, d_inner=128, kernel=4].""" - return torch.randn(2, 128, 4) - - -@pytest.fixture -def sample_conv_tuple(): - """Sample tuple conv state for KDA: (q, k, v) each [batch=2, d=64, kernel=3].""" - return (torch.randn(2, 64, 3), torch.randn(2, 64, 3), torch.randn(2, 64, 3)) - - -@pytest.fixture -def sample_recurrent(): - """Sample recurrent state: [batch=2, heads=4, head_dim=16, d_state=16].""" - return torch.randn(2, 4, 16, 16) - - -# ============================================================================= -# SECTION 1: CACHE CREATION & PROPERTIES -# ============================================================================= - - -class TestCacheCreation: - """Test cache initialization from config.""" - - def test_attention_cache_creation(self, tiny_attention_config): - """Create cache for pure attention config.""" - cache = Apriel2Cache(tiny_attention_config) - - assert len(cache) == 2 - assert cache.mixer_types == ["attention", "attention"] - assert all(isinstance(l, _AttentionCache) for l in cache.layers) - - def test_ssm_cache_creation(self, ssm_config): - """Create cache for pure SSM config.""" - cache = Apriel2Cache(ssm_config) - - assert len(cache) == 2 - assert cache.mixer_types == ["mamba", "mamba"] - assert all(isinstance(l, _SSMCache) for l in cache.layers) - - def test_kda_cache_creation(self, kda_config): - """Create cache for pure KDA config.""" - cache = Apriel2Cache(kda_config) - - assert len(cache) == 2 - assert cache.mixer_types == ["kda", "kda"] - assert all(isinstance(l, _SSMCache) for l in cache.layers) - - def test_stochastic_cache_creation(self, stochastic_config): - """Create cache for stochastic mixer config.""" - cache = Apriel2Cache(stochastic_config) - - assert len(cache) == 2 - # Layer 0: pure attention, Layer 1: stochastic (dict) - assert isinstance(cache.layers[0], _AttentionCache) - assert isinstance(cache.layers[1], dict) - assert set(cache.layers[1].keys()) == {"attention", "mamba"} - - def test_swa_window_captured(self, swa_config): - """Verify sliding window size is captured.""" - cache = Apriel2Cache(swa_config) - - assert cache.layers[0].window == 8 - assert cache.is_sliding == [True, True] - - def test_active_mixers_initialized_none(self, stochastic_config): - """Verify active_mixers starts as None for all layers.""" - cache = Apriel2Cache(stochastic_config) - - assert cache.active_mixers == [None, None] - - -class TestCacheProperties: - """Test cache property accessors.""" - - def test_empty_cache_properties(self, tiny_attention_config): - """Test properties of uninitialized cache.""" - cache = Apriel2Cache(tiny_attention_config) - - assert cache.is_initialized == False - assert cache.has_previous_state == False - assert cache.max_batch_size is None - assert cache.max_cache_len is None - assert cache.is_compileable == False - - def test_is_initialized_attention(self, tiny_attention_config, sample_kv): - """is_initialized detects attention cache.""" - cache = Apriel2Cache(tiny_attention_config) - cache.update(*sample_kv, layer_idx=0) - - assert cache.is_initialized == True - - def test_is_initialized_ssm(self, ssm_config, sample_conv_single): - """is_initialized detects SSM cache.""" - cache = Apriel2Cache(ssm_config) - cache.conv_states[0] = sample_conv_single - - assert cache.is_initialized == True - - def test_has_previous_state_ssm_only(self, ssm_config, sample_conv_single): - """has_previous_state only looks at SSM conv states.""" - cache = Apriel2Cache(ssm_config) - - assert cache.has_previous_state == False - cache.conv_states[0] = sample_conv_single - assert cache.has_previous_state == True - - def test_has_previous_state_ignores_attention(self, tiny_attention_config, sample_kv): - """has_previous_state ignores attention cache.""" - cache = Apriel2Cache(tiny_attention_config) - cache.update(*sample_kv, layer_idx=0) - - # Attention cache is set, but has_previous_state only checks SSM - assert cache.has_previous_state == False - - def test_max_batch_size_from_attention(self, tiny_attention_config, sample_kv): - """max_batch_size from attention cache.""" - cache = Apriel2Cache(tiny_attention_config) - cache.update(*sample_kv, layer_idx=0) - - assert cache.max_batch_size == 2 - - def test_max_batch_size_from_ssm(self, ssm_config, sample_conv_single): - """max_batch_size from SSM cache.""" - cache = Apriel2Cache(ssm_config) - cache.conv_states[0] = sample_conv_single - - assert cache.max_batch_size == 2 - - def test_max_batch_size_from_kda_tuple(self, kda_config, sample_conv_tuple): - """max_batch_size from KDA tuple conv state.""" - cache = Apriel2Cache(kda_config) - cache.conv_states[0] = sample_conv_tuple - - assert cache.max_batch_size == 2 - - def test_max_cache_len_single_window(self, swa_config): - """max_cache_len with single window size.""" - cache = Apriel2Cache(swa_config) - assert cache.max_cache_len == 8 - - def test_max_cache_len_multiple_windows(self, multi_window_config): - """max_cache_len returns minimum window.""" - cache = Apriel2Cache(multi_window_config) - assert cache.max_cache_len == 512 # min(512, 2048) - - def test_max_cache_len_no_windows(self, tiny_attention_config): - """max_cache_len is None when no windows.""" - cache = Apriel2Cache(tiny_attention_config) - assert cache.max_cache_len is None - - def test_is_sliding_mixed(self, multi_window_config): - """is_sliding reflects per-layer window presence.""" - cache = Apriel2Cache(multi_window_config) - assert cache.is_sliding == [False, True, True] - - -# ============================================================================= -# SECTION 2: ATTENTION CACHE OPERATIONS -# ============================================================================= - - -class TestAttentionCacheBasics: - """Test basic attention cache operations.""" - - def test_update_stores_kv(self, tiny_attention_config, sample_kv): - """update() stores key/value states.""" - cache = Apriel2Cache(tiny_attention_config) - key, value = sample_kv - - k_out, v_out = cache.update(key, value, layer_idx=0) - - torch.testing.assert_close(k_out, key) - torch.testing.assert_close(v_out, value) - assert cache.get_seq_length(0) == 10 - - def test_update_concatenates(self, tiny_attention_config, sample_kv): - """Subsequent updates concatenate.""" - cache = Apriel2Cache(tiny_attention_config) - key, value = sample_kv - - cache.update(key, value, layer_idx=0) - k_out, v_out = cache.update(key, value, layer_idx=0) - - assert k_out.shape[-2] == 20 - assert cache.get_seq_length(0) == 20 - - def test_key_value_cache_accessors(self, tiny_attention_config, sample_kv): - """Test key_cache and value_cache accessors.""" - cache = Apriel2Cache(tiny_attention_config) - cache.update(*sample_kv, layer_idx=0) - - assert cache.key_cache[0] is not None - assert cache.value_cache[0] is not None - torch.testing.assert_close(cache.key_cache[0], sample_kv[0]) - - -class TestSlidingWindowAttention: - """Test sliding window attention behavior.""" - - def test_initial_within_window(self, swa_config): - """Initial sequence within window is kept.""" - cache = Apriel2Cache(swa_config) - key = torch.randn(2, 4, 5, 16) # seq=5 < window=8 - value = torch.randn(2, 4, 5, 16) - - cache.update(key, value, layer_idx=0) - - assert cache.get_seq_length(0) == 5 - - def test_initial_exceeds_window(self, swa_config): - """Initial sequence > window is truncated to last window tokens.""" - cache = Apriel2Cache(swa_config) - key = torch.arange(12).float().view(1, 1, 12, 1).expand(2, 4, 12, 16) - value = key.clone() - - k_out, v_out = cache.update(key, value, layer_idx=0) - - assert cache.get_seq_length(0) == 8 - # Should keep tokens 4-11 (last 8) - assert k_out[0, 0, 0, 0].item() == 4.0 - - def test_single_token_roll_path(self, swa_config): - """Single token decode with full window uses efficient roll.""" - cache = Apriel2Cache(swa_config) - - # Fill window exactly - key1 = torch.arange(8).float().view(1, 1, 8, 1).expand(2, 4, 8, 16) - cache.update(key1, key1.clone(), layer_idx=0) - - # Decode single token - key2 = torch.full((2, 4, 1, 16), 8.0) - k_out, _ = cache.update(key2, key2.clone(), layer_idx=0) - - assert cache.get_seq_length(0) == 8 - assert k_out[0, 0, 0, 0].item() == 1.0 # Token 0 rolled out - assert k_out[0, 0, 7, 0].item() == 8.0 # New token at end - - def test_multi_token_cat_slice_path(self, swa_config): - """Multiple tokens use cat+slice path.""" - cache = Apriel2Cache(swa_config) - - # Fill window - key1 = torch.randn(2, 4, 8, 16) - cache.update(key1, key1.clone(), layer_idx=0) - - # Add 3 tokens - key2 = torch.randn(2, 4, 3, 16) - k_out, _ = cache.update(key2, key2.clone(), layer_idx=0) - - assert cache.get_seq_length(0) == 8 - torch.testing.assert_close(k_out[..., -3:, :], key2) - - def test_partial_then_fill_then_overflow(self, swa_config): - """Progressive filling: partial โ†’ full โ†’ overflow.""" - cache = Apriel2Cache(swa_config) - - cache.update(torch.randn(2, 4, 5, 16), torch.randn(2, 4, 5, 16), layer_idx=0) - assert cache.get_seq_length(0) == 5 - - cache.update(torch.randn(2, 4, 3, 16), torch.randn(2, 4, 3, 16), layer_idx=0) - assert cache.get_seq_length(0) == 8 - - cache.update(torch.randn(2, 4, 2, 16), torch.randn(2, 4, 2, 16), layer_idx=0) - assert cache.get_seq_length(0) == 8 - - def test_contiguous_output(self, swa_config): - """Outputs are contiguous after windowing.""" - cache = Apriel2Cache(swa_config) - - cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0) - cache.update(torch.randn(2, 4, 5, 16), torch.randn(2, 4, 5, 16), layer_idx=0) - - assert cache.layers[0].key.is_contiguous() - assert cache.layers[0].value.is_contiguous() - - -# ============================================================================= -# SECTION 3: SSM CACHE OPERATIONS -# ============================================================================= - - -class TestSSMCacheBasics: - """Test basic SSM cache operations.""" - - def test_conv_states_accessor(self, ssm_config, sample_conv_single): - """Test conv_states accessor.""" - cache = Apriel2Cache(ssm_config) - - cache.conv_states[0] = sample_conv_single - torch.testing.assert_close(cache.conv_states[0], sample_conv_single) - - def test_recurrent_states_accessor(self, ssm_config, sample_recurrent): - """Test recurrent_states accessor.""" - cache = Apriel2Cache(ssm_config) - - cache.recurrent_states[0] = sample_recurrent - torch.testing.assert_close(cache.recurrent_states[0], sample_recurrent) - - def test_ssm_seq_length_always_zero(self, ssm_config, sample_conv_single): - """get_seq_length returns 0 for SSM (no KV cache).""" - cache = Apriel2Cache(ssm_config) - cache.conv_states[0] = sample_conv_single - - assert cache.get_seq_length(0) == 0 - - -class TestKDACache: - """Test KDA-specific cache operations with tuple conv states.""" - - def test_tuple_conv_storage(self, kda_config, sample_conv_tuple): - """KDA stores tuple conv states.""" - cache = Apriel2Cache(kda_config) - - cache.conv_states[0] = sample_conv_tuple - - assert isinstance(cache.conv_states[0], tuple) - assert len(cache.conv_states[0]) == 3 - for i in range(3): - torch.testing.assert_close(cache.conv_states[0][i], sample_conv_tuple[i]) - - def test_tuple_with_recurrent(self, kda_config, sample_conv_tuple, sample_recurrent): - """KDA can have both tuple conv and recurrent states.""" - cache = Apriel2Cache(kda_config) - - cache.conv_states[0] = sample_conv_tuple - cache.recurrent_states[0] = sample_recurrent - - assert isinstance(cache.conv_states[0], tuple) - assert cache.recurrent_states[0] is not None - - def test_has_previous_state_detects_tuple(self, kda_config, sample_conv_tuple): - """has_previous_state works with tuple conv states.""" - cache = Apriel2Cache(kda_config) - - assert cache.has_previous_state == False - cache.conv_states[0] = sample_conv_tuple - assert cache.has_previous_state == True - - -# ============================================================================= -# SECTION 4: STOCHASTIC ROUTING -# ============================================================================= - - -class TestStochasticRouting: - """Test stochastic mixer cache routing.""" - - def test_set_active_mixer(self, stochastic_config): - """set_active_mixer sets the pointer.""" - cache = Apriel2Cache(stochastic_config) - - cache.set_active_mixer(1, "attention") - assert cache.active_mixers[1] == "attention" - - cache.set_active_mixer(1, "mamba") - assert cache.active_mixers[1] == "mamba" - - def test_operations_route_to_active(self, stochastic_config, sample_kv): - """Operations route to currently active mixer.""" - cache = Apriel2Cache(stochastic_config) - - cache.set_active_mixer(1, "attention") - cache.update(*sample_kv, layer_idx=1) - attn_len = cache.get_seq_length(1) - - cache.set_active_mixer(1, "mamba") - mamba_len = cache.get_seq_length(1) - - assert attn_len == 10 - assert mamba_len == 0 # Mamba cache is separate and empty - - def test_each_mixer_independent_cache(self, stochastic_config, sample_kv, sample_conv_single): - """Each mixer maintains independent cache.""" - cache = Apriel2Cache(stochastic_config) - - # Fill attention cache - cache.set_active_mixer(1, "attention") - cache.update(*sample_kv, layer_idx=1) - - # Fill mamba cache - cache.set_active_mixer(1, "mamba") - cache.conv_states[1] = sample_conv_single - - # Both preserved - cache.set_active_mixer(1, "attention") - assert cache.get_seq_length(1) == 10 - - cache.set_active_mixer(1, "mamba") - torch.testing.assert_close(cache.conv_states[1], sample_conv_single) - - -class TestMixerSwitching: - """Test behavior when switching between mixers mid-generation.""" - - def test_switch_preserves_previous_state(self, stochastic_config, sample_kv): - """Switching mixers preserves previous mixer's state.""" - cache = Apriel2Cache(stochastic_config) - - cache.set_active_mixer(1, "attention") - cache.update(*sample_kv, layer_idx=1) - original_key = cache.layers[1]["attention"].key.clone() - - # Switch to mamba, do something - cache.set_active_mixer(1, "mamba") - cache.conv_states[1] = torch.randn(2, 128, 4) - - # Switch back - attention unchanged - cache.set_active_mixer(1, "attention") - torch.testing.assert_close(cache.layers[1]["attention"].key, original_key) - - def test_switch_does_not_copy_state(self, stochastic_config, sample_kv): - """Switching does NOT copy state between mixers.""" - cache = Apriel2Cache(stochastic_config) - - # Fill attention with 10 tokens - cache.set_active_mixer(1, "attention") - cache.update(*sample_kv, layer_idx=1) - - # Switch to mamba - it has NO history from attention - cache.set_active_mixer(1, "mamba") - assert cache.conv_states[1] is None - assert cache.recurrent_states[1] is None - - def test_has_previous_state_checks_all_sub_caches(self, stochastic_config): - """has_previous_state checks ALL sub-caches, not just active.""" - cache = Apriel2Cache(stochastic_config) - - cache.set_active_mixer(1, "mamba") - cache.conv_states[1] = torch.randn(2, 128, 4) - - # Even if we switch away, has_previous_state still detects it - cache.set_active_mixer(1, "attention") - assert cache.has_previous_state == True - - -class TestAllMixerTypes: - """Test cache isolation across all 5 mixer types.""" - - def test_all_five_mixer_types_isolated(self, all_mixers_config): - """All 5 mixer types maintain isolated caches.""" - cache = Apriel2Cache(all_mixers_config) - layer_idx = 1 # Stochastic layer - - # Fill each mixer's cache - cache.set_active_mixer(layer_idx, "attention") - attn_kv = (torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16)) - cache.update(*attn_kv, layer_idx=layer_idx) - - cache.set_active_mixer(layer_idx, "swa") - swa_kv = (torch.randn(2, 4, 5, 16), torch.randn(2, 4, 5, 16)) - cache.update(*swa_kv, layer_idx=layer_idx) - - cache.set_active_mixer(layer_idx, "mamba") - mamba_conv = torch.randn(2, 128, 4) - cache.conv_states[layer_idx] = mamba_conv - - cache.set_active_mixer(layer_idx, "gdn") - gdn_conv = torch.randn(2, 64, 3) - cache.conv_states[layer_idx] = gdn_conv - - cache.set_active_mixer(layer_idx, "kda") - kda_conv = (torch.randn(2, 64, 3), torch.randn(2, 64, 3), torch.randn(2, 64, 3)) - cache.conv_states[layer_idx] = kda_conv - - # Verify all preserved - cache.set_active_mixer(layer_idx, "attention") - assert cache.get_seq_length(layer_idx) == 10 - - cache.set_active_mixer(layer_idx, "swa") - assert cache.get_seq_length(layer_idx) == 5 - - cache.set_active_mixer(layer_idx, "mamba") - torch.testing.assert_close(cache.conv_states[layer_idx], mamba_conv) - - cache.set_active_mixer(layer_idx, "gdn") - torch.testing.assert_close(cache.conv_states[layer_idx], gdn_conv) - - cache.set_active_mixer(layer_idx, "kda") - assert isinstance(cache.conv_states[layer_idx], tuple) - - -# ============================================================================= -# SECTION 5: CACHE INVALIDATION -# ============================================================================= - - -class TestCacheInvalidation: - """Test cache invalidation and reset semantics. - - Key principle: Each mixer maintains independent state. To invalidate: - - reset() clears ALL caches across ALL layers and mixers - - There is no per-mixer reset (by design - each mixer is independent) - """ - - def test_reset_clears_attention(self, tiny_attention_config, sample_kv): - """reset() clears attention cache.""" - cache = Apriel2Cache(tiny_attention_config) - cache.update(*sample_kv, layer_idx=0) - - cache.reset() - - assert cache.is_initialized == False - assert cache.get_seq_length(0) == 0 - - def test_reset_clears_ssm(self, ssm_config, sample_conv_single, sample_recurrent): - """reset() clears SSM cache.""" - cache = Apriel2Cache(ssm_config) - cache.conv_states[0] = sample_conv_single - cache.recurrent_states[0] = sample_recurrent - - cache.reset() - - assert cache.has_previous_state == False - assert cache.conv_states[0] is None - assert cache.recurrent_states[0] is None - - def test_reset_clears_kda_tuple(self, kda_config, sample_conv_tuple): - """reset() clears KDA tuple conv states.""" - cache = Apriel2Cache(kda_config) - cache.conv_states[0] = sample_conv_tuple - - cache.reset() - - assert cache.conv_states[0] is None - - def test_reset_clears_all_stochastic_mixers(self, all_mixers_config): - """reset() clears ALL mixer caches in stochastic layer.""" - cache = Apriel2Cache(all_mixers_config) - layer_idx = 1 - - # Fill all mixers - cache.set_active_mixer(layer_idx, "attention") - cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=layer_idx) - - cache.set_active_mixer(layer_idx, "mamba") - cache.conv_states[layer_idx] = torch.randn(2, 128, 4) - - cache.set_active_mixer(layer_idx, "kda") - cache.conv_states[layer_idx] = (torch.randn(2, 64, 3),) * 3 - - cache.reset() - - # All cleared - assert cache.layers[layer_idx]["attention"].key is None - assert cache.layers[layer_idx]["mamba"].conv is None - assert cache.layers[layer_idx]["kda"].conv is None - - def test_crop_truncates_attention(self, tiny_attention_config, sample_kv): - """crop() truncates attention cache to max_length.""" - cache = Apriel2Cache(tiny_attention_config) - cache.update(*sample_kv, layer_idx=0) - - cache.crop(5) - - assert cache.get_seq_length(0) == 5 - - def test_crop_affects_all_layers(self, tiny_attention_config, sample_kv): - """crop() affects all layers.""" - cache = Apriel2Cache(tiny_attention_config) - cache.update(*sample_kv, layer_idx=0) - cache.update(*sample_kv, layer_idx=1) - - cache.crop(3) - - assert cache.get_seq_length(0) == 3 - assert cache.get_seq_length(1) == 3 - - def test_crop_ignores_ssm(self, ssm_config, sample_conv_single): - """crop() only affects attention, not SSM.""" - cache = Apriel2Cache(ssm_config) - cache.conv_states[0] = sample_conv_single - - cache.crop(5) # Should not crash - - # Conv state unchanged - torch.testing.assert_close(cache.conv_states[0], sample_conv_single) - - -# ============================================================================= -# SECTION 6: BEAM SEARCH OPERATIONS -# ============================================================================= - - -class TestBatchRepeatInterleave: - """Test batch_repeat_interleave for beam search expansion.""" - - def test_repeat_attention(self, tiny_attention_config, sample_kv): - """Repeat attention cache for beam search.""" - cache = Apriel2Cache(tiny_attention_config) - cache.update(*sample_kv, layer_idx=0) - - cache.batch_repeat_interleave(3) - - assert cache.max_batch_size == 6 # 2 * 3 - - def test_repeat_ssm(self, ssm_config, sample_conv_single, sample_recurrent): - """Repeat SSM cache for beam search.""" - cache = Apriel2Cache(ssm_config) - cache.conv_states[0] = sample_conv_single - cache.recurrent_states[0] = sample_recurrent - - cache.batch_repeat_interleave(4) - - assert cache.conv_states[0].shape[0] == 8 # 2 * 4 - assert cache.recurrent_states[0].shape[0] == 8 - - def test_repeat_kda_tuple(self, kda_config, sample_conv_tuple): - """Repeat KDA tuple conv states.""" - cache = Apriel2Cache(kda_config) - cache.conv_states[0] = sample_conv_tuple - - cache.batch_repeat_interleave(3) - - for c in cache.conv_states[0]: - assert c.shape[0] == 6 - - def test_repeat_stochastic_all_mixers(self, all_mixers_config): - """Repeat all mixer caches in stochastic layer.""" - cache = Apriel2Cache(all_mixers_config) - layer_idx = 1 - - cache.set_active_mixer(layer_idx, "attention") - cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=layer_idx) - - cache.set_active_mixer(layer_idx, "mamba") - cache.conv_states[layer_idx] = torch.randn(2, 128, 4) - - cache.batch_repeat_interleave(2) - - cache.set_active_mixer(layer_idx, "attention") - assert cache.layers[layer_idx]["attention"].key.shape[0] == 4 - - cache.set_active_mixer(layer_idx, "mamba") - assert cache.conv_states[layer_idx].shape[0] == 4 - - def test_repeat_skips_none(self, tiny_attention_config): - """Repeat gracefully skips None caches.""" - cache = Apriel2Cache(tiny_attention_config) - # Don't fill anything - - cache.batch_repeat_interleave(3) # Should not crash - - assert cache.max_batch_size is None - - -class TestReorderCache: - """Test reorder_cache for beam search hypothesis selection.""" - - def test_reorder_attention(self, tiny_attention_config, sample_kv): - """Reorder attention cache.""" - cache = Apriel2Cache(tiny_attention_config) - key, value = sample_kv - # Make batches distinguishable - key = torch.arange(2).float().view(2, 1, 1, 1).expand(2, 4, 10, 16) - cache.update(key, key.clone(), layer_idx=0) - - beam_idx = torch.tensor([1, 0]) - cache.reorder_cache(beam_idx) - - assert cache.layers[0].key[0, 0, 0, 0].item() == 1.0 - assert cache.layers[0].key[1, 0, 0, 0].item() == 0.0 - - def test_reorder_ssm(self, ssm_config): - """Reorder SSM cache.""" - cache = Apriel2Cache(ssm_config) - conv = torch.arange(2).float().view(2, 1, 1).expand(2, 128, 4) - cache.conv_states[0] = conv.clone() - - beam_idx = torch.tensor([1, 0]) - cache.reorder_cache(beam_idx) - - assert cache.conv_states[0][0, 0, 0].item() == 1.0 - - def test_reorder_kda_tuple(self, kda_config): - """Reorder KDA tuple conv states.""" - cache = Apriel2Cache(kda_config) - conv_q = torch.arange(2).float().view(2, 1, 1).expand(2, 64, 3) - cache.conv_states[0] = (conv_q.clone(), conv_q.clone(), conv_q.clone()) - - beam_idx = torch.tensor([1, 0]) - cache.reorder_cache(beam_idx) - - for c in cache.conv_states[0]: - assert c[0, 0, 0].item() == 1.0 - - -class TestBatchSelectIndices: - """Test batch_select_indices for beam selection.""" - - def test_select_attention(self, tiny_attention_config, sample_kv): - """Select subset of attention cache.""" - cache = Apriel2Cache(tiny_attention_config) - key = torch.arange(4).float().view(4, 1, 1, 1).expand(4, 4, 10, 16) - cache.update(key, key.clone(), layer_idx=0) - - indices = torch.tensor([0, 3]) - cache.batch_select_indices(indices) - - assert cache.max_batch_size == 2 - assert cache.layers[0].key[0, 0, 0, 0].item() == 0.0 - assert cache.layers[0].key[1, 0, 0, 0].item() == 3.0 - - def test_select_kda_tuple(self, kda_config): - """Select subset of KDA tuple conv states.""" - cache = Apriel2Cache(kda_config) - conv = tuple(torch.arange(4).float().view(4, 1, 1).expand(4, 64, 3).clone() for _ in range(3)) - cache.conv_states[0] = conv - - indices = torch.tensor([1, 2]) - cache.batch_select_indices(indices) - - for c in cache.conv_states[0]: - assert c.shape[0] == 2 - assert c[0, 0, 0].item() == 1.0 - - -# ============================================================================= -# SECTION 7: HUGGINGFACE INTEGRATION -# ============================================================================= - - -class TestGetMaskSizes: - """Test get_mask_sizes() for attention mask computation.""" - - def test_empty_cache(self, tiny_attention_config): - """Mask sizes with empty cache.""" - cache = Apriel2Cache(tiny_attention_config) - cache_position = torch.arange(10) - - kv_length, kv_offset = cache.get_mask_sizes(cache_position, layer_idx=0) - - assert kv_length == 10 - assert kv_offset == 0 - - def test_with_cached_tokens(self, tiny_attention_config, sample_kv): - """Mask sizes with cached tokens.""" - cache = Apriel2Cache(tiny_attention_config) - cache.update(*sample_kv, layer_idx=0) # 10 tokens - - cache_position = torch.arange(5) - kv_length, kv_offset = cache.get_mask_sizes(cache_position, layer_idx=0) - - assert kv_length == 15 # 10 + 5 - assert kv_offset == 10 - - def test_single_token_decode(self, tiny_attention_config, sample_kv): - """Mask sizes for single token decode.""" - cache = Apriel2Cache(tiny_attention_config) - cache.update(*sample_kv, layer_idx=0) - - cache_position = torch.arange(1) - kv_length, kv_offset = cache.get_mask_sizes(cache_position, layer_idx=0) - - assert kv_length == 11 - assert kv_offset == 10 - - def test_ssm_returns_query_only(self, ssm_config, sample_conv_single): - """SSM layers return query_length (no KV cache).""" - cache = Apriel2Cache(ssm_config) - cache.conv_states[0] = sample_conv_single - - cache_position = torch.arange(5) - kv_length, kv_offset = cache.get_mask_sizes(cache_position, layer_idx=0) - - assert kv_length == 5 - assert kv_offset == 0 - - -class TestCacheIndexing: - """Test cache[idx] indexing.""" - - def test_attention_returns_kv(self, tiny_attention_config, sample_kv): - """Indexing attention layer returns (key, value).""" - cache = Apriel2Cache(tiny_attention_config) - cache.update(*sample_kv, layer_idx=0) - - result = cache[0] - - assert isinstance(result, tuple) - torch.testing.assert_close(result[0], sample_kv[0]) - - def test_empty_returns_empty_tensors(self, tiny_attention_config): - """Indexing empty layer returns empty tensors.""" - cache = Apriel2Cache(tiny_attention_config) - - result = cache[0] - - assert result[0].numel() == 0 - assert result[1].numel() == 0 - - def test_ssm_returns_empty(self, ssm_config, sample_conv_single): - """Indexing SSM layer returns empty (no KV).""" - cache = Apriel2Cache(ssm_config) - cache.conv_states[0] = sample_conv_single - - result = cache[0] - - assert result[0].numel() == 0 - - def test_stochastic_attention_returns_kv(self, stochastic_config, sample_kv): - """Indexing stochastic with attention active returns KV.""" - cache = Apriel2Cache(stochastic_config) - cache.set_active_mixer(1, "attention") - cache.update(*sample_kv, layer_idx=1) - - result = cache[1] - - torch.testing.assert_close(result[0], sample_kv[0]) - - -# ============================================================================= -# SECTION 8: GENERATION PATTERNS -# ============================================================================= - - -class TestGenerationPatterns: - """Test real-world generation patterns.""" - - def test_prefill_then_decode(self, tiny_attention_config, sample_kv): - """Prefill with long prompt, then decode token-by-token.""" - cache = Apriel2Cache(tiny_attention_config) - cache.update(*sample_kv, layer_idx=0) # Prefill 10 tokens - - for _ in range(5): - new_kv = (torch.randn(2, 4, 1, 16), torch.randn(2, 4, 1, 16)) - cache.update(*new_kv, layer_idx=0) - - assert cache.get_seq_length(0) == 15 - - def test_crop_then_continue(self, tiny_attention_config, sample_kv): - """Crop old context, continue generation.""" - cache = Apriel2Cache(tiny_attention_config) - cache.update(*sample_kv, layer_idx=0) - cache.update(*sample_kv, layer_idx=0) # 20 tokens - - cache.crop(5) # Keep last 5 - cache.update(torch.randn(2, 4, 3, 16), torch.randn(2, 4, 3, 16), layer_idx=0) - - assert cache.get_seq_length(0) == 8 - - def test_reset_between_generations(self, tiny_attention_config, sample_kv): - """Reset between independent generations.""" - cache = Apriel2Cache(tiny_attention_config) - - # First generation - cache.update(*sample_kv, layer_idx=0) - assert cache.is_initialized == True - - # Reset - cache.reset() - assert cache.is_initialized == False - - # Second generation - cache.update(*sample_kv, layer_idx=0) - assert cache.get_seq_length(0) == 10 - - def test_multi_layer_consistency(self, tiny_attention_config, sample_kv): - """All layers updated consistently.""" - cache = Apriel2Cache(tiny_attention_config) - - for layer_idx in range(2): - cache.update(*sample_kv, layer_idx=layer_idx) - cache.update(torch.randn(2, 4, 1, 16), torch.randn(2, 4, 1, 16), layer_idx=layer_idx) - - for layer_idx in range(2): - assert cache.get_seq_length(layer_idx) == 11 - - -# ============================================================================= -# SECTION 9: ERROR HANDLING -# ============================================================================= - - -class TestErrorHandling: - """Test error conditions and guards.""" - - def test_stochastic_update_without_active_mixer(self, stochastic_config): - """update() on stochastic without active_mixer raises.""" - cache = Apriel2Cache(stochastic_config) - - with pytest.raises(RuntimeError, match="needs active_mixer set"): - cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=1) - - def test_stochastic_accessor_without_active_mixer(self, stochastic_config): - """Accessing stochastic cache without active_mixer raises.""" - cache = Apriel2Cache(stochastic_config) - - with pytest.raises(RuntimeError, match="requires set_active_mixer"): - _ = cache.conv_states[1] - - def test_accessor_error_lists_available_mixers(self, stochastic_config): - """Error message lists available mixers.""" - cache = Apriel2Cache(stochastic_config) - - with pytest.raises(RuntimeError, match="Available mixers:"): - _ = cache.key_cache[1] - - def test_invalid_mixer_name(self, stochastic_config): - """Invalid mixer name raises KeyError on access.""" - cache = Apriel2Cache(stochastic_config) - cache.set_active_mixer(1, "nonexistent") - - with pytest.raises(KeyError): - cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=1) - - def test_layer_idx_out_of_bounds(self, tiny_attention_config): - """Out-of-bounds layer_idx raises IndexError.""" - cache = Apriel2Cache(tiny_attention_config) - - with pytest.raises(IndexError): - cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=999) - - -# ============================================================================= -# SECTION 10: INTERNAL CLASSES -# ============================================================================= - - -class TestAttentionCacheInternal: - """Test internal _AttentionCache class directly.""" - - def test_unbounded_growth(self): - """No window allows unbounded growth.""" - cache = _AttentionCache(window=None) - - for _ in range(10): - cache.update(torch.randn(2, 4, 100, 16), torch.randn(2, 4, 100, 16)) - - assert cache.key.shape[-2] == 1000 - - def test_window_enforced(self): - """Window caps cache size.""" - cache = _AttentionCache(window=50) - - for _ in range(10): - cache.update(torch.randn(2, 4, 100, 16), torch.randn(2, 4, 100, 16)) - - assert cache.key.shape[-2] == 50 - - -class TestSSMCacheInternal: - """Test internal _SSMCache class directly.""" - - def test_initial_none(self): - """Initial states are None.""" - cache = _SSMCache() - - assert cache.conv is None - assert cache.recurrent is None - - def test_stores_tuple(self): - """Can store tuple (for KDA).""" - cache = _SSMCache() - cache.conv = (torch.randn(2, 64, 3),) * 3 - - assert isinstance(cache.conv, tuple) diff --git a/fast_llm_external_models/tests/test_apriel2/test_cache_apriel2_specific.py b/fast_llm_external_models/tests/test_apriel2/test_cache_apriel2_specific.py new file mode 100644 index 000000000..b45779454 --- /dev/null +++ b/fast_llm_external_models/tests/test_apriel2/test_cache_apriel2_specific.py @@ -0,0 +1,341 @@ +"""Tests for Apriel2-specific cache behaviors with no HuggingFace equivalent. + +This module tests features unique to Apriel2Cache that cannot be validated +against upstream HF implementations: + +1. Stochastic mixer routing (switching between attention/SSM per layer) +2. Multi-mixer layer support +3. Error handling and guard rails +4. Beam search operations (batch_repeat, reorder, select) +5. Crop operation + +Fixtures used from conftest.py: + - stochastic_config: Stochastic mixer config with attention and mamba + - attention_config: Pure attention config + - ssm_config: Pure SSM config +""" + +import pytest +import torch + +from fast_llm_external_models.apriel2.cache import Apriel2Cache + +# ============================================================================= +# STOCHASTIC MIXER ROUTING +# ============================================================================= + + +class TestStochasticMixerRouting: + """Test routing operations to correct sub-cache in stochastic layers.""" + + def test_set_active_mixer(self, stochastic_config): + """set_active_mixer updates routing for layer.""" + cache = Apriel2Cache(stochastic_config) + + cache.set_active_mixer(0, "attention") + assert cache.active_mixers[0] == "attention" + + cache.set_active_mixer(0, "mamba") + assert cache.active_mixers[0] == "mamba" + + def test_update_routes_to_active_mixer(self, stochastic_config): + """update() stores in correct sub-cache based on active_mixer.""" + cache = Apriel2Cache(stochastic_config) + + # Route to attention + cache.set_active_mixer(0, "attention") + cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0) + + # Attention sub-cache should have data + assert cache.layers[0]["attention"].key is not None + # Mamba sub-cache should be empty + assert cache.layers[0]["mamba"].conv is None + + def test_each_mixer_has_independent_cache(self, stochastic_config): + """Each mixer in a stochastic layer has its own independent state.""" + cache = Apriel2Cache(stochastic_config) + + # Store in attention + cache.set_active_mixer(0, "attention") + cache.update(torch.randn(2, 4, 5, 16), torch.randn(2, 4, 5, 16), layer_idx=0) + + # Switch to mamba and store + cache.set_active_mixer(0, "mamba") + cache.layers[0]["mamba"].conv = torch.randn(2, 64, 4) + + # Attention data should be unchanged + assert cache.layers[0]["attention"].cumulative_length == 5 + + def test_switching_preserves_all_states(self, stochastic_config): + """Switching active_mixer doesn't clear other mixer's state.""" + cache = Apriel2Cache(stochastic_config) + + # Build up attention state + cache.set_active_mixer(0, "attention") + cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0) + attn_key = cache.layers[0]["attention"].key.clone() + + # Switch to mamba + cache.set_active_mixer(0, "mamba") + + # Attention state preserved + torch.testing.assert_close(cache.layers[0]["attention"].key, attn_key) + + +# ============================================================================= +# ERROR HANDLING +# ============================================================================= + + +class TestErrorHandling: + """Test guard rails and error messages.""" + + def test_update_without_active_mixer_raises(self, stochastic_config): + """update() on stochastic layer without active_mixer raises RuntimeError.""" + cache = Apriel2Cache(stochastic_config) + + with pytest.raises(RuntimeError, match="needs active_mixer set"): + cache.update(torch.randn(2, 4, 5, 16), torch.randn(2, 4, 5, 16), layer_idx=0) + + def test_accessor_without_active_mixer_raises(self, stochastic_config): + """Accessing key_cache/value_cache without active_mixer raises RuntimeError.""" + cache = Apriel2Cache(stochastic_config) + + with pytest.raises(RuntimeError, match="requires set_active_mixer"): + _ = cache.key_cache[0] + + def test_error_message_lists_available_mixers(self, stochastic_config): + """Error message includes list of available mixers.""" + cache = Apriel2Cache(stochastic_config) + + with pytest.raises(RuntimeError, match="attention.*mamba|mamba.*attention"): + _ = cache.key_cache[0] + + +# ============================================================================= +# BEAM SEARCH OPERATIONS +# ============================================================================= + + +class TestBeamSearchOperations: + """Test batch manipulation for beam search.""" + + def test_batch_repeat_interleave_attention(self, attention_config): + """batch_repeat_interleave expands batch dimension.""" + cache = Apriel2Cache(attention_config) + cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0) + + cache.batch_repeat_interleave(3) + + assert cache.layers[0].key.shape[0] == 6 # 2 * 3 + + def test_batch_repeat_interleave_ssm(self, ssm_config): + """batch_repeat_interleave works for SSM caches.""" + cache = Apriel2Cache(ssm_config) + cache.layers[0].conv = torch.randn(2, 64, 4) + + cache.batch_repeat_interleave(3) + + assert cache.layers[0].conv.shape[0] == 6 + + def test_batch_repeat_interleave_kda_tuple(self, ssm_config): + """batch_repeat_interleave handles KDA tuple conv states.""" + cache = Apriel2Cache(ssm_config) + cache.layers[0].conv = (torch.randn(2, 64, 4),) * 3 + + cache.batch_repeat_interleave(3) + + assert cache.layers[0].conv[0].shape[0] == 6 + + def test_reorder_cache_attention(self, attention_config): + """reorder_cache reorders batch dimension.""" + cache = Apriel2Cache(attention_config) + k = torch.arange(4).float().view(4, 1, 1, 1).expand(4, 4, 10, 16) + cache.update(k, k.clone(), layer_idx=0) + + beam_idx = torch.tensor([3, 2, 1, 0]) + cache.reorder_cache(beam_idx) + + # Check reordering + assert cache.layers[0].key[0, 0, 0, 0].item() == 3.0 + assert cache.layers[0].key[3, 0, 0, 0].item() == 0.0 + + def test_batch_select_indices(self, attention_config): + """batch_select_indices selects subset of batch.""" + cache = Apriel2Cache(attention_config) + cache.update(torch.randn(4, 4, 10, 16), torch.randn(4, 4, 10, 16), layer_idx=0) + + indices = torch.tensor([0, 2]) + cache.batch_select_indices(indices) + + assert cache.layers[0].key.shape[0] == 2 + + def test_reorder_cache_ssm_tuple(self, ssm_config): + """reorder_cache handles KDA tuple conv states.""" + cache = Apriel2Cache(ssm_config) + # Create distinguishable tensors for each batch position + conv0 = torch.full((1, 64, 4), 0.0) + conv1 = torch.full((1, 64, 4), 1.0) + conv2 = torch.full((1, 64, 4), 2.0) + cache.layers[0].conv = ( + torch.cat([conv0, conv1, conv2], dim=0), + torch.cat([conv0, conv1, conv2], dim=0), + torch.cat([conv0, conv1, conv2], dim=0), + ) + + beam_idx = torch.tensor([2, 1, 0]) + cache.reorder_cache(beam_idx) + + # Check reordering: batch[0] should now have value 2.0 + assert cache.layers[0].conv[0][0, 0, 0].item() == 2.0 + assert cache.layers[0].conv[0][2, 0, 0].item() == 0.0 + + def test_batch_select_indices_ssm_tuple(self, ssm_config): + """batch_select_indices handles KDA tuple conv states.""" + cache = Apriel2Cache(ssm_config) + cache.layers[0].conv = (torch.randn(4, 64, 4),) * 3 + + indices = torch.tensor([0, 2]) + cache.batch_select_indices(indices) + + assert cache.layers[0].conv[0].shape[0] == 2 + assert cache.layers[0].conv[1].shape[0] == 2 + assert cache.layers[0].conv[2].shape[0] == 2 + + +# ============================================================================= +# CROP OPERATION +# ============================================================================= + + +class TestCropOperation: + """Test cache truncation.""" + + def test_crop_truncates_attention(self, attention_config): + """crop() truncates attention cache.""" + cache = Apriel2Cache(attention_config) + cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0) + + cache.crop(5) + + assert cache.layers[0].key.shape[-2] == 5 + assert cache.get_seq_length(0) == 5 + + def test_crop_affects_all_layers(self, attention_config): + """crop() affects all layers.""" + cache = Apriel2Cache(attention_config) + cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0) + cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=1) + + cache.crop(3) + + assert cache.layers[0].key.shape[-2] == 3 + assert cache.layers[1].key.shape[-2] == 3 + + def test_crop_ignores_ssm(self, ssm_config): + """crop() doesn't affect SSM caches (they don't have seq dimension).""" + cache = Apriel2Cache(ssm_config) + cache.layers[0].conv = torch.randn(2, 64, 4) + + # Should not raise + cache.crop(5) + + # SSM state unchanged + assert cache.layers[0].conv.shape == (2, 64, 4) + + +# ============================================================================= +# CACHE PROPERTIES +# ============================================================================= + + +class TestCacheProperties: + """Test cache property methods.""" + + def test_is_initialized_attention(self, attention_config): + """is_initialized True after update.""" + cache = Apriel2Cache(attention_config) + assert not cache.is_initialized + + cache.update(torch.randn(2, 4, 5, 16), torch.randn(2, 4, 5, 16), layer_idx=0) + assert cache.is_initialized + + def test_is_initialized_ssm(self, ssm_config): + """is_initialized True after setting conv state.""" + cache = Apriel2Cache(ssm_config) + assert not cache.is_initialized + + cache.layers[0].conv = torch.randn(2, 64, 4) + assert cache.is_initialized + + def test_has_previous_state_ssm_only(self, ssm_config): + """has_previous_state checks SSM conv states.""" + cache = Apriel2Cache(ssm_config) + assert not cache.has_previous_state + + cache.layers[0].conv = torch.randn(2, 64, 4) + assert cache.has_previous_state + + def test_has_previous_state_ignores_attention(self, attention_config): + """has_previous_state ignores attention caches.""" + cache = Apriel2Cache(attention_config) + cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0) + + # Attention-only cache returns False for has_previous_state + assert not cache.has_previous_state + + def test_reset_clears_ssm_states(self, ssm_config): + """reset() clears SSM conv and recurrent states.""" + cache = Apriel2Cache(ssm_config) + cache.layers[0].conv = torch.randn(2, 64, 4) + cache.layers[0].recurrent = torch.randn(2, 64, 16) + + cache.reset() + + assert cache.layers[0].conv is None + assert cache.layers[0].recurrent is None + + def test_max_batch_size_from_ssm_tuple(self, ssm_config): + """max_batch_size works with KDA tuple conv states.""" + cache = Apriel2Cache(ssm_config) + cache.layers[0].conv = (torch.randn(3, 64, 4),) * 3 + + assert cache.max_batch_size == 3 + + def test_max_batch_size(self, attention_config): + """max_batch_size returns batch dimension.""" + cache = Apriel2Cache(attention_config) + cache.update(torch.randn(3, 4, 10, 16), torch.randn(3, 4, 10, 16), layer_idx=0) + + assert cache.max_batch_size == 3 + + def test_len_returns_num_layers(self, attention_config): + """__len__ returns number of layers.""" + cache = Apriel2Cache(attention_config) + assert len(cache) == 2 + + +# ============================================================================= +# INDEXING +# ============================================================================= + + +class TestCacheIndexing: + """Test __getitem__ for HF compatibility.""" + + def test_getitem_returns_kv_tuple(self, attention_config): + """cache[idx] returns (key, value) tuple.""" + cache = Apriel2Cache(attention_config) + cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0) + + k, v = cache[0] + assert k.shape == (2, 4, 10, 16) + assert v.shape == (2, 4, 10, 16) + + def test_getitem_empty_returns_empty_tensors(self, attention_config): + """cache[idx] on empty cache returns empty tensors.""" + cache = Apriel2Cache(attention_config) + + k, v = cache[0] + assert k.numel() == 0 + assert v.numel() == 0 diff --git a/fast_llm_external_models/tests/test_apriel2/test_cache_contracts.py b/fast_llm_external_models/tests/test_apriel2/test_cache_contracts.py new file mode 100644 index 000000000..8ceabfb91 --- /dev/null +++ b/fast_llm_external_models/tests/test_apriel2/test_cache_contracts.py @@ -0,0 +1,591 @@ +"""Contract tests for Apriel2Cache against HuggingFace cache implementations. + +This module tests that Apriel2Cache components behave equivalently to their +HuggingFace counterparts. This ensures compatibility with HF's generation +infrastructure (mask creation, beam search, etc.). + +Mapping: + Apriel2 Component HuggingFace Equivalent + ----------------- ---------------------- + _AttentionCache (no window) -> DynamicLayer + _AttentionCache (window) -> DynamicSlidingWindowLayer + _SSMCache -> MambaCache (different interface, same concept) + +Apriel2-specific features (stochastic routing, multi-mixer layers) are tested +separately in test_cache_apriel2_specific.py since they have no HF equivalent. + +Fixtures used from conftest.py: + - batch_size, num_heads, head_dim: Tensor dimensions + - hf_dynamic_layer: HuggingFace DynamicLayer + - hf_sliding_layer: HuggingFace DynamicSlidingWindowLayer (parameterized by window_size) + - apriel_attention_cache: Apriel2 _AttentionCache (no window) + - apriel_sliding_cache: Apriel2 _AttentionCache (with window, parameterized) + - window_size: Parameterized window sizes [4, 8, 16, 32] + - attention_config, swa_config: Apriel2 configs +""" + +import pytest +import torch + +from fast_llm_external_models.apriel2.cache import Apriel2Cache, _AttentionCache + +# ============================================================================= +# SECTION 1: FULL ATTENTION - _AttentionCache vs DynamicLayer +# ============================================================================= + + +class TestFullAttentionContract: + """Test _AttentionCache (no window) matches HuggingFace DynamicLayer. + + DynamicLayer is the standard cache for full causal attention. + We test that our cache produces identical mask parameters. + """ + + # ------------------------------------------------------------------------- + # get_seq_length: Must match exactly for generation to work + # ------------------------------------------------------------------------- + + @pytest.mark.parametrize("seq_len", [1, 5, 10, 50, 100]) + def test_get_seq_length_after_prefill( + self, hf_dynamic_layer, apriel_attention_cache, batch_size, num_heads, head_dim, seq_len + ): + """After prefill, cumulative_length matches HF get_seq_length.""" + key = torch.randn(batch_size, num_heads, seq_len, head_dim) + value = torch.randn(batch_size, num_heads, seq_len, head_dim) + + hf_dynamic_layer.update(key.clone(), value.clone()) + apriel_attention_cache.update(key.clone(), value.clone()) + + assert apriel_attention_cache.cumulative_length == hf_dynamic_layer.get_seq_length() + + @pytest.mark.parametrize("prefill_len", [1, 5, 10]) + @pytest.mark.parametrize("decode_steps", [1, 5, 10, 20]) + def test_get_seq_length_during_decode( + self, hf_dynamic_layer, apriel_attention_cache, batch_size, num_heads, head_dim, prefill_len, decode_steps + ): + """During decode, cumulative_length tracks total tokens seen.""" + # Prefill + key = torch.randn(batch_size, num_heads, prefill_len, head_dim) + value = torch.randn(batch_size, num_heads, prefill_len, head_dim) + hf_dynamic_layer.update(key.clone(), value.clone()) + apriel_attention_cache.update(key.clone(), value.clone()) + + # Decode + for step in range(decode_steps): + key = torch.randn(batch_size, num_heads, 1, head_dim) + value = torch.randn(batch_size, num_heads, 1, head_dim) + hf_dynamic_layer.update(key.clone(), value.clone()) + apriel_attention_cache.update(key.clone(), value.clone()) + + assert ( + apriel_attention_cache.cumulative_length == hf_dynamic_layer.get_seq_length() + ), f"Mismatch at decode step {step}" + + # ------------------------------------------------------------------------- + # get_mask_sizes: Verify HF behavior for documentation + # ------------------------------------------------------------------------- + + @pytest.mark.parametrize("prefill_len", [1, 5, 10]) + @pytest.mark.parametrize("decode_steps", [0, 1, 5, 10]) + def test_hf_mask_sizes_kv_length( + self, hf_dynamic_layer, apriel_attention_cache, batch_size, num_heads, head_dim, prefill_len, decode_steps + ): + """Document HF's kv_length behavior and verify cumulative_length tracks correctly. + + For full attention, kv_length = cumulative_length + query_length. + This test verifies our cache tracks tokens identically to HF. + """ + # Prefill + key = torch.randn(batch_size, num_heads, prefill_len, head_dim) + value = torch.randn(batch_size, num_heads, prefill_len, head_dim) + hf_dynamic_layer.update(key.clone(), value.clone()) + apriel_attention_cache.update(key.clone(), value.clone()) + + # Decode + for _ in range(decode_steps): + key = torch.randn(batch_size, num_heads, 1, head_dim) + value = torch.randn(batch_size, num_heads, 1, head_dim) + hf_dynamic_layer.update(key.clone(), value.clone()) + apriel_attention_cache.update(key.clone(), value.clone()) + + # Verify cumulative_length matches HF + assert apriel_attention_cache.cumulative_length == hf_dynamic_layer.get_seq_length() + + # Verify HF's kv_length follows the expected formula + cache_position = torch.arange(1) # Single token decode + hf_kv_len, hf_kv_offset = hf_dynamic_layer.get_mask_sizes(cache_position) + expected_kv_len = hf_dynamic_layer.get_seq_length() + cache_position.shape[0] + assert hf_kv_len == expected_kv_len + + def test_hf_kv_offset_always_zero(self, hf_dynamic_layer, batch_size, num_heads, head_dim): + """Document that HF DynamicLayer always returns kv_offset=0. + + For full attention, all cached KV pairs map to absolute positions + starting from 0, so kv_offset is always 0. + """ + # Add many tokens + for _ in range(20): + key = torch.randn(batch_size, num_heads, 5, head_dim) + value = torch.randn(batch_size, num_heads, 5, head_dim) + hf_dynamic_layer.update(key.clone(), value.clone()) + + cache_position = torch.arange(1) + _, hf_kv_offset = hf_dynamic_layer.get_mask_sizes(cache_position) + + assert hf_kv_offset == 0, "DynamicLayer always returns kv_offset=0" + + # ------------------------------------------------------------------------- + # update: Output shape and values must match + # ------------------------------------------------------------------------- + + @pytest.mark.parametrize("seq_len", [1, 5, 10]) + def test_update_returns_same_shape( + self, hf_dynamic_layer, apriel_attention_cache, batch_size, num_heads, head_dim, seq_len + ): + """update() returns tensors with matching shapes.""" + key = torch.randn(batch_size, num_heads, seq_len, head_dim) + value = torch.randn(batch_size, num_heads, seq_len, head_dim) + + hf_k, hf_v = hf_dynamic_layer.update(key.clone(), value.clone()) + apr_k, apr_v = apriel_attention_cache.update(key.clone(), value.clone()) + + assert hf_k.shape == apr_k.shape + assert hf_v.shape == apr_v.shape + + def test_update_concatenates_identically( + self, hf_dynamic_layer, apriel_attention_cache, batch_size, num_heads, head_dim + ): + """Multiple updates produce identical concatenated states.""" + # Use deterministic values for comparison + k1 = torch.arange(10).float().view(1, 1, 10, 1).expand(batch_size, num_heads, 10, head_dim) + v1 = k1.clone() + + hf_dynamic_layer.update(k1.clone(), v1.clone()) + apriel_attention_cache.update(k1.clone(), v1.clone()) + + k2 = torch.arange(10, 15).float().view(1, 1, 5, 1).expand(batch_size, num_heads, 5, head_dim) + v2 = k2.clone() + + hf_k, hf_v = hf_dynamic_layer.update(k2.clone(), v2.clone()) + apr_k, apr_v = apriel_attention_cache.update(k2.clone(), v2.clone()) + + torch.testing.assert_close(hf_k, apr_k) + torch.testing.assert_close(hf_v, apr_v) + + +# ============================================================================= +# SECTION 2: SLIDING WINDOW - _AttentionCache vs DynamicSlidingWindowLayer +# ============================================================================= + + +class TestSlidingWindowContract: + """Test _AttentionCache (with window) matches HuggingFace DynamicSlidingWindowLayer. + + DynamicSlidingWindowLayer is used for sliding window attention (e.g., Mistral). + Critical behaviors: + - cumulative_length tracks ALL tokens seen (not just cached) + - kv_offset increases once window is exceeded + - kv_length is capped at window size + + Uses fixtures from conftest.py: + - window_size: parameterized [4, 8, 16, 32] + - hf_sliding_layer: DynamicSlidingWindowLayer + - apriel_sliding_cache: _AttentionCache with window + """ + + # ------------------------------------------------------------------------- + # cumulative_length: Must track total tokens, not cached tokens + # ------------------------------------------------------------------------- + + @pytest.mark.parametrize("prefill_len", [1, 3, 5, 10, 20]) + def test_cumulative_length_matches_after_prefill( + self, hf_sliding_layer, apriel_sliding_cache, batch_size, num_heads, head_dim, prefill_len + ): + """cumulative_length matches HF get_seq_length after prefill.""" + key = torch.randn(batch_size, num_heads, prefill_len, head_dim) + value = torch.randn(batch_size, num_heads, prefill_len, head_dim) + + hf_sliding_layer.update(key.clone(), value.clone()) + apriel_sliding_cache.update(key.clone(), value.clone()) + + assert apriel_sliding_cache.cumulative_length == hf_sliding_layer.get_seq_length() + + def test_cumulative_length_continues_past_window( + self, hf_sliding_layer, apriel_sliding_cache, window_size, batch_size, num_heads, head_dim + ): + """cumulative_length keeps growing even after window is full.""" + total_tokens = window_size * 3 # Way past window + + for i in range(total_tokens): + key = torch.randn(batch_size, num_heads, 1, head_dim) + value = torch.randn(batch_size, num_heads, 1, head_dim) + + hf_sliding_layer.update(key.clone(), value.clone()) + apriel_sliding_cache.update(key.clone(), value.clone()) + + expected = i + 1 + assert apriel_sliding_cache.cumulative_length == expected + assert hf_sliding_layer.get_seq_length() == expected + + # ------------------------------------------------------------------------- + # get_mask_sizes: kv_offset must increase once window is exceeded + # ------------------------------------------------------------------------- + + def test_kv_offset_zero_before_window_full( + self, hf_sliding_layer, apriel_sliding_cache, window_size, batch_size, num_heads, head_dim + ): + """kv_offset is 0 while cumulative < window. + + Before the window is full, kv_offset should be 0 because all cached tokens + correspond to absolute positions starting from 0. + """ + # Add tokens up to window-1 + for i in range(window_size - 1): + key = torch.randn(batch_size, num_heads, 1, head_dim) + value = torch.randn(batch_size, num_heads, 1, head_dim) + + hf_sliding_layer.update(key.clone(), value.clone()) + apriel_sliding_cache.update(key.clone(), value.clone()) + + cache_position = torch.arange(1) + hf_kv_len, hf_kv_offset = hf_sliding_layer.get_mask_sizes(cache_position) + + # Verify HF returns 0 offset before window full + assert hf_kv_offset == 0, f"HF offset should be 0 at step {i}" + # Verify Apriel cache tracks cumulative correctly + assert apriel_sliding_cache.cumulative_length == i + 1 + + def test_kv_offset_increases_after_window_full( + self, hf_sliding_layer, apriel_sliding_cache, window_size, batch_size, num_heads, head_dim + ): + """kv_offset increases once cumulative >= window. + + Once the window is full, the cache discards oldest tokens. kv_offset tracks + which absolute position KV[0] corresponds to. + """ + # Fill to exactly window + for _ in range(window_size): + key = torch.randn(batch_size, num_heads, 1, head_dim) + value = torch.randn(batch_size, num_heads, 1, head_dim) + hf_sliding_layer.update(key.clone(), value.clone()) + apriel_sliding_cache.update(key.clone(), value.clone()) + + cache_position = torch.arange(1) + hf_kv_len, hf_kv_offset = hf_sliding_layer.get_mask_sizes(cache_position) + + # At window boundary, offset should be 1 + assert hf_kv_offset == 1, "HF offset should be 1 at window boundary" + assert apriel_sliding_cache.cumulative_length == window_size + + # Add more tokens and verify offset keeps increasing with HF + for i in range(5): + key = torch.randn(batch_size, num_heads, 1, head_dim) + value = torch.randn(batch_size, num_heads, 1, head_dim) + hf_sliding_layer.update(key.clone(), value.clone()) + apriel_sliding_cache.update(key.clone(), value.clone()) + + hf_kv_len, hf_kv_offset = hf_sliding_layer.get_mask_sizes(cache_position) + + expected_offset = i + 2 + assert hf_kv_offset == expected_offset + assert apriel_sliding_cache.cumulative_length == window_size + i + 1 + + def test_kv_length_capped_at_window( + self, hf_sliding_layer, apriel_sliding_cache, window_size, batch_size, num_heads, head_dim + ): + """kv_length is capped at window size once exceeded. + + For a query of length 1 after the window is full, kv_length = window + (window-1 cached tokens + 1 query token). + """ + # Way past window + for _ in range(window_size * 2): + key = torch.randn(batch_size, num_heads, 1, head_dim) + value = torch.randn(batch_size, num_heads, 1, head_dim) + hf_sliding_layer.update(key.clone(), value.clone()) + apriel_sliding_cache.update(key.clone(), value.clone()) + + cache_position = torch.arange(1) + hf_kv_len, _ = hf_sliding_layer.get_mask_sizes(cache_position) + + # HF returns window (window-1 cached + 1 query) + assert hf_kv_len == window_size + # Verify our cache tracked cumulative correctly + assert apriel_sliding_cache.cumulative_length == window_size * 2 + + # ------------------------------------------------------------------------- + # Full sequence length tracking through generation + # ------------------------------------------------------------------------- + + @pytest.mark.parametrize("prefill_len", [1, 3, 5, 10, 20]) + def test_cumulative_length_tracks_all_tokens( + self, hf_sliding_layer, apriel_sliding_cache, window_size, batch_size, num_heads, head_dim, prefill_len + ): + """cumulative_length tracks total tokens seen through prefill + decode. + + This is the foundation for correct mask size computation. We verify that + our _AttentionCache tracks tokens identically to HuggingFace's DynamicSlidingWindowLayer. + The actual get_mask_sizes computation is tested in TestApriel2CacheIntegration. + """ + # Prefill + key = torch.randn(batch_size, num_heads, prefill_len, head_dim) + value = torch.randn(batch_size, num_heads, prefill_len, head_dim) + hf_sliding_layer.update(key.clone(), value.clone()) + apriel_sliding_cache.update(key.clone(), value.clone()) + + assert apriel_sliding_cache.cumulative_length == hf_sliding_layer.get_seq_length() + + # Decode past window + for i in range(window_size + 10): + key = torch.randn(batch_size, num_heads, 1, head_dim) + value = torch.randn(batch_size, num_heads, 1, head_dim) + hf_sliding_layer.update(key.clone(), value.clone()) + apriel_sliding_cache.update(key.clone(), value.clone()) + + assert ( + apriel_sliding_cache.cumulative_length == hf_sliding_layer.get_seq_length() + ), f"cumulative_length mismatch at step {i}" + + +# ============================================================================= +# SECTION 3: SSM CACHE - _SSMCache vs MambaCache concept +# ============================================================================= + + +class TestSSMCacheContract: + """Document _SSMCache interface and verify basic contract. + + Unlike attention caches which have HF equivalents (DynamicLayer, DynamicSlidingWindowLayer), + SSM caches have no direct HF counterpart with matching interface. HF's MambaCache uses + different methods (update_conv_state, update_ssm_state), so we can't do direct comparison. + + These tests document the interface contract: + 1. `conv` and `recurrent` attributes for storing states + 2. Both support None (lazy initialization) + 3. `conv` can be tuple (for KDA which has separate q/k/v conv states) + + Higher-level operations (reorder, batch_repeat, reset) are tested in + TestBeamSearchOperations in test_cache_apriel2_specific.py. + """ + + def test_conv_state_storage(self, ssm_cache): + """conv attribute stores conv states (batch, intermediate, kernel_size).""" + conv = torch.randn(2, 64, 4) + ssm_cache.conv = conv + torch.testing.assert_close(ssm_cache.conv, conv) + + def test_recurrent_state_storage(self, ssm_cache): + """recurrent attribute stores SSM states (batch, intermediate, state_size).""" + recurrent = torch.randn(2, 64, 16) + ssm_cache.recurrent = recurrent + torch.testing.assert_close(ssm_cache.recurrent, recurrent) + + def test_conv_state_tuple_for_kda(self, ssm_cache): + """conv can be tuple for KDA's separate q/k/v convolutions.""" + conv_tuple = (torch.randn(2, 64, 4), torch.randn(2, 64, 4), torch.randn(2, 64, 4)) + ssm_cache.conv = conv_tuple + assert isinstance(ssm_cache.conv, tuple) + assert len(ssm_cache.conv) == 3 + + def test_initial_states_none(self, ssm_cache): + """States are None initially (lazy initialization pattern).""" + assert ssm_cache.conv is None + assert ssm_cache.recurrent is None + + def test_states_independent(self, ssm_cache): + """conv and recurrent states are independent.""" + ssm_cache.conv = torch.randn(2, 64, 4) + assert ssm_cache.recurrent is None # recurrent unchanged + + ssm_cache.recurrent = torch.randn(2, 64, 16) + assert ssm_cache.conv is not None # conv unchanged + + +# ============================================================================= +# SECTION 4: APRIEL2CACHE INTEGRATION +# ============================================================================= + + +class TestApriel2CacheIntegration: + """Test Apriel2Cache correctly delegates to underlying caches. + + Uses fixtures from conftest.py: + - attention_config: Pure attention config + - swa_config: Sliding window attention config (window=8) + """ + + def test_get_seq_length_matches_dynamic_layer(self, attention_config): + """Apriel2Cache.get_seq_length matches DynamicLayer for full attention.""" + from transformers.cache_utils import DynamicLayer + + cache = Apriel2Cache(attention_config) + hf_layer = DynamicLayer() + + key = torch.randn(2, 4, 10, 16) + value = torch.randn(2, 4, 10, 16) + + cache.update(key.clone(), value.clone(), layer_idx=0) + hf_layer.update(key.clone(), value.clone()) + + assert cache.get_seq_length(0) == hf_layer.get_seq_length() + + def test_get_mask_sizes_matches_dynamic_layer(self, attention_config): + """Apriel2Cache.get_mask_sizes matches DynamicLayer.""" + from transformers.cache_utils import DynamicLayer + + cache = Apriel2Cache(attention_config) + hf_layer = DynamicLayer() + + key = torch.randn(2, 4, 10, 16) + value = torch.randn(2, 4, 10, 16) + + cache.update(key.clone(), value.clone(), layer_idx=0) + hf_layer.update(key.clone(), value.clone()) + + cache_position = torch.arange(1) + hf_kv_len, hf_kv_offset = hf_layer.get_mask_sizes(cache_position) + apr_kv_len, apr_kv_offset = cache.get_mask_sizes(cache_position, layer_idx=0) + + assert apr_kv_len == hf_kv_len + assert apr_kv_offset == hf_kv_offset + + def test_get_mask_sizes_matches_sliding_layer(self, swa_config): + """Apriel2Cache.get_mask_sizes matches DynamicSlidingWindowLayer.""" + from transformers.cache_utils import DynamicSlidingWindowLayer + + cache = Apriel2Cache(swa_config) + hf_layer = DynamicSlidingWindowLayer(sliding_window=8) + + # Fill past window + for _ in range(15): + key = torch.randn(2, 4, 1, 16) + value = torch.randn(2, 4, 1, 16) + cache.update(key.clone(), value.clone(), layer_idx=0) + hf_layer.update(key.clone(), value.clone()) + + cache_position = torch.arange(1) + hf_kv_len, hf_kv_offset = hf_layer.get_mask_sizes(cache_position) + apr_kv_len, apr_kv_offset = cache.get_mask_sizes(cache_position, layer_idx=0) + + assert apr_kv_len == hf_kv_len + assert apr_kv_offset == hf_kv_offset + + def test_reset_clears_cumulative_length(self, attention_config): + """reset() clears cumulative_length (matches DynamicLayer.reset).""" + cache = Apriel2Cache(attention_config) + + cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0) + assert cache.get_seq_length(0) == 10 + + cache.reset() + assert cache.get_seq_length(0) == 0 + + +# ============================================================================= +# SECTION 5: MASK CORRECTNESS (SEMANTIC TESTS) +# ============================================================================= + + +class TestMaskCorrectness: + """Test that mask parameters produce semantically correct masks. + + These tests verify the END RESULT: masks created with our parameters + allow the correct attention patterns. + """ + + def test_full_attention_decode_can_attend_to_all(self): + """During decode, query can attend to all cached positions.""" + from transformers.masking_utils import causal_mask_function, sdpa_mask + + cache = _AttentionCache(window=None) + + # Prefill + decode + for _ in range(10): + cache.update(torch.randn(1, 1, 1, 16), torch.randn(1, 1, 1, 16)) + + # Mask for decode step + cache_position = torch.tensor([10]) # Position of new token + kv_length = cache.cumulative_length + 1 + kv_offset = 0 + + mask = sdpa_mask( + batch_size=1, + cache_position=cache_position, + kv_length=kv_length, + kv_offset=kv_offset, + mask_function=causal_mask_function, + ) + + if mask is not None: + # Query at position 10 should attend to positions 0-10 + query_mask = mask[0, 0, 0, :] + for kv_idx in range(kv_length): + assert query_mask[kv_idx].item() == True, f"Should attend to position {kv_idx}" + + @pytest.mark.parametrize("window_size", [4, 8, 16]) + def test_sliding_window_decode_respects_window(self, window_size): + """During decode, query only attends within sliding window.""" + from transformers.masking_utils import sdpa_mask, sliding_window_causal_mask_function + + cache = _AttentionCache(window=window_size) + + # Fill way past window + total_tokens = window_size * 2 + for _ in range(total_tokens): + cache.update(torch.randn(1, 1, 1, 16), torch.randn(1, 1, 1, 16)) + + # Mask for decode step + cache_position = torch.tensor([total_tokens]) + cumulative = cache.cumulative_length + kv_offset = max(cumulative - window_size + 1, 0) + kv_length = window_size - 1 + 1 # cached + query + + mask = sdpa_mask( + batch_size=1, + cache_position=cache_position, + kv_length=kv_length, + kv_offset=kv_offset, + mask_function=sliding_window_causal_mask_function(window_size), + ) + + if mask is not None: + query_mask = mask[0, 0, 0, :] + query_pos = cache_position[0].item() + + for kv_idx in range(kv_length): + abs_pos = kv_offset + kv_idx + in_window = abs_pos > query_pos - window_size + causal = abs_pos <= query_pos + expected = in_window and causal + + assert ( + query_mask[kv_idx].item() == expected + ), f"Position {abs_pos}: expected {expected}, got {query_mask[kv_idx].item()}" + + def test_prefill_has_causal_pattern(self): + """During prefill, mask has proper causal (lower triangular) pattern.""" + from transformers.masking_utils import causal_mask_function, sdpa_mask + + cache = _AttentionCache(window=None) + cache.update(torch.randn(1, 1, 5, 16), torch.randn(1, 1, 5, 16)) + + cache_position = torch.arange(5) + kv_length = cache.cumulative_length + kv_offset = 0 + + mask = sdpa_mask( + batch_size=1, + cache_position=cache_position, + kv_length=kv_length, + kv_offset=kv_offset, + mask_function=causal_mask_function, + allow_is_causal_skip=False, # Force mask creation + ) + + if mask is not None: + # Check causal pattern + for q_idx in range(5): + for kv_idx in range(5): + expected = kv_idx <= q_idx + actual = mask[0, 0, q_idx, kv_idx].item() + assert actual == expected, f"q={q_idx}, kv={kv_idx}: expected {expected}" diff --git a/fast_llm_external_models/tests/test_apriel2/test_causal_conv1d.py b/fast_llm_external_models/tests/test_apriel2/test_causal_conv1d.py index ec6abc1d2..0567cd76e 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_causal_conv1d.py +++ b/fast_llm_external_models/tests/test_apriel2/test_causal_conv1d.py @@ -24,7 +24,6 @@ from fast_llm_external_models.apriel2.modeling_apriel2 import CausalConv1d, _causal_conv1d_fn - # ============================================================================= # Fixtures # ============================================================================= @@ -63,6 +62,7 @@ def kernel_size(): def to_device(conv: CausalConv1d, device: str) -> CausalConv1d: """Create a copy of conv on the specified device.""" import copy + return copy.deepcopy(conv).to(device) @@ -71,7 +71,9 @@ def prefill(conv: CausalConv1d, x: torch.Tensor, state: torch.Tensor = None) -> return conv(x, conv_state=state, return_final_state=True) -def decode_sequence(conv: CausalConv1d, tokens: torch.Tensor, state: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: +def decode_sequence( + conv: CausalConv1d, tokens: torch.Tensor, state: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: """Decode multiple tokens one-by-one, return (stacked_outputs, final_state). Args: @@ -223,7 +225,7 @@ def test_chunked_prefill_cpu(self, conv, dim, total_len, chunk_size): outputs = [] state = None for start in range(0, total_len, chunk_size): - chunk = x[:, :, start:start + chunk_size] + chunk = x[:, :, start : start + chunk_size] out, state = prefill(conv, chunk, state) outputs.append(out) @@ -248,7 +250,7 @@ def test_chunked_prefill_cuda(self, conv, dim, total_len, chunk_size): outputs = [] state = None for start in range(0, total_len, chunk_size): - chunk = x[:, :, start:start + chunk_size].cuda() + chunk = x[:, :, start : start + chunk_size].cuda() out, state = prefill(conv_cuda, chunk, state) outputs.append(out) @@ -329,7 +331,7 @@ def test_all_cpu_paths_match(self, conv, dim): outputs = [] state = None for start in range(0, total_len, chunk_size): - chunk = x[:, :, start:start + chunk_size] + chunk = x[:, :, start : start + chunk_size] out, state = prefill(conv, chunk, state) outputs.append(out) path1 = torch.cat(outputs, dim=-1) @@ -374,7 +376,7 @@ def test_all_paths_match_cross_device(self, conv, dim): # CPU chunked outputs, state = [], None for start in range(0, total_len, chunk_size): - out, state = prefill(conv, x[:, :, start:start + chunk_size], state) + out, state = prefill(conv, x[:, :, start : start + chunk_size], state) outputs.append(out) results["cpu_chunked"] = torch.cat(outputs, dim=-1) @@ -393,7 +395,7 @@ def test_all_paths_match_cross_device(self, conv, dim): # CUDA chunked outputs, state = [], None for start in range(0, total_len, chunk_size): - out, state = prefill(conv_cuda, x[:, :, start:start + chunk_size].cuda(), state) + out, state = prefill(conv_cuda, x[:, :, start : start + chunk_size].cuda(), state) outputs.append(out.cpu()) results["cuda_chunked"] = torch.cat(outputs, dim=-1) @@ -431,8 +433,7 @@ def test_all_paths_match_cross_device(self, conv, dim): for name, result in results.items(): tol = tolerances[name] torch.testing.assert_close( - result, reference, atol=tol, rtol=tol, - msg=f"Path '{name}' diverged from reference" + result, reference, atol=tol, rtol=tol, msg=f"Path '{name}' diverged from reference" ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") @@ -468,8 +469,8 @@ def test_long_decode_no_drift(self, conv, dim): # Check no systematic drift (errors shouldn't consistently increase) decode_errors = errors[prefill_len:] - first_half = decode_errors[:len(decode_errors)//2].mean() - second_half = decode_errors[len(decode_errors)//2:].mean() + first_half = decode_errors[: len(decode_errors) // 2].mean() + second_half = decode_errors[len(decode_errors) // 2 :].mean() assert second_half < first_half * 2, "Errors growing over decode steps (drift detected)" diff --git a/fast_llm_external_models/tests/test_apriel2/test_compose_configs.py b/fast_llm_external_models/tests/test_apriel2/test_compose_configs.py index 0bd6ac88d..3413b9d25 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_compose_configs.py +++ b/fast_llm_external_models/tests/test_apriel2/test_compose_configs.py @@ -20,7 +20,7 @@ import yaml from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config -from fast_llm_external_models.apriel2.conversion.config import apply_surgery, compose_configs +from fast_llm_external_models.apriel2.conversion.config import compose_configs class TestComposeConfigsLaws: @@ -75,14 +75,10 @@ def source_config(self): }, } - def test_identity_empty_surgery(self, source_config): - """Law 1: compose_configs(config, {}) == config""" - result = compose_configs(source_config, {}) - assert result == source_config - - def test_identity_none_surgery(self, source_config): - """Law 1: compose_configs(config, None) == config""" - result = compose_configs(source_config, None) + @pytest.mark.parametrize("empty_surgery", [{}, None]) + def test_identity(self, source_config, empty_surgery): + """Law 1: compose_configs(config, empty) == config for empty in [{}, None]""" + result = compose_configs(source_config, empty_surgery) assert result == source_config def test_override_explicit_values(self, source_config): @@ -114,7 +110,7 @@ def test_same_type_inheritance(self, source_config): assert mixer["head_size"] == 32 # Inherited assert mixer["rope_theta"] == 10000.0 # Inherited assert mixer["window_size"] == 512 # Added - assert "init" not in mixer # Stripped by apply_surgery + # init is preserved for plan_surgery to see (stripped only at final output) def test_cross_type_attention_to_gdn(self, source_config): """Law 5: attentionโ†’gdn derives GDN dims from attention geometry.""" @@ -239,8 +235,14 @@ def test_null_deletion(self, source_config): assert "vision_encoder" not in result - def test_init_stripped_from_result(self, source_config): - """Verify `init` keys are stripped from final result.""" + def test_init_preserved_for_plan_surgery(self, source_config): + """Verify `init` keys are preserved so plan_surgery can see them. + + The `init` field controls weight initialization (transfer vs random). + It's preserved through composition and only stripped at final output. + """ + from fast_llm_external_models.apriel2.conversion.config import strip_init_fields + surgery = { "decoder": { "block": { @@ -252,20 +254,20 @@ def test_init_stripped_from_result(self, source_config): "gdn": {"type": "gdn", "init": "random", "convolution_layer": {"kernel_size": 4}}, }, }, - "mlp": {"init": "transfer"}, - "normalization": {"init": "transfer"}, }, }, } result = compose_configs(source_config, surgery) - def check_no_init(d, path=""): - assert "init" not in d, f"Found 'init' key at {path}" - for k, v in d.items(): - if isinstance(v, dict): - check_no_init(v, f"{path}.{k}") + # init is preserved in composed config + mixers = result["decoder"]["block"]["mixer"]["mixers"] + assert mixers["attention"].get("init") == "transfer" + assert mixers["gdn"].get("init") == "random" - check_no_init(result) + # strip_init_fields removes them for final output + stripped = strip_init_fields(result) + assert "init" not in stripped["decoder"]["block"]["mixer"]["mixers"]["attention"] + assert "init" not in stripped["decoder"]["block"]["mixer"]["mixers"]["gdn"] def test_init_random_still_inherits_config(self, source_config): """init: random is for weights only - config params still inherited.""" @@ -287,6 +289,212 @@ def test_init_random_still_inherits_config(self, source_config): assert mixer["head_groups"] == 4 assert mixer["window_size"] == 512 + # ========================================================================= + # Monoid Laws: compose_configs forms a monoid action on configs + # ========================================================================= + + def test_surgery_monoid_associativity(self): + """MONOID: merge(merge(A, B), C) == merge(A, merge(B, C)) for partial configs.""" + surgery_a = {"decoder": {"block": {"mixer": {"type": "stochastic", "main_mixer_name": "attention"}}}} + surgery_b = {"decoder": {"block": {"mixer": {"mixers": {"sliding_window": {"window_size": 512}}}}}} + surgery_c = {"decoder": {"block": {"mixer": {"mixers": {"gdn": {"type": "gdn"}}}}}} + + # Left-associated: (A โˆ˜ B) โˆ˜ C + ab_c = compose_configs(compose_configs(surgery_a, surgery_b), surgery_c) + # Right-associated: A โˆ˜ (B โˆ˜ C) + a_bc = compose_configs(surgery_a, compose_configs(surgery_b, surgery_c)) + + assert ab_c == a_bc, "Surgery monoid should be associative" + + @pytest.mark.parametrize("num_surgeries", [2, 3]) + def test_monoid_action_compatibility(self, source_config, num_surgeries): + """MONOID ACTION: apply(apply(c, A), B) == apply(c, merge(A, B)) + + This is the key law: applying surgeries sequentially equals merging first. + Parameterized to test with 2 and 3 surgeries. + """ + surgeries = [ + { + "decoder": { + "block": { + "mixer": {"type": "stochastic", "main_mixer_name": "attention", "mixers": {"attention": {}}} + } + } + }, + {"decoder": {"block": {"mixer": {"mixers": {"sliding_window": {"window_size": 512}}}}}}, + {"decoder": {"block": {"mixer": {"mixers": {"gdn": {"type": "gdn"}}}}}}, + ][:num_surgeries] + + # Sequential: ((c โŠณ A) โŠณ B) โŠณ ... + result_sequential = source_config + for s in surgeries: + result_sequential = compose_configs(result_sequential, s) + + # Merged: c โŠณ (A โˆ˜ B โˆ˜ ...) + merged = surgeries[0] + for s in surgeries[1:]: + merged = compose_configs(merged, s) + result_merged = compose_configs(source_config, merged) + + assert result_sequential == result_merged, f"Monoid action compatibility failed for {num_surgeries} surgeries" + + +class TestBiasConfigInheritance: + """Test per-layer bias inheritance through surgery composition. + + These tests verify that the per-layer bias configuration (mirroring Fast-LLM's + AffineLinearConfig) is correctly inherited through surgery operations: + - query_layer.bias.enabled, key_layer.bias.enabled, etc. for attention + - layer_1.bias.enabled, layer_2.bias.enabled for MLP + """ + + @pytest.fixture + def source_config_with_bias(self): + """Source config with Qwen-style bias (QKV enabled, O disabled).""" + return { + "model_type": "apriel2", + "architectures": ["Apriel2ForCausalLM"], + "hidden_size": 256, + "vocab_size": 1000, + "decoder": { + "type": "fixed", + "num_blocks": 4, + "block": { + "mixer": { + "type": "attention", + "heads": 8, + "head_groups": 4, + "head_size": 32, + "rotary": {"type": "mistral_1d", "theta": 10000.0}, + # Qwen-style per-layer bias + "query_layer": {"bias": {"enabled": True}}, + "key_layer": {"bias": {"enabled": True}}, + "value_layer": {"bias": {"enabled": True}}, + "dense_layer": {"bias": {"enabled": False}}, + }, + "mlp": { + "type": "mlp", + "intermediate_size": 512, + "gated": False, + # Per-layer MLP bias + "layer_1": {"bias": {"enabled": True}}, + "layer_2": {"bias": {"enabled": False}}, + }, + "normalization": {"type": "rms_norm", "epsilon": 1e-5}, + }, + }, + } + + def test_same_type_inherits_attention_bias(self, source_config_with_bias): + """Same-type surgery inherits per-layer attention bias settings.""" + surgery = { + "decoder": { + "block": { + "mixer": { + "window_size": 512, # Add sliding window behavior + }, + }, + }, + } + result = compose_configs(source_config_with_bias, surgery) + + mixer = result["decoder"]["block"]["mixer"] + assert mixer["query_layer"]["bias"]["enabled"] is True + assert mixer["key_layer"]["bias"]["enabled"] is True + assert mixer["value_layer"]["bias"]["enabled"] is True + assert mixer["dense_layer"]["bias"]["enabled"] is False + + def test_same_type_inherits_mlp_bias(self, source_config_with_bias): + """Same-type surgery inherits per-layer MLP bias settings.""" + surgery = { + "decoder": { + "block": { + "mlp": { + "intermediate_size": 1024, # Change size + }, + }, + }, + } + result = compose_configs(source_config_with_bias, surgery) + + mlp = result["decoder"]["block"]["mlp"] + assert mlp["layer_1"]["bias"]["enabled"] is True + assert mlp["layer_2"]["bias"]["enabled"] is False + assert mlp["intermediate_size"] == 1024 + + def test_cross_type_attention_to_sliding_window_preserves_bias(self, source_config_with_bias): + """attentionโ†’sliding_window cross-type preserves per-layer bias.""" + surgery = { + "decoder": { + "block": { + "mixer": { + "type": "sliding_window", # Cross-type derivation + "window_size": 512, + }, + }, + }, + } + result = compose_configs(source_config_with_bias, surgery) + + mixer = result["decoder"]["block"]["mixer"] + assert mixer["type"] == "sliding_window" + # Bias settings preserved through cross-type + assert mixer["query_layer"]["bias"]["enabled"] is True + assert mixer["key_layer"]["bias"]["enabled"] is True + assert mixer["value_layer"]["bias"]["enabled"] is True + assert mixer["dense_layer"]["bias"]["enabled"] is False + + def test_stochastic_wrapper_inherits_bias(self, source_config_with_bias): + """Wrapping in stochastic inherits bias settings to all sub-mixers.""" + surgery = { + "decoder": { + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"init": "transfer"}, + "sliding_window": { + "type": "sliding_window", + "window_size": 512, + "init": "transfer", + }, + }, + }, + }, + }, + } + result = compose_configs(source_config_with_bias, surgery) + + mixers = result["decoder"]["block"]["mixer"]["mixers"] + + # Attention sub-mixer inherits bias + assert mixers["attention"]["query_layer"]["bias"]["enabled"] is True + assert mixers["attention"]["dense_layer"]["bias"]["enabled"] is False + + # Sliding window sub-mixer also inherits bias + assert mixers["sliding_window"]["query_layer"]["bias"]["enabled"] is True + assert mixers["sliding_window"]["dense_layer"]["bias"]["enabled"] is False + + def test_surgery_can_override_bias(self, source_config_with_bias): + """Surgery can explicitly override inherited bias settings.""" + surgery = { + "decoder": { + "block": { + "mixer": { + "dense_layer": {"bias": {"enabled": True}}, # Override O bias + }, + }, + }, + } + result = compose_configs(source_config_with_bias, surgery) + + mixer = result["decoder"]["block"]["mixer"] + # Q/K/V unchanged + assert mixer["query_layer"]["bias"]["enabled"] is True + # O bias overridden + assert mixer["dense_layer"]["bias"]["enabled"] is True + class TestComposeConfigsRealYAML: """Test compose_configs with real YAML surgery files.""" @@ -398,160 +606,12 @@ def test_build_plan_returns_complete_config(self, llava_pixtral_checkpoint): mixer = config.decoder["block"]["mixer"] assert mixer["type"] == "stochastic" - # Each sub-mixer should have complete config (no init keys) + # Each sub-mixer should have complete config + # (init is preserved for plan_surgery, stripped only at final output) for name, sub_mixer in mixer["mixers"].items(): - assert "init" not in sub_mixer, f"Sub-mixer {name} still has 'init' key" assert "type" in sub_mixer -class TestMonoidLaws: - """Test the algebraic laws of compose_configs. - - Surgery specs form a MONOID under deep-merge: - - Identity: {} - - Operation: deep merge (overlay wins) - - Associativity: merge(merge(A, B), C) == merge(A, merge(B, C)) - - compose_configs is a MONOID ACTION on configs: - - Identity action: apply(config, {}) == config - - Compatibility: apply(apply(c, A), B) == apply(c, merge(A, B)) - """ - - @pytest.fixture - def complete_config(self): - """A complete Apriel2 config.""" - return { - "model_type": "apriel2", - "architectures": ["Apriel2ForConditionalGeneration"], - "hidden_size": 256, - "vocab_size": 1000, - "bos_token_id": 1, - "eos_token_id": 2, - "tie_word_embeddings": False, - "image_token_index": 100, - "decoder": { - "type": "fixed", - "num_blocks": 4, - "block": { - "mixer": { - "type": "attention", - "heads": 8, - "head_groups": 4, - "head_size": 32, - "rope_theta": 10000.0, - }, - "mlp": {"type": "mlp", "intermediate_size": 512}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - }, - } - - @pytest.fixture - def surgery_a(self): - """First surgery: wrap in stochastic with attention.""" - return { - "decoder": { - "block": { - "mixer": { - "type": "stochastic", - "main_mixer_name": "attention", - "mixers": { - "attention": {"init": "transfer"}, - }, - }, - }, - }, - } - - @pytest.fixture - def surgery_b(self): - """Second surgery: add sliding window mixer.""" - return { - "decoder": { - "block": { - "mixer": { - "mixers": { - "sliding_window": {"init": "transfer", "window_size": 512}, - }, - }, - }, - }, - } - - def test_identity_action(self, complete_config): - """apply(config, {}) == config""" - result = compose_configs(complete_config, {}) - assert result == complete_config - - def test_surgery_monoid_associativity(self, surgery_a, surgery_b): - """merge(merge(A, B), C) == merge(A, merge(B, C)) for partial configs.""" - surgery_c = { - "decoder": { - "block": { - "mixer": { - "mixers": { - "gdn": {"type": "gdn", "init": "transfer", "convolution_layer": {"kernel_size": 4}}, - }, - }, - }, - }, - } - - # Left-associated: (A โˆ˜ B) โˆ˜ C - ab = compose_configs(surgery_a, surgery_b) - ab_c = compose_configs(ab, surgery_c) - - # Right-associated: A โˆ˜ (B โˆ˜ C) - bc = compose_configs(surgery_b, surgery_c) - a_bc = compose_configs(surgery_a, bc) - - assert ab_c == a_bc, "Surgery monoid should be associative" - - def test_monoid_action_compatibility(self, complete_config, surgery_a, surgery_b): - """apply(apply(c, A), B) == apply(c, merge(A, B)) - - This is the key law: applying surgeries sequentially should equal - merging the surgeries first, then applying once. - """ - # Sequential application: (c โŠณ A) โŠณ B - result_sequential = compose_configs(compose_configs(complete_config, surgery_a), surgery_b) - - # Merged application: c โŠณ (A โˆ˜ B) - merged_surgery = compose_configs(surgery_a, surgery_b) - result_merged = compose_configs(complete_config, merged_surgery) - - # These should be equivalent - assert result_sequential == result_merged, "Monoid action should satisfy compatibility law" - - def test_three_way_compatibility(self, complete_config, surgery_a, surgery_b): - """Test with three surgeries for stronger confidence.""" - surgery_c = { - "decoder": { - "block": { - "mixer": { - "mixers": { - "gdn": {"type": "gdn", "init": "transfer", "convolution_layer": {"kernel_size": 4}}, - }, - }, - }, - }, - } - - # Sequential: ((c โŠณ A) โŠณ B) โŠณ C - seq = compose_configs( - compose_configs(compose_configs(complete_config, surgery_a), surgery_b), - surgery_c - ) - - # Merged: c โŠณ ((A โˆ˜ B) โˆ˜ C) - merged = compose_configs( - complete_config, - compose_configs(compose_configs(surgery_a, surgery_b), surgery_c) - ) - - assert seq == merged, "Three-way monoid action should satisfy compatibility" - - class TestCompositionTortureTest: """Comprehensive stress test for config composition. @@ -650,19 +710,29 @@ def test_final_config_structure(self, complete_config, additive_surgery_chain): assert mixer["mixers"]["sliding_window"]["window_size"] == 512 assert mixer["mixers"]["gdn"]["value_heads"] == 16 - def test_no_init_keys_in_result(self, complete_config, additive_surgery_chain): - """Verify no 'init' keys leak through.""" + def test_init_keys_preserved_for_planning(self, complete_config, additive_surgery_chain): + """Verify 'init' keys are preserved for plan_surgery to see. - def check_no_init(d, path=""): - if isinstance(d, dict): - assert "init" not in d, f"Found 'init' key at {path}" - for k, v in d.items(): - check_no_init(v, f"{path}.{k}") + The `init` field is metadata for weight initialization. It's preserved + through composition and only stripped when saving final output. + """ + from fast_llm_external_models.apriel2.conversion.config import strip_init_fields result = complete_config for i, surgery in enumerate(additive_surgery_chain): result = compose_configs(result, surgery) - check_no_init(result, f"step_{i+1}") + + # init should be in the composed config + mixer = result["decoder"]["block"]["mixer"] + if "mixers" in mixer: + has_init = any("init" in m for m in mixer["mixers"].values()) + assert has_init, "init should be preserved in composed config" + + # strip_init_fields removes them + stripped = strip_init_fields(result) + mixer = stripped["decoder"]["block"]["mixer"] + if "mixers" in mixer: + assert all("init" not in m for m in mixer["mixers"].values()) def test_full_torture_chain(self, complete_config, torture_surgery_chain): """Test the full 10-step torture chain produces valid configs.""" diff --git a/fast_llm_external_models/tests/test_apriel2/test_plan_composition_torture.py b/fast_llm_external_models/tests/test_apriel2/test_conversion_e2e.py similarity index 78% rename from fast_llm_external_models/tests/test_apriel2/test_plan_composition_torture.py rename to fast_llm_external_models/tests/test_apriel2/test_conversion_e2e.py index 3b4adc7f5..b91fb7e51 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_plan_composition_torture.py +++ b/fast_llm_external_models/tests/test_apriel2/test_conversion_e2e.py @@ -1,6 +1,6 @@ -"""End-to-end torture test for plan composition. +"""test_conversion_e2e.py - End-to-end conversion integration tests. -This tests the FULL pipeline at every step of a surgery chain: +Tests the FULL pipeline at every step of a surgery chain: 1. Config composition produces valid configs 2. Plan building works for each surgery 3. Plan execution produces valid weights @@ -16,21 +16,12 @@ import pytest import torch -from fast_llm_external_models.tests.test_apriel2.conftest import requires_cuda - from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config -from fast_llm_external_models.apriel2.conversion import ( - compose, - compose_configs, - execute, - plan_surgery, -) -from fast_llm_external_models.apriel2.conversion.llava import ( - convert_config as convert_llava_config, - plan_llava_to_apriel2, -) +from fast_llm_external_models.apriel2.conversion import compose, compose_configs, execute, plan_surgery +from fast_llm_external_models.apriel2.conversion.llava import convert_config as convert_llava_config +from fast_llm_external_models.apriel2.conversion.llava import plan_llava_to_apriel2 from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForConditionalGeneration - +from fast_llm_external_models.tests.test_apriel2.conftest import requires_cuda # ============================================================================= # Cycling Surgery Generation @@ -87,40 +78,20 @@ def generate_cycling_surgeries(config: dict) -> list[tuple[dict, str]]: if sub_name != main_mixer: # Build surgery path based on block_path if block_path == "block": - surgery = { - "decoder": { - "block": {"mixer": {"main_mixer_name": sub_name}} - } - } + surgery = {"decoder": {"block": {"mixer": {"main_mixer_name": sub_name}}}} else: # block_path is "blocks.block_name" block_name = block_path.split(".")[1] - surgery = { - "decoder": { - "blocks": { - block_name: {"mixer": {"main_mixer_name": sub_name}} - } - } - } + surgery = {"decoder": {"blocks": {block_name: {"mixer": {"main_mixer_name": sub_name}}}}} surgeries.append((surgery, f"cycle {block_path} to {sub_name}")) # Restore original main_mixer_name if any(sub_name != main_mixer for sub_name in sub_mixer_names): if block_path == "block": - restore = { - "decoder": { - "block": {"mixer": {"main_mixer_name": main_mixer}} - } - } + restore = {"decoder": {"block": {"mixer": {"main_mixer_name": main_mixer}}}} else: block_name = block_path.split(".")[1] - restore = { - "decoder": { - "blocks": { - block_name: {"mixer": {"main_mixer_name": main_mixer}} - } - } - } + restore = {"decoder": {"blocks": {block_name: {"mixer": {"main_mixer_name": main_mixer}}}}} surgeries.append((restore, f"restore {block_path} to {main_mixer}")) return surgeries @@ -194,9 +165,7 @@ def source_config(self, llava_pixtral_checkpoint): with open(llava_pixtral_checkpoint / "config.json") as f: return json.load(f) - def test_initial_conversion_produces_working_model( - self, source_config, source_weights - ): + def test_initial_conversion_produces_working_model(self, source_config, source_weights): """Test that Llava โ†’ Apriel2 conversion produces a working model.""" # Convert config apriel2_config_dict = convert_llava_config(source_config) @@ -219,9 +188,7 @@ def test_initial_conversion_produces_working_model( assert outputs.logits.shape == (1, 8, config.vocab_size) - def test_each_surgery_step_produces_working_model( - self, source_config, source_weights, additive_surgery_chain - ): + def test_each_surgery_step_produces_working_model(self, source_config, source_weights, additive_surgery_chain): """Test that each surgery step produces a model that can forward pass. Key insight: Surgery plans reference Apriel2 keys, so we must COMPOSE @@ -290,9 +257,7 @@ def test_each_surgery_step_produces_working_model( except Exception as e: pytest.fail(f"Step {i+1}: Forward pass failed - {e}") - def test_all_stochastic_submixers_via_cycling( - self, source_config, source_weights, additive_surgery_chain - ): + def test_all_stochastic_submixers_via_cycling(self, source_config, source_weights, additive_surgery_chain): """Test ALL sub-mixers in stochastic blocks, not just the main mixer. Problem: Forward pass only exercises the main_mixer_name. Other sub-mixers @@ -312,9 +277,7 @@ def test_all_stochastic_submixers_via_cycling( conversion_plan = plan_llava_to_apriel2(source_config) # Expand surgery chain with cycling - expanded_chain = expand_surgery_chain_with_cycling( - additive_surgery_chain, apriel2_config - ) + expanded_chain = expand_surgery_chain_with_cycling(additive_surgery_chain, apriel2_config) # Build cumulative plan: conversion | surgery_1 | cycling_1a | ... | restore_1 | surgery_2 | ... current_plan = conversion_plan @@ -359,9 +322,7 @@ def test_all_stochastic_submixers_via_cycling( except Exception as e: pytest.fail(f"{desc}: Forward pass failed - {e}") - def test_composed_plan_equals_sequential_execution( - self, source_config, source_weights, additive_surgery_chain - ): + def test_composed_plan_equals_sequential_execution(self, source_config, source_weights, additive_surgery_chain): """Test that composing plans gives same result as sequential execution. This verifies plan composition associativity: @@ -399,13 +360,9 @@ def test_composed_plan_equals_sequential_execution( # Compare weights for key in seq_weights: if key in composed_weights: - assert torch.allclose( - seq_weights[key], composed_weights[key], atol=1e-5 - ), f"Weight mismatch for {key}" + assert torch.allclose(seq_weights[key], composed_weights[key], atol=1e-5), f"Weight mismatch for {key}" - def test_final_model_structure( - self, source_config, source_weights, additive_surgery_chain - ): + def test_final_model_structure(self, source_config, source_weights, additive_surgery_chain): """Verify the final model has the expected structure.""" # Initial conversion current_config = convert_llava_config(source_config) @@ -504,9 +461,7 @@ def base_setup(self, llava_pixtral_checkpoint): """Set up base config and weights after Llava conversion.""" from safetensors.torch import load_file - from fast_llm_external_models.apriel2.conversion.llava import ( - convert_config as convert_llava_config, - ) + from fast_llm_external_models.apriel2.conversion.llava import convert_config as convert_llava_config # Load source config and weights with open(llava_pixtral_checkpoint / "config.json") as f: @@ -534,9 +489,7 @@ def _merge_surgeries(self, surgeries: list[dict]) -> dict: result = _deep_merge(result, s) return result - def _build_incremental_plans( - self, base_config: dict, surgeries: list[dict] - ) -> tuple[list, list[dict]]: + def _build_incremental_plans(self, base_config: dict, surgeries: list[dict]) -> tuple[list, list[dict]]: """Build incremental plans for each surgery step. Returns (plans, configs) where configs[i] is the config after surgery i. @@ -552,9 +505,7 @@ def _build_incremental_plans( config = target_config return plans, configs - def test_incremental_equals_direct_full_chain( - self, base_setup, additive_surgery_chain - ): + def test_incremental_equals_direct_full_chain(self, base_setup, additive_surgery_chain): """Test that composing all incremental plans equals one direct plan. compose(P1, P2, ..., Pn) โ‰ก plan_surgery(base, final) @@ -575,9 +526,7 @@ def test_incremental_equals_direct_full_chain( direct_plan = plan_surgery(base_config, final_config) # Verify same target keys - assert set(composed_plan.mappings.keys()) == set( - direct_plan.mappings.keys() - ), "Plan keys should match" + assert set(composed_plan.mappings.keys()) == set(direct_plan.mappings.keys()), "Plan keys should match" # Execute both and compare weights composed_weights = execute(composed_plan, base_weights, seed=0) @@ -611,9 +560,7 @@ def test_every_prefix_consistency(self, base_setup, additive_surgery_chain): direct = plan_surgery(base_config, configs[k]) # Verify keys match - assert set(composed.mappings.keys()) == set( - direct.mappings.keys() - ), f"Prefix {k}: keys don't match" + assert set(composed.mappings.keys()) == set(direct.mappings.keys()), f"Prefix {k}: keys don't match" # Execute and compare composed_weights = execute(composed, base_weights, seed=0) @@ -781,9 +728,7 @@ def torture_setup(self, llava_pixtral_checkpoint): """Set up for comprehensive torture tests.""" from safetensors.torch import load_file - from fast_llm_external_models.apriel2.conversion.llava import ( - convert_config as convert_llava_config, - ) + from fast_llm_external_models.apriel2.conversion.llava import convert_config as convert_llava_config # Load source with open(llava_pixtral_checkpoint / "config.json") as f: @@ -801,9 +746,7 @@ def torture_setup(self, llava_pixtral_checkpoint): return base_config, base_weights - def test_each_step_produces_valid_config( - self, torture_setup, comprehensive_torture_chain - ): + def test_each_step_produces_valid_config(self, torture_setup, comprehensive_torture_chain): """Test that each surgery step produces a valid config.""" base_config, _ = torture_setup @@ -818,9 +761,7 @@ def test_each_step_produces_valid_config( pytest.fail(f"Step {i+1} produced invalid config: {e}") @requires_cuda - def test_each_step_produces_working_model( - self, torture_setup, comprehensive_torture_chain - ): + def test_each_step_produces_working_model(self, torture_setup, comprehensive_torture_chain): """Test that each surgery step produces a model that can forward pass. This is the ultimate integration test - config composition + plan building @@ -875,9 +816,7 @@ def test_each_step_produces_working_model( current_weights = new_weights @requires_cuda - def test_final_supernet_structure( - self, torture_setup, comprehensive_torture_chain - ): + def test_final_supernet_structure(self, torture_setup, comprehensive_torture_chain): """Verify the final architecture has supernet blocks with all 4 mixer types.""" base_config, base_weights = torture_setup @@ -914,9 +853,7 @@ def test_final_supernet_structure( assert outputs.logits.shape == (1, 8, config.vocab_size) @requires_cuda - def test_plan_config_consistency_comprehensive( - self, torture_setup, comprehensive_torture_chain - ): + def test_plan_config_consistency_comprehensive(self, torture_setup, comprehensive_torture_chain): """Test that incremental plan composition works for the comprehensive chain. Note: We cannot compare to a "direct plan" because the comprehensive chain @@ -1083,66 +1020,6 @@ def mamba_config(self): }, } - def test_config_composition_identical_regardless_of_init_mode(self, base_config): - """Config composition produces same structure with init: transfer vs init: random.""" - # Surgery with init: transfer - surgery_transfer = { - "decoder": { - "block": { - "mixer": { - "type": "stochastic", - "main_mixer_name": "attention", - "mixers": { - "attention": {"type": "attention", "init": "transfer"}, - "swa": { - "type": "attention", - "init": "transfer", - "sliding_window": 512, - }, - }, - }, - }, - }, - } - - # Surgery with init: random - surgery_random = { - "decoder": { - "block": { - "mixer": { - "type": "stochastic", - "main_mixer_name": "attention", - "mixers": { - "attention": {"type": "attention", "init": "random"}, - "swa": { - "type": "attention", - "init": "random", - "sliding_window": 512, - }, - }, - }, - }, - }, - } - - # Compose configs - result_transfer = compose_configs(base_config, surgery_transfer) - result_random = compose_configs(base_config, surgery_random) - - # Both should produce identical structure (init is stripped) - assert result_transfer == result_random, ( - "Config composition should produce identical structure regardless of init mode" - ) - - # Verify the structure is correct - mixer = result_transfer["decoder"]["block"]["mixer"] - assert mixer["type"] == "stochastic" - assert "attention" in mixer["mixers"] - assert "swa" in mixer["mixers"] - # init should be stripped - assert "init" not in mixer["mixers"]["attention"] - assert "init" not in mixer["mixers"]["swa"] - def test_plan_surgery_random_succeeds_for_any_type_pair(self, mamba_config): """plan_surgery with init: random should succeed even for mamba -> attention.""" # This surgery changes mamba to attention with random init @@ -1166,7 +1043,7 @@ def test_plan_surgery_random_succeeds_for_any_type_pair(self, mamba_config): plan = plan_surgery(mamba_config, surgery) # Verify the plan has the expected target keys - target_keys = set(str(k) for k in plan.mappings.keys()) + target_keys = {str(k) for k in plan.mappings.keys()} assert any("mixer.q_proj" in k for k in target_keys) def test_plan_surgery_transfer_fails_for_unsupported_type_pair(self, mamba_config): @@ -1219,7 +1096,7 @@ def test_plan_surgery_transfer_succeeds_for_supported_type_pair(self, base_confi plan = plan_surgery(base_config, surgery) # Verify the plan has mamba target keys - target_keys = set(str(k) for k in plan.mappings.keys()) + target_keys = {str(k) for k in plan.mappings.keys()} assert any("mixer.in_proj" in k for k in target_keys) def test_stochastic_init_random_succeeds_for_any_submixer_type(self, mamba_config): @@ -1259,7 +1136,7 @@ def test_stochastic_init_random_succeeds_for_any_submixer_type(self, mamba_confi plan = plan_surgery(mamba_config, surgery) # Verify both sub-mixers have target keys - target_keys = set(str(k) for k in plan.mappings.keys()) + target_keys = {str(k) for k in plan.mappings.keys()} assert any("mixers.attention.q_proj" in k for k in target_keys) assert any("mixers.swa.q_proj" in k for k in target_keys) @@ -1294,7 +1171,7 @@ def test_mixed_init_modes_in_stochastic(self, base_config): plan = plan_surgery(base_config, surgery) # Verify both sub-mixers have target keys - target_keys = set(str(k) for k in plan.mappings.keys()) + target_keys = {str(k) for k in plan.mappings.keys()} assert any("mixers.attention.q_proj" in k for k in target_keys) assert any("mixers.gdn.in_proj_qkvz" in k for k in target_keys) @@ -1313,8 +1190,8 @@ class TestMarkovianProperty: """ @pytest.fixture - def attention_config(self): - """Base config with attention.""" + def attention_config_dict(self): + """Base config dict with attention mixer for compose_configs tests.""" return { "model_type": "apriel2", "hidden_size": 256, @@ -1335,43 +1212,7 @@ def attention_config(self): }, } - @pytest.fixture - def stochastic_config(self): - """Config with stochastic mixer.""" - return { - "model_type": "apriel2", - "hidden_size": 256, - "vocab_size": 1000, - "decoder": { - "type": "fixed", - "num_blocks": 2, - "block": { - "mixer": { - "type": "stochastic", - "main_mixer_name": "attention", - "mixers": { - "attention": { - "type": "attention", - "heads": 8, - "head_groups": 4, - "head_size": 32, - }, - "swa": { - "type": "sliding_window", - "heads": 8, - "head_groups": 4, - "head_size": 32, - "window_size": 512, - }, - }, - }, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - }, - } - - def test_different_paths_same_config_same_plan(self, attention_config): + def test_different_paths_same_config_same_plan(self, attention_config_dict): """Two different paths to the same config produce identical plans. Path A: attention -> stochastic{att, swa} @@ -1398,7 +1239,7 @@ def test_different_paths_same_config_same_plan(self, attention_config): }, }, } - config_a = compose_configs(attention_config, surgery_a) + config_a = compose_configs(attention_config_dict, surgery_a) # Path B: First add attention only, then add swa surgery_b1 = { @@ -1414,7 +1255,7 @@ def test_different_paths_same_config_same_plan(self, attention_config): }, }, } - intermediate_config = compose_configs(attention_config, surgery_b1) + intermediate_config = compose_configs(attention_config_dict, surgery_b1) surgery_b2 = { "decoder": { @@ -1465,11 +1306,11 @@ def test_different_paths_same_config_same_plan(self, attention_config): plan_from_b = plan_surgery(config_b, final_surgery) # Compare plan mappings - keys_a = set(str(k) for k in plan_from_a.mappings.keys()) - keys_b = set(str(k) for k in plan_from_b.mappings.keys()) + keys_a = {str(k) for k in plan_from_a.mappings.keys()} + keys_b = {str(k) for k in plan_from_b.mappings.keys()} assert keys_a == keys_b, "Plans from same config via different paths should be identical" - def test_init_in_source_config_does_not_affect_plan(self, attention_config): + def test_init_in_source_config_does_not_affect_plan(self, attention_config_dict): """Manually injecting init into source config doesn't change the plan. This tests that plan_surgery reads init from surgery, not source. @@ -1479,8 +1320,8 @@ def test_init_in_source_config_does_not_affect_plan(self, attention_config): import copy # Create two copies of the config - config_with_init = copy.deepcopy(attention_config) - config_without_init = copy.deepcopy(attention_config) + config_with_init = copy.deepcopy(attention_config_dict) + config_without_init = copy.deepcopy(attention_config_dict) # Manually inject init into one (bypassing compose_configs) config_with_init["decoder"]["block"]["mixer"]["init"] = "random" @@ -1504,238 +1345,12 @@ def test_init_in_source_config_does_not_affect_plan(self, attention_config): plan_with = plan_surgery(config_with_init, surgery) plan_without = plan_surgery(config_without_init, surgery) - keys_with = set(str(k) for k in plan_with.mappings.keys()) - keys_without = set(str(k) for k in plan_without.mappings.keys()) + keys_with = {str(k) for k in plan_with.mappings.keys()} + keys_without = {str(k) for k in plan_without.mappings.keys()} # Plans should be identical - source's init field is ignored assert keys_with == keys_without, "Plan should not depend on init in source config" - def test_associativity_of_surgery_composition(self, attention_config): - """Verify associativity: (A โˆ˜ B) โˆ˜ C == A โˆ˜ (B โˆ˜ C) for surgery specs. - - This tests that composing surgeries is associative, which is - equivalent to Markovianity for plan creation. - """ - surgery_a = { - "decoder": { - "block": { - "mixer": { - "type": "stochastic", - "main_mixer_name": "attention", - "mixers": { - "attention": {"type": "attention", "init": "transfer"}, - }, - }, - }, - }, - } - - surgery_b = { - "decoder": { - "block": { - "mixer": { - "mixers": { - "swa": { - "type": "sliding_window", - "init": "transfer", - "window_size": 512, - }, - }, - }, - }, - }, - } - - surgery_c = { - "decoder": { - "block": { - "mixer": { - "mixers": { - "gdn": { - "type": "gdn", - "init": "random", - "value_heads": 8, - "key_heads": 4, - "key_head_dim": 32, - "value_head_dim": 32, - "convolution_layer": {"kernel_size": 4}, - }, - }, - }, - }, - }, - } - - # Left association: ((attention_config โˆ˜ A) โˆ˜ B) โˆ˜ C - left_1 = compose_configs(attention_config, surgery_a) - left_2 = compose_configs(left_1, surgery_b) - left_result = compose_configs(left_2, surgery_c) - - # Right association: (attention_config โˆ˜ A) โˆ˜ (B โˆ˜ C) - # Note: B โˆ˜ C is partial โˆ˜ partial = deep merge of surgery specs - bc_merged = compose_configs(surgery_b, surgery_c) - right_1 = compose_configs(attention_config, surgery_a) - right_result = compose_configs(right_1, bc_merged) - - assert left_result == right_result, "Surgery composition should be associative" - - def test_complete_configs_have_no_init_fields(self, attention_config): - """Verify that compose_configs strips init from complete configs. - - This is the key invariant that enables Markovianity: - - Complete configs (states) have no init fields - - Surgery specs (transitions) have init fields - - Plans read init from surgery, not state - """ - surgery_with_init = { - "decoder": { - "block": { - "mixer": { - "type": "stochastic", - "main_mixer_name": "attention", - "mixers": { - "attention": {"type": "attention", "init": "transfer"}, - "swa": {"type": "sliding_window", "init": "random", "window_size": 512}, - }, - }, - }, - }, - } - - result = compose_configs(attention_config, surgery_with_init) - - # Recursively check for init fields - def has_init(obj): - if isinstance(obj, dict): - if "init" in obj: - return True - return any(has_init(v) for v in obj.values()) - if isinstance(obj, list): - return any(has_init(v) for v in obj) - return False - - assert not has_init(result), "Complete configs should have no init fields" - - def test_monoid_action_law_additive_surgeries(self): - """Monoid action law HOLDS for additive surgeries. - - Additive surgeries (no type: declaration) support: - apply(apply(s, t1), t2) == apply(s, t1 โˆ˜ t2) - - This is because additive operations commute nicely: - "add {a}" then "add {b}" == "add {a, b}" - """ - # Start with stochastic (additive surgery target) - s = { - "model_type": "apriel2", - "hidden_size": 256, - "vocab_size": 1000, - "decoder": { - "type": "fixed", - "num_blocks": 2, - "block": { - "mixer": { - "type": "stochastic", - "main_mixer_name": "attention", - "mixers": { - "attention": {"type": "attention", "heads": 8, "head_groups": 4, "head_size": 32}, - }, - }, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - }, - } - - # Additive surgeries (no type: declaration) - t1 = {"decoder": {"block": {"mixer": {"mixers": {"swa": {"type": "sliding_window", "window_size": 512}}}}}} - t2 = {"decoder": {"block": {"mixer": {"mixers": {"mamba": {"type": "mamba", "d_inner": 512}}}}}} - - # Path A: Sequential - s_prime = compose_configs(s, t1) - s_double_prime_A = compose_configs(s_prime, t2) - - # Path B: Composed - t1_t2 = compose_configs(t1, t2) - s_double_prime_B = compose_configs(s, t1_t2) - - assert s_double_prime_A == s_double_prime_B, "Monoid action law should hold for additive surgeries" - - def test_monoid_action_law_replacement_surgeries_fails(self): - """Monoid action law FAILS for replacement surgeries (by design). - - Replacement surgeries (type: stochastic declared) have: - apply(apply(s, t1), t2) != apply(s, t1 โˆ˜ t2) - - This is FUNDAMENTAL, not a bug: - - Sequential: "set to {a}" then "set to {b}" โ†’ {b} (second wins) - - Composed: merge({a}, {b}) = {a,b}, then apply โ†’ {a,b} - - These are genuinely different semantics. The failure documents - the distinction between declarative composition (merge) and - operational composition (function application). - """ - s = { - "model_type": "apriel2", - "hidden_size": 256, - "vocab_size": 1000, - "decoder": { - "type": "fixed", - "num_blocks": 2, - "block": { - "mixer": {"type": "attention", "heads": 8, "head_groups": 4, "head_size": 32}, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - }, - } - - # Replacement surgeries (both declare type: stochastic) - t1 = { - "decoder": { - "block": { - "mixer": { - "type": "stochastic", - "main_mixer_name": "attention", - "mixers": {"attention": {"type": "attention"}}, - } - } - } - } - t2 = { - "decoder": { - "block": { - "mixer": { - "type": "stochastic", - "main_mixer_name": "swa", - "mixers": {"swa": {"type": "sliding_window", "window_size": 512}}, - } - } - } - } - - # Path A: Sequential (second replacement wins) - s_prime = compose_configs(s, t1) - s_double_prime_A = compose_configs(s_prime, t2) - - # Path B: Composed (declarations merged) - t1_t2 = compose_configs(t1, t2) - s_double_prime_B = compose_configs(s, t1_t2) - - # They should be DIFFERENT (law fails) - assert s_double_prime_A != s_double_prime_B, ( - "Monoid action law should FAIL for replacement surgeries" - ) - - # Verify the specific difference: - # Sequential: only swa (second replacement wins) - # Composed: both attention and swa (merged declarations) - mixers_A = set(s_double_prime_A["decoder"]["block"]["mixer"]["mixers"].keys()) - mixers_B = set(s_double_prime_B["decoder"]["block"]["mixer"]["mixers"].keys()) - - assert mixers_A == {"swa"}, "Sequential: second replacement wins" - assert mixers_B == {"attention", "swa"}, "Composed: declarations merged" - class TestCyclingSurgeryGeneration: """Tests for the cycling surgery generation functions. @@ -1936,7 +1551,7 @@ def test_expand_surgery_chain_adds_cycling(self): # Verify restore flag assert expanded[0][2] is False # surgery - not restore assert expanded[1][2] is False # cycle - not restore - assert expanded[2][2] is True # restore + assert expanded[2][2] is True # restore def test_expand_surgery_chain_preserves_invariant(self): """Test that cycling leaves the chain state invariant.""" @@ -1980,3 +1595,151 @@ def test_expand_surgery_chain_preserves_invariant(self): # After cycling and restore, we should be back to the same state assert current_config == config_after_original + + +class TestBiasSurgeryChain: + """Torture tests for per-layer bias inheritance through surgery operations. + + Uses apriel2_config_with_bias + bias_surgery_chain to test that: + - Qwen-style per-layer attention bias (QKV enabled, O disabled) survives surgery + - Non-gated MLP per-layer bias (layer_1 enabled, layer_2 disabled) survives surgery + - Bias settings are correctly inherited by new sub-mixers + - Bias is correctly tracked in surgery plans + """ + + @pytest.fixture + def bias_source_config(self, apriel2_config_with_bias): + """Convert Apriel2Config to dict for surgery operations.""" + return apriel2_config_with_bias.to_dict() + + def test_bias_survives_stochastic_wrapper(self, bias_source_config, bias_surgery_chain): + """Test that bias settings survive wrapping in stochastic mixer.""" + # Apply first surgery (wrap in stochastic) + result = compose_configs(bias_source_config, bias_surgery_chain[0]) + + # Check attention sub-mixer inherited bias settings + mixer = result["decoder"]["block"]["mixer"] + assert mixer["type"] == "stochastic" + + attn_mixer = mixer["mixers"]["attention"] + assert attn_mixer["query_layer"]["bias"]["enabled"] is True + assert attn_mixer["key_layer"]["bias"]["enabled"] is True + assert attn_mixer["value_layer"]["bias"]["enabled"] is True + assert attn_mixer["dense_layer"]["bias"]["enabled"] is False + + # Check MLP bias survived + mlp = result["decoder"]["block"]["mlp"] + assert mlp["layer_1"]["bias"]["enabled"] is True + assert mlp["layer_2"]["bias"]["enabled"] is False + + def test_new_submixer_inherits_bias(self, bias_source_config, bias_surgery_chain): + """Test that new sub-mixers inherit bias from source attention.""" + # Apply S1 + S2 (wrap in stochastic, add sliding_window) + config = bias_source_config + for surgery in bias_surgery_chain[:2]: + config = compose_configs(config, surgery) + + # sliding_window should inherit bias from source attention + mixer = config["decoder"]["block"]["mixer"] + sw_mixer = mixer["mixers"]["sliding_window"] + + assert sw_mixer["query_layer"]["bias"]["enabled"] is True + assert sw_mixer["key_layer"]["bias"]["enabled"] is True + assert sw_mixer["value_layer"]["bias"]["enabled"] is True + assert sw_mixer["dense_layer"]["bias"]["enabled"] is False + + def test_full_bias_chain_produces_valid_config(self, bias_source_config, bias_surgery_chain): + """Test that full bias surgery chain produces valid config.""" + config = bias_source_config + for surgery in bias_surgery_chain: + config = compose_configs(config, surgery) + + # Verify final config structure + mixer = config["decoder"]["block"]["mixer"] + assert mixer["type"] == "stochastic" + assert "attention" in mixer["mixers"] + assert "sliding_window" in mixer["mixers"] + assert "full_bias_attn" in mixer["mixers"] + + # attention and sliding_window inherit Qwen-style bias + for name in ["attention", "sliding_window"]: + sub = mixer["mixers"][name] + assert sub["query_layer"]["bias"]["enabled"] is True + assert sub["dense_layer"]["bias"]["enabled"] is False + + # full_bias_attn has add_linear_biases=True but per-layer settings inherited from + # source take precedence, so O bias is still disabled + full_bias = mixer["mixers"]["full_bias_attn"] + assert full_bias.get("add_linear_biases") is True + # Per-layer dense_layer.bias.enabled=False inherited from source takes precedence + assert full_bias["dense_layer"]["bias"]["enabled"] is False + + def test_bias_plan_has_correct_mappings(self, bias_source_config, bias_surgery_chain): + """Test that surgery plan correctly includes/excludes bias weight mappings.""" + # Compose config first to get full target config with inherited bias settings + target_config = compose_configs(bias_source_config, bias_surgery_chain[0]) + plan = plan_surgery(bias_source_config, target_config) + mapping_strs = [str(k) for k in plan.mappings.keys()] + + # Should have q_proj.bias (enabled) + q_bias = [m for m in mapping_strs if "q_proj.bias" in m] + assert len(q_bias) > 0, "Should have q_proj.bias mappings" + + # Should NOT have o_proj.bias (disabled) + o_bias = [m for m in mapping_strs if "o_proj.bias" in m] + assert len(o_bias) == 0, "Should not have o_proj.bias mappings" + + # Should have up_proj.bias (layer_1 enabled) + up_bias = [m for m in mapping_strs if "up_proj.bias" in m] + assert len(up_bias) > 0, "Should have up_proj.bias mappings" + + # Should NOT have down_proj.bias (layer_2 disabled) + down_bias = [m for m in mapping_strs if "down_proj.bias" in m] + assert len(down_bias) == 0, "Should not have down_proj.bias mappings" + + def test_bias_chain_produces_working_model(self, bias_source_config, bias_surgery_chain): + """Test that bias surgery chain produces a working model.""" + from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForCausalLM + + # Apply full chain + config = bias_source_config + for surgery in bias_surgery_chain: + config = compose_configs(config, surgery) + + # Create model + apriel_config = Apriel2Config(**config) + model = Apriel2ForCausalLM(apriel_config) + model.eval() + + # Verify model structure has correct biases + block = model.model.decoder.blocks[0] + + # attention sub-mixer should have QKV bias, no O bias + attn = block.mixer.mixers["attention"] + assert attn.q_proj.bias is not None + assert attn.k_proj.bias is not None + assert attn.v_proj.bias is not None + assert attn.o_proj.bias is None + + # sliding_window should also inherit bias settings + sw = block.mixer.mixers["sliding_window"] + assert sw.q_proj.bias is not None + assert sw.o_proj.bias is None + + # full_bias_attn inherits per-layer bias from source (even with add_linear_biases=True, + # per-layer settings take precedence in same-type inheritance) + full_bias = block.mixer.mixers["full_bias_attn"] + assert full_bias.q_proj.bias is not None + # O bias is disabled because inherited per-layer dense_layer.bias.enabled=False + # takes precedence over add_linear_biases=True + assert full_bias.o_proj.bias is None + + # MLP should have layer_1 bias, no layer_2 bias + assert block.mlp.up_proj.bias is not None + assert block.mlp.down_proj.bias is None + + # Forward pass should work + input_ids = torch.randint(0, config["vocab_size"], (1, 10)) + with torch.no_grad(): + outputs = model(input_ids, use_cache=False) + assert outputs.logits.shape == (1, 10, config["vocab_size"]) diff --git a/fast_llm_external_models/tests/test_apriel2/test_convert_from_llava.py b/fast_llm_external_models/tests/test_apriel2/test_convert_from_llava.py index a437f920d..f96f5ac40 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_convert_from_llava.py +++ b/fast_llm_external_models/tests/test_apriel2/test_convert_from_llava.py @@ -14,23 +14,15 @@ """ import json -from pathlib import Path -import pytest import torch from safetensors import safe_open -from safetensors.torch import save_file from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config -from fast_llm_external_models.apriel2.conversion import ( - convert_llava_config as convert_config, - execute, - plan_llava_to_apriel2, - plan_surgery, -) +from fast_llm_external_models.apriel2.conversion import convert_llava_config as convert_config +from fast_llm_external_models.apriel2.conversion import execute, plan_llava_to_apriel2, plan_surgery from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForConditionalGeneration - # ============================================================================= # Config Conversion Tests # ============================================================================= @@ -330,9 +322,9 @@ def test_plan_keys_match_model_state_dict(self, llava_pixtral_checkpoint): extra_in_plan = plan_keys - model_keys # Filter out expected missing keys (caches, positions, etc.) - missing_in_plan = {k for k in missing_in_plan if not any( - skip in k.lower() for skip in ["cache", "position", "mask"] - )} + missing_in_plan = { + k for k in missing_in_plan if not any(skip in k.lower() for skip in ["cache", "position", "mask"]) + } assert not missing_in_plan, f"Model keys not in plan: {sorted(missing_in_plan)[:10]}" assert not extra_in_plan, f"Plan keys not in model: {sorted(extra_in_plan)[:10]}" diff --git a/fast_llm_external_models/tests/test_apriel2/test_equivalence.py b/fast_llm_external_models/tests/test_apriel2/test_equivalence.py index c59ed2000..9b3eb4efe 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_equivalence.py +++ b/fast_llm_external_models/tests/test_apriel2/test_equivalence.py @@ -23,9 +23,6 @@ import torch from transformers import LlavaForConditionalGeneration -from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForConditionalGeneration - - # ============================================================================= # Input Configuration # ============================================================================= @@ -487,8 +484,10 @@ def test_batch_processing_behavior(self, model_pair): batch_tgt = target.get_image_features(pixel_values).view(-1, batch_src.shape[-1]) # Sequential processing - singles_src = [get_pixtral_vision_features(source, pixel_values[i:i+1]) for i in range(3)] - singles_tgt = [target.get_image_features(pixel_values[i:i+1]).view(-1, batch_src.shape[-1]) for i in range(3)] + singles_src = [get_pixtral_vision_features(source, pixel_values[i : i + 1]) for i in range(3)] + singles_tgt = [ + target.get_image_features(pixel_values[i : i + 1]).view(-1, batch_src.shape[-1]) for i in range(3) + ] single_concat_src = torch.cat(singles_src, dim=0) single_concat_tgt = torch.cat(singles_tgt, dim=0) @@ -500,9 +499,9 @@ def test_batch_processing_behavior(self, model_pair): print(f"Apriel2 batch vs sequential: {tgt_diff:.6f}") # Both should have the same behavior (within FP tolerance) - assert abs(src_diff - tgt_diff) < 1e-6, ( - f"Batch processing behavior differs: src={src_diff:.6f}, tgt={tgt_diff:.6f}" - ) + assert ( + abs(src_diff - tgt_diff) < 1e-6 + ), f"Batch processing behavior differs: src={src_diff:.6f}, tgt={tgt_diff:.6f}" if __name__ == "__main__": diff --git a/fast_llm_external_models/tests/test_apriel2/test_expr_plan.py b/fast_llm_external_models/tests/test_apriel2/test_expr_plan.py index c487ab3a3..2dccac5ad 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_expr_plan.py +++ b/fast_llm_external_models/tests/test_apriel2/test_expr_plan.py @@ -1,15 +1,13 @@ """Tests for the expression-based plan system.""" import json + import pytest import torch -from fast_llm_external_models.tests.test_apriel2.conftest import requires_cuda - from fast_llm_external_models.apriel2.conversion import ( Concat, EvalKwargs, - Expr, ExprAdapter, ExprPlan, Init, @@ -18,10 +16,9 @@ Slice, StreamingExecutor, W, - compose, execute, - fuse, full_slice, + fuse, make_slice, plan_dil_attention_to_gdn, plan_kil_attention_to_kda, @@ -31,6 +28,7 @@ slice_spec, substitute, ) +from fast_llm_external_models.tests.test_apriel2.conftest import requires_cuda def make_eval_kwargs( @@ -219,10 +217,13 @@ def test_substitute_init_unchanged(self): def test_substitute_complex(self): """Substitute handles complex nested expressions.""" # Concat of Slice(Ref) and Init - expr = Concat(exprs=( - Slice(expr=Ref(key=W("a")), slices=((0, 5, None),)), - Init(shape=(5,), init_type="zeros"), - ), dim=0) + expr = Concat( + exprs=( + Slice(expr=Ref(key=W("a")), slices=((0, 5, None),)), + Init(shape=(5,), init_type="zeros"), + ), + dim=0, + ) bindings = {W("a"): Ref(key=W("source"))} result = substitute(expr, bindings) @@ -238,7 +239,13 @@ class TestFuse: def test_fuse_flatten_concat(self): """Fuse flattens nested Concat with same dim.""" inner = Concat(exprs=(Ref(key=W("a")), Ref(key=W("b"))), dim=0) - outer = Concat(exprs=(inner, Ref(key=W("c")),), dim=0) + outer = Concat( + exprs=( + inner, + Ref(key=W("c")), + ), + dim=0, + ) result = fuse(outer) assert isinstance(result, Concat) @@ -250,7 +257,13 @@ def test_fuse_flatten_concat(self): def test_fuse_no_flatten_different_dim(self): """Fuse doesn't flatten Concat with different dim.""" inner = Concat(exprs=(Ref(key=W("a")), Ref(key=W("b"))), dim=1) - outer = Concat(exprs=(inner, Ref(key=W("c")),), dim=0) + outer = Concat( + exprs=( + inner, + Ref(key=W("c")), + ), + dim=0, + ) result = fuse(outer) assert isinstance(result, Concat) @@ -340,28 +353,34 @@ class TestExprPlan: def test_plan_define_and_access(self): """Plan stores and retrieves expressions.""" - plan = ExprPlan(mappings={ - W("target"): Ref(key=W("source")), - }) + plan = ExprPlan( + mappings={ + W("target"): Ref(key=W("source")), + } + ) assert W("target") in plan assert isinstance(plan[W("target")], Ref) def test_plan_source_keys(self): """Plan identifies all source references.""" - plan = ExprPlan(mappings={ - W("a"): Ref(key=W("x")), - W("b"): Concat(exprs=(Ref(key=W("y")), Ref(key=W("z"))), dim=0), - W("c"): Init(shape=(10,), init_type="zeros"), - }) + plan = ExprPlan( + mappings={ + W("a"): Ref(key=W("x")), + W("b"): Concat(exprs=(Ref(key=W("y")), Ref(key=W("z"))), dim=0), + W("c"): Init(shape=(10,), init_type="zeros"), + } + ) assert plan.source_keys() == {W("x"), W("y"), W("z")} def test_plan_target_keys(self): """Plan identifies all target keys.""" - plan = ExprPlan(mappings={ - W("a"): Ref(key=W("x")), - W("b"): Ref(key=W("y")), - }) + plan = ExprPlan( + mappings={ + W("a"): Ref(key=W("x")), + W("b"): Ref(key=W("y")), + } + ) assert plan.target_keys() == {W("a"), W("b")} @@ -386,9 +405,17 @@ def test_plan_summary(self): def test_plan_fuse(self): """Plan fuse applies optimizations.""" inner = Concat(exprs=(Ref(key=W("a")), Ref(key=W("b"))), dim=0) - plan = ExprPlan(mappings={ - W("out"): Concat(exprs=(inner, Ref(key=W("c")),), dim=0), - }) + plan = ExprPlan( + mappings={ + W("out"): Concat( + exprs=( + inner, + Ref(key=W("c")), + ), + dim=0, + ), + } + ) fused = plan.fuse() assert isinstance(fused[W("out")], Concat) @@ -532,9 +559,11 @@ class TestStreamingExecution: def test_execute_simple(self): """Execute simple plan.""" - plan = ExprPlan(mappings={ - W("out"): Ref(key=W("in")), - }) + plan = ExprPlan( + mappings={ + W("out"): Ref(key=W("in")), + } + ) sources = {W("in"): torch.tensor([1.0, 2.0, 3.0])} result = execute(plan, sources, seed=42) @@ -544,9 +573,11 @@ def test_execute_simple(self): def test_execute_concat(self): """Execute plan with Concat.""" - plan = ExprPlan(mappings={ - W("combined"): Concat(exprs=(Ref(key=W("a")), Ref(key=W("b"))), dim=0), - }) + plan = ExprPlan( + mappings={ + W("combined"): Concat(exprs=(Ref(key=W("a")), Ref(key=W("b"))), dim=0), + } + ) sources = { W("a"): torch.ones(2, 3), @@ -559,14 +590,19 @@ def test_execute_concat(self): def test_execute_mil_like(self): """Execute MIL-like Concat of Slices and Init.""" # Simulated MIL: in_proj = [z, x, B, C] - plan = ExprPlan(mappings={ - W("in_proj"): Concat(exprs=( - Init(shape=(4, 8), init_type="zeros"), # z - Slice(expr=Ref(key=W("v")), slices=((0, 2, None), (None, None, None))), # x - Slice(expr=Ref(key=W("k")), slices=((0, 2, None), (None, None, None))), # B - Slice(expr=Ref(key=W("q")), slices=((0, 4, None), (None, None, None))), # C - ), dim=0), - }) + plan = ExprPlan( + mappings={ + W("in_proj"): Concat( + exprs=( + Init(shape=(4, 8), init_type="zeros"), # z + Slice(expr=Ref(key=W("v")), slices=((0, 2, None), (None, None, None))), # x + Slice(expr=Ref(key=W("k")), slices=((0, 2, None), (None, None, None))), # B + Slice(expr=Ref(key=W("q")), slices=((0, 4, None), (None, None, None))), # C + ), + dim=0, + ), + } + ) sources = { W("q"): torch.ones(4, 8), @@ -583,11 +619,13 @@ def test_execute_mil_like(self): def test_streaming_execution(self): """Streaming executor processes all targets.""" - plan = ExprPlan(mappings={ - W("out1"): Ref(key=W("shared")), - W("out2"): Ref(key=W("shared")), - W("out3"): Ref(key=W("unique")), - }) + plan = ExprPlan( + mappings={ + W("out1"): Ref(key=W("shared")), + W("out2"): Ref(key=W("shared")), + W("out3"): Ref(key=W("unique")), + } + ) load_calls = [] @@ -858,25 +896,23 @@ def test_plan_dil_execution(self): key_dim = 64 value_dim = 64 - head_k_dim = 16 - head_v_dim = 16 conv_dim = 2 * key_dim + value_dim # 192 # Create attention weights with per-head distinctive values # Q: each head gets value (head_idx + 1) q_weight = torch.zeros(64, 64) for h in range(4): - q_weight[h*16:(h+1)*16, :] = float(h + 1) + q_weight[h * 16 : (h + 1) * 16, :] = float(h + 1) # K: each head gets value (head_idx + 1) * 10 k_weight = torch.zeros(64, 64) for h in range(4): - k_weight[h*16:(h+1)*16, :] = float((h + 1) * 10) + k_weight[h * 16 : (h + 1) * 16, :] = float((h + 1) * 10) # V: each head gets value (head_idx + 1) * 100 v_weight = torch.zeros(64, 64) for h in range(4): - v_weight[h*16:(h+1)*16, :] = float((h + 1) * 100) + v_weight[h * 16 : (h + 1) * 16, :] = float((h + 1) * 100) sources = { W("attn.q_proj.weight"): q_weight, @@ -894,30 +930,23 @@ def test_plan_dil_execution(self): # Q_all (rows 0-63): heads 0,1,2,3 concatenated for h in range(4): - assert torch.allclose( - in_proj_qkvz[h*16:(h+1)*16], - torch.full((16, 64), float(h + 1)) - ) + assert torch.allclose(in_proj_qkvz[h * 16 : (h + 1) * 16], torch.full((16, 64), float(h + 1))) # K_all (rows 64-127): heads 0,1,2,3 concatenated for h in range(4): assert torch.allclose( - in_proj_qkvz[key_dim + h*16:key_dim + (h+1)*16], - torch.full((16, 64), float((h + 1) * 10)) + in_proj_qkvz[key_dim + h * 16 : key_dim + (h + 1) * 16], torch.full((16, 64), float((h + 1) * 10)) ) # V_all (rows 128-191): heads 0,1,2,3 concatenated for h in range(4): assert torch.allclose( - in_proj_qkvz[2*key_dim + h*16:2*key_dim + (h+1)*16], - torch.full((16, 64), float((h + 1) * 100)) + in_proj_qkvz[2 * key_dim + h * 16 : 2 * key_dim + (h + 1) * 16], + torch.full((16, 64), float((h + 1) * 100)), ) # Z_all (rows 192-255): zeros - assert torch.allclose( - in_proj_qkvz[2*key_dim + value_dim:], - torch.zeros(value_dim, 64) - ) + assert torch.allclose(in_proj_qkvz[2 * key_dim + value_dim :], torch.zeros(value_dim, 64)) # in_proj_ba should be zeros in_proj_ba = result[W("in_proj_ba.weight")] @@ -971,17 +1000,17 @@ def test_plan_dil_execution_gqa(self): # Q: 4 heads, each with value (head_idx + 1) q_weight = torch.zeros(64, 64) for h in range(4): - q_weight[h*16:(h+1)*16, :] = float(h + 1) + q_weight[h * 16 : (h + 1) * 16, :] = float(h + 1) # K: 2 kv_heads, each with value (head_idx + 1) * 10 k_weight = torch.zeros(32, 64) for h in range(2): - k_weight[h*16:(h+1)*16, :] = float((h + 1) * 10) + k_weight[h * 16 : (h + 1) * 16, :] = float((h + 1) * 10) # V: 2 kv_heads, each with value (head_idx + 1) * 100 v_weight = torch.zeros(32, 64) for h in range(2): - v_weight[h*16:(h+1)*16, :] = float((h + 1) * 100) + v_weight[h * 16 : (h + 1) * 16, :] = float((h + 1) * 100) sources = { W("attn.q_proj.weight"): q_weight, @@ -1007,22 +1036,22 @@ def test_plan_dil_execution_gqa(self): # K_all (rows 32-63): k_heads 0,1 (maps to source K heads 0,1 via modulo) # k_head 0 โ†’ source K head 0 (value 10) - assert torch.allclose(in_proj_qkvz[key_dim:key_dim+16], torch.full((16, 64), 10.0)) + assert torch.allclose(in_proj_qkvz[key_dim : key_dim + 16], torch.full((16, 64), 10.0)) # k_head 1 โ†’ source K head 1 (value 20) - assert torch.allclose(in_proj_qkvz[key_dim+16:key_dim+32], torch.full((16, 64), 20.0)) + assert torch.allclose(in_proj_qkvz[key_dim + 16 : key_dim + 32], torch.full((16, 64), 20.0)) # V_all (rows 64-127): 4 v_heads, tiled from 2 source KV heads via modulo # v_head 0 โ†’ src_v_head 0 (value 100) - assert torch.allclose(in_proj_qkvz[2*key_dim:2*key_dim+16], torch.full((16, 64), 100.0)) + assert torch.allclose(in_proj_qkvz[2 * key_dim : 2 * key_dim + 16], torch.full((16, 64), 100.0)) # v_head 1 โ†’ src_v_head 1 (value 200) - assert torch.allclose(in_proj_qkvz[2*key_dim+16:2*key_dim+32], torch.full((16, 64), 200.0)) + assert torch.allclose(in_proj_qkvz[2 * key_dim + 16 : 2 * key_dim + 32], torch.full((16, 64), 200.0)) # v_head 2 โ†’ src_v_head 0 (value 100, tiled) - assert torch.allclose(in_proj_qkvz[2*key_dim+32:2*key_dim+48], torch.full((16, 64), 100.0)) + assert torch.allclose(in_proj_qkvz[2 * key_dim + 32 : 2 * key_dim + 48], torch.full((16, 64), 100.0)) # v_head 3 โ†’ src_v_head 1 (value 200, tiled) - assert torch.allclose(in_proj_qkvz[2*key_dim+48:2*key_dim+64], torch.full((16, 64), 200.0)) + assert torch.allclose(in_proj_qkvz[2 * key_dim + 48 : 2 * key_dim + 64], torch.full((16, 64), 200.0)) # Z_all (rows 128-191): zeros - assert torch.allclose(in_proj_qkvz[2*key_dim+value_dim:], torch.zeros(value_dim, 64)) + assert torch.allclose(in_proj_qkvz[2 * key_dim + value_dim :], torch.zeros(value_dim, 64)) def test_plan_kil_attention_to_kda(self): """AIK plan produces correct structure for attention โ†’ KDA conversion.""" @@ -1188,6 +1217,7 @@ def test_compose_llava_to_mamba(self, llava_pixtral_config, apriel2_config_stoch # Build surgery plan (need intermediate config) from fast_llm_external_models.apriel2.conversion.llava import convert_config + intermediate_config = convert_config(llava_pixtral_config) target_config = apriel2_config_stochastic.to_dict() surgery_plan = plan_surgery(intermediate_config, target_config) @@ -1210,6 +1240,7 @@ def test_execute_composed_pipeline(self, llava_pixtral_checkpoint): """ import json from pathlib import Path + from safetensors.torch import load_file # Load config @@ -1448,10 +1479,9 @@ def test_comprehensive_conversion_all_mixer_types(self, llava_pixtral_checkpoint the conversion produced correct keys and shapes. """ import json - from pathlib import Path from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config - from fast_llm_external_models.apriel2.convert import build_plan, convert + from fast_llm_external_models.apriel2.convert import convert from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForConditionalGeneration # Load LLaVA config @@ -1477,11 +1507,11 @@ def test_comprehensive_conversion_all_mixer_types(self, llava_pixtral_checkpoint "type": "pattern", "num_blocks": 5, "pattern": [ - "attn", # 0: attention โ†’ attention (passthrough) - "mamba", # 1: attention โ†’ mamba (MIL) - "gdn", # 2: attention โ†’ gated_delta_net (DIL) - "stoch_am", # 3: attention โ†’ stochastic(attention + mamba) - "stoch_sg", # 4: attention โ†’ stochastic(swa + gdn) + "attn", # 0: attention โ†’ attention (passthrough) + "mamba", # 1: attention โ†’ mamba (MIL) + "gdn", # 2: attention โ†’ gated_delta_net (DIL) + "stoch_am", # 3: attention โ†’ stochastic(attention + mamba) + "stoch_sg", # 4: attention โ†’ stochastic(swa + gdn) ], "blocks": { # Pure attention (passthrough from source) @@ -1609,7 +1639,8 @@ def test_comprehensive_conversion_all_mixer_types(self, llava_pixtral_checkpoint "type": "attention", "heads": llava_config["vision_config"]["num_attention_heads"], "head_groups": llava_config["vision_config"]["num_attention_heads"], - "head_size": llava_config["vision_config"]["hidden_size"] // llava_config["vision_config"]["num_attention_heads"], + "head_size": llava_config["vision_config"]["hidden_size"] + // llava_config["vision_config"]["num_attention_heads"], "add_linear_biases": False, "causal": False, "rotary": { @@ -1688,7 +1719,6 @@ def test_conversion_plan_targets_match_model_state_dict(self, llava_pixtral_conf This test validates the plan WITHOUT executing it, by comparing plan target keys against what the model expects. """ - import json from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config from fast_llm_external_models.apriel2.convert import build_plan @@ -1703,7 +1733,7 @@ def test_conversion_plan_targets_match_model_state_dict(self, llava_pixtral_conf expected_keys = set(model.state_dict().keys()) # Get plan target keys - plan_target_keys = set(str(k) for k in plan.target_keys()) + plan_target_keys = {str(k) for k in plan.target_keys()} # Compare missing_from_plan = expected_keys - plan_target_keys @@ -1711,3 +1741,214 @@ def test_conversion_plan_targets_match_model_state_dict(self, llava_pixtral_conf assert not missing_from_plan, f"Plan missing keys that model expects: {sorted(missing_from_plan)}" assert not extra_in_plan, f"Plan has extra keys model doesn't expect: {sorted(extra_in_plan)}" + + +class TestBiasPlanGeneration: + """Test that surgery plans correctly handle per-layer bias configurations. + + These tests verify that plan_surgery correctly includes/excludes bias + weight mappings based on the per-layer bias settings: + - query_layer.bias.enabled, key_layer.bias.enabled, etc. for attention + - layer_1.bias.enabled, layer_2.bias.enabled for MLP + """ + + @pytest.fixture + def source_config_with_bias(self): + """Source config with Qwen-style bias (QKV enabled, O disabled).""" + return { + "model_type": "apriel2", + "hidden_size": 256, + "vocab_size": 1000, + "decoder": { + "type": "fixed", + "num_blocks": 2, + "block": { + "mixer": { + "type": "attention", + "heads": 8, + "head_groups": 4, + "head_size": 32, + "rotary": {"type": "mistral_1d", "theta": 10000.0}, + # Qwen-style: QKV bias enabled, O bias disabled + "query_layer": {"bias": {"enabled": True}}, + "key_layer": {"bias": {"enabled": True}}, + "value_layer": {"bias": {"enabled": True}}, + "dense_layer": {"bias": {"enabled": False}}, + }, + "mlp": { + "type": "mlp", + "intermediate_size": 512, + "gated": False, + # Per-layer MLP bias: layer_1 enabled, layer_2 disabled + "layer_1": {"bias": {"enabled": True}}, + "layer_2": {"bias": {"enabled": False}}, + }, + "normalization": {"type": "rms_norm", "epsilon": 1e-5}, + }, + }, + } + + def test_plan_includes_enabled_attention_biases(self, source_config_with_bias): + """Surgery plan includes bias mappings for enabled attention biases.""" + from fast_llm_external_models.apriel2.conversion.config import compose_configs + from fast_llm_external_models.apriel2.conversion.converters import plan_surgery + + target_config = compose_configs( + source_config_with_bias, + { + "decoder": { + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"init": "transfer"}, + }, + }, + "mlp": {"init": "transfer"}, + }, + }, + }, + ) + + plan = plan_surgery(source_config_with_bias, target_config) + mapping_strs = [str(k) for k in plan.mappings.keys()] + + # Should have q_proj.bias, k_proj.bias, v_proj.bias mappings + q_bias = [m for m in mapping_strs if "q_proj.bias" in m] + k_bias = [m for m in mapping_strs if "k_proj.bias" in m] + v_bias = [m for m in mapping_strs if "v_proj.bias" in m] + + assert len(q_bias) > 0, "Should have q_proj.bias mappings" + assert len(k_bias) > 0, "Should have k_proj.bias mappings" + assert len(v_bias) > 0, "Should have v_proj.bias mappings" + + def test_plan_excludes_disabled_attention_biases(self, source_config_with_bias): + """Surgery plan excludes bias mappings for disabled attention biases.""" + from fast_llm_external_models.apriel2.conversion.config import compose_configs + from fast_llm_external_models.apriel2.conversion.converters import plan_surgery + + target_config = compose_configs( + source_config_with_bias, + { + "decoder": { + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"init": "transfer"}, + }, + }, + "mlp": {"init": "transfer"}, + }, + }, + }, + ) + + plan = plan_surgery(source_config_with_bias, target_config) + mapping_strs = [str(k) for k in plan.mappings.keys()] + + # Should NOT have o_proj.bias mappings (disabled) + o_bias = [m for m in mapping_strs if "o_proj.bias" in m] + assert len(o_bias) == 0, f"Should not have o_proj.bias mappings, found: {o_bias}" + + def test_plan_includes_enabled_mlp_biases(self, source_config_with_bias): + """Surgery plan includes bias mappings for enabled MLP biases.""" + from fast_llm_external_models.apriel2.conversion.config import compose_configs + from fast_llm_external_models.apriel2.conversion.converters import plan_surgery + + target_config = compose_configs( + source_config_with_bias, + { + "decoder": { + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"init": "transfer"}, + }, + }, + "mlp": {"init": "transfer"}, + }, + }, + }, + ) + + plan = plan_surgery(source_config_with_bias, target_config) + mapping_strs = [str(k) for k in plan.mappings.keys()] + + # Should have up_proj.bias (layer_1) mappings + up_bias = [m for m in mapping_strs if "up_proj.bias" in m] + assert len(up_bias) > 0, "Should have up_proj.bias mappings" + + def test_plan_excludes_disabled_mlp_biases(self, source_config_with_bias): + """Surgery plan excludes bias mappings for disabled MLP biases.""" + from fast_llm_external_models.apriel2.conversion.config import compose_configs + from fast_llm_external_models.apriel2.conversion.converters import plan_surgery + + target_config = compose_configs( + source_config_with_bias, + { + "decoder": { + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"init": "transfer"}, + }, + }, + "mlp": {"init": "transfer"}, + }, + }, + }, + ) + + plan = plan_surgery(source_config_with_bias, target_config) + mapping_strs = [str(k) for k in plan.mappings.keys()] + + # Should NOT have down_proj.bias (layer_2) mappings + down_bias = [m for m in mapping_strs if "down_proj.bias" in m] + assert len(down_bias) == 0, f"Should not have down_proj.bias mappings, found: {down_bias}" + + def test_plan_random_init_creates_init_expressions_for_bias(self, source_config_with_bias): + """Random init creates Init expressions for bias weights.""" + from fast_llm_external_models.apriel2.conversion.converters import plan_surgery + + # Surgery spec - pass directly to plan_surgery (NOT composed, to preserve init) + surgery = { + "decoder": { + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"init": "transfer"}, + "new_attention": { + "type": "attention", + "init": "random", # This triggers random init + "heads": 8, + "head_groups": 4, + "head_size": 32, + "rotary": {"type": "mistral_1d", "theta": 10000.0}, + "add_linear_biases": True, # All biases enabled + }, + }, + }, + }, + }, + } + + # Pass surgery spec directly - init fields are preserved + plan = plan_surgery(source_config_with_bias, surgery) + + # Check that new_attention biases use Init expressions + new_mixer_bias_keys = [k for k in plan.mappings.keys() if "new_attention" in str(k) and "bias" in str(k)] + + assert len(new_mixer_bias_keys) > 0, "Should have bias mappings for new_attention" + + for key in new_mixer_bias_keys: + expr = plan.mappings[key] + assert isinstance(expr, Init), f"{key} should be Init, got {type(expr)}" diff --git a/fast_llm_external_models/tests/test_apriel2/test_integration.py b/fast_llm_external_models/tests/test_apriel2/test_integration.py new file mode 100644 index 000000000..e84fa06ef --- /dev/null +++ b/fast_llm_external_models/tests/test_apriel2/test_integration.py @@ -0,0 +1,330 @@ +"""Integration tests for Qwen2 -> Apriel2 -> Fast-LLM conversion pipeline. + +Tests verify the full conversion chain: +1. Qwen2 -> Apriel2 (external module conversion) +2. Apriel2 + Surgery -> Supernet (stochastic mixer creation) +3. Supernet -> Fast-LLM -> Supernet (roundtrip through training format) + +Test Strategy: +- Use real HuggingFace model (Qwen2.5-0.5B) for meaningful validation +- Separate config preservation tests from numerical equivalence tests +- Parameterize both conversion stages AND input variations +- Single test implementation applied across all stages +""" + +import json +import tempfile +from pathlib import Path + +import pytest +import torch + +from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config +from fast_llm_external_models.apriel2.conversion import compose, compose_configs, execute, plan_surgery +from fast_llm_external_models.apriel2.conversion.expr import W +from fast_llm_external_models.apriel2.conversion.qwen2.config import convert_config as convert_qwen2_config +from fast_llm_external_models.apriel2.conversion.qwen2.plan import plan_qwen2_to_apriel2 +from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForCausalLM + +from .conftest import requires_fastllm + +# ============================================================================= +# Test Input Variations +# ============================================================================= + +TEST_INPUTS = pytest.mark.parametrize( + "prompts,max_new_tokens", + [ + pytest.param(["Hello world"], 10, id="single_short"), + pytest.param(["Hi", "The quick brown fox jumps over the lazy dog"], 20, id="batch_varied"), + pytest.param(["Once upon a time"], 50, id="long_generation"), + ], +) + + +# ============================================================================= +# Conversion Fixtures +# ============================================================================= + + +@pytest.fixture(scope="module") +def qwen2_source(): + """Load Qwen2.5-0.5B as the source/reference model.""" + from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer + + model_name = "Qwen/Qwen2.5-0.5B" + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32, trust_remote_code=True) + config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + model.eval() + + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "left" + + return { + "model": model, + "tokenizer": tokenizer, + "config_dict": config.to_dict(), + "state_dict": model.state_dict(), + } + + +@pytest.fixture(scope="module") +def apriel2_converted(qwen2_source): + """Stage 1: Qwen2 -> Apriel2.""" + config_dict = convert_qwen2_config(qwen2_source["config_dict"]) + plan = plan_qwen2_to_apriel2(qwen2_source["config_dict"]) + weights = execute(plan, {W(k): v for k, v in qwen2_source["state_dict"].items()}, seed=42) + + config = Apriel2Config(**config_dict) + model = Apriel2ForCausalLM(config) + model.load_state_dict({str(k): v for k, v in weights.items()}, strict=False) + model.eval() + + return {"model": model, "config_dict": config_dict, "plan": plan, "name": "Apriel2"} + + +@pytest.fixture(scope="module") +def supernet_converted(qwen2_source, apriel2_converted): + """Stage 2: Apriel2 + Surgery -> Supernet.""" + surgery_spec = { + "decoder": { + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"type": "attention", "init": "transfer"}, + "sliding_window": { + "type": "attention", + "init": "transfer", + "window_size": 4096, + }, + }, + }, + }, + }, + } + + apriel_config = apriel2_converted["config_dict"] + supernet_config = compose_configs(apriel_config, surgery_spec) + + full_plan = compose( + apriel2_converted["plan"], + plan_surgery(apriel_config, supernet_config), + ) + + weights = execute(full_plan, {W(k): v for k, v in qwen2_source["state_dict"].items()}, seed=42) + + config = Apriel2Config(**supernet_config) + model = Apriel2ForCausalLM(config) + model.load_state_dict({str(k): v for k, v in weights.items()}, strict=False) + model.eval() + + return {"model": model, "config_dict": supernet_config, "name": "Supernet"} + + +@pytest.fixture(scope="module") +def roundtrip_converted(supernet_converted, qwen2_source): + """Stage 3: Supernet -> Fast-LLM -> Supernet.""" + if not torch.cuda.is_available(): + pytest.skip("Roundtrip conversion requires CUDA (integration tests need realistic hardware)") + + from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, CheckpointSaveConfig, FastLLMCheckpointFormat + from fast_llm.engine.checkpoint.convert import ConvertConfig + from fast_llm.models.gpt.config import GPTModelConfig + from fast_llm.models.gpt.conversion.config import Apriel2TextCheckpointFormat + + with tempfile.TemporaryDirectory() as tmpdir: + tmpdir = Path(tmpdir) + supernet_path = tmpdir / "supernet" + fastllm_path = tmpdir / "fastllm" + roundtrip_path = tmpdir / "roundtrip" + + supernet_converted["model"].save_pretrained(supernet_path) + qwen2_source["tokenizer"].save_pretrained(supernet_path) + + ConvertConfig( + model=GPTModelConfig, + input=CheckpointLoadConfig(path=supernet_path, format=Apriel2TextCheckpointFormat), + output=CheckpointSaveConfig(path=fastllm_path, format=FastLLMCheckpointFormat), + ).run() + + ConvertConfig( + model=GPTModelConfig, + input=CheckpointLoadConfig(path=fastllm_path, format=FastLLMCheckpointFormat), + output=CheckpointSaveConfig(path=roundtrip_path, format=Apriel2TextCheckpointFormat), + ).run() + + model = Apriel2ForCausalLM.from_pretrained(roundtrip_path) + model.eval() + + with open(roundtrip_path / "config.json") as f: + config_dict = json.load(f) + + yield {"model": model, "config_dict": config_dict, "name": "Roundtrip"} + + +# ============================================================================= +# Parameterized Fixture: All Conversion Stages +# ============================================================================= + + +@pytest.fixture(params=["apriel2", "supernet", "roundtrip"]) +def converted_model(request, apriel2_converted, supernet_converted): + """Parameterized fixture providing each conversion stage for testing. + + This allows a single test to run against all stages automatically. + """ + if request.param == "roundtrip": + pytest.importorskip("fast_llm") + if not torch.cuda.is_available(): + pytest.skip("Roundtrip tests require CUDA (integration tests need realistic hardware)") + # Lazy-load to avoid fixture evaluation when CUDA unavailable + roundtrip_converted = request.getfixturevalue("roundtrip_converted") + return roundtrip_converted + + return { + "apriel2": apriel2_converted, + "supernet": supernet_converted, + }[request.param] + + +# ============================================================================= +# Config Preservation Tests +# ============================================================================= + + +@pytest.mark.slow +class TestConfigPreservation: + """Verify configs are correctly preserved through the conversion chain.""" + + def test_apriel2_structure(self, qwen2_source, apriel2_converted): + """Qwen2 -> Apriel2 preserves model dimensions.""" + qwen = qwen2_source["config_dict"] + apriel = apriel2_converted["config_dict"] + + assert apriel["hidden_size"] == qwen["hidden_size"] + assert apriel["vocab_size"] == qwen["vocab_size"] + assert apriel["decoder"]["num_blocks"] == qwen["num_hidden_layers"] + + def test_apriel2_bias_pattern(self, apriel2_converted): + """Qwen2 -> Apriel2 preserves Qwen-style bias (QKV yes, O no).""" + mixer = apriel2_converted["config_dict"]["decoder"]["block"]["mixer"] + + assert mixer["query_layer"]["bias"]["enabled"] is True + assert mixer["key_layer"]["bias"]["enabled"] is True + assert mixer["value_layer"]["bias"]["enabled"] is True + assert mixer["dense_layer"]["bias"]["enabled"] is False + + def test_supernet_structure(self, supernet_converted): + """Surgery creates correct stochastic mixer structure.""" + mixer = supernet_converted["config_dict"]["decoder"]["block"]["mixer"] + + assert mixer["type"] == "stochastic" + assert mixer["main_mixer_name"] == "attention" + assert set(mixer["mixers"].keys()) == {"attention", "sliding_window"} + + def test_supernet_bias_inheritance(self, supernet_converted): + """Submixers inherit bias settings from source.""" + mixer = supernet_converted["config_dict"]["decoder"]["block"]["mixer"] + + for name in ["attention", "sliding_window"]: + assert mixer["mixers"][name]["query_layer"]["bias"]["enabled"] is True + assert mixer["mixers"][name]["dense_layer"]["bias"]["enabled"] is False + + @requires_fastllm + def test_roundtrip_structure(self, roundtrip_converted): + """Fast-LLM roundtrip preserves stochastic mixer structure.""" + mixer = roundtrip_converted["config_dict"]["decoder"]["block"]["mixer"] + + assert mixer["type"] == "stochastic" + assert mixer["main_mixer_name"] == "attention" + assert set(mixer["mixers"].keys()) == {"attention", "sliding_window"} + + @requires_fastllm + def test_roundtrip_bias_preservation(self, roundtrip_converted): + """Fast-LLM roundtrip preserves per-layer bias settings.""" + mixer = roundtrip_converted["config_dict"]["decoder"]["block"]["mixer"] + + for name in ["attention", "sliding_window"]: + assert mixer["mixers"][name]["query_layer"]["bias"]["enabled"] is True + assert mixer["mixers"][name]["dense_layer"]["bias"]["enabled"] is False + + +# ============================================================================= +# Numerical Equivalence Tests +# ============================================================================= + + +@pytest.mark.slow +class TestNumericalEquivalence: + """Verify all conversion stages produce numerically identical outputs. + + Uses parameterized fixtures to test all stages with all input variations, + giving us 3 stages ร— 3 inputs = 9 test cases from a single test function. + """ + + @TEST_INPUTS + def test_logits_match(self, qwen2_source, converted_model, prompts, max_new_tokens): + """Converted model produces identical logits to source.""" + tokenizer = qwen2_source["tokenizer"] + ref_model = qwen2_source["model"] + test_model = converted_model["model"] + stage = converted_model["name"] + + inputs = tokenizer(prompts, return_tensors="pt", padding=True) + ref_device = next(ref_model.parameters()).device + test_device = next(test_model.parameters()).device + + with torch.no_grad(): + ref_logits = ref_model( + input_ids=inputs.input_ids.to(ref_device), + attention_mask=inputs.attention_mask.to(ref_device), + ).logits.cpu() + + test_logits = test_model( + input_ids=inputs.input_ids.to(test_device), + attention_mask=inputs.attention_mask.to(test_device), + ).logits.cpu() + + max_diff = (ref_logits - test_logits).abs().max().item() + assert torch.allclose( + ref_logits, test_logits, rtol=1e-4, atol=1e-4 + ), f"{stage} logits mismatch: max diff = {max_diff:.6f}" + + @TEST_INPUTS + def test_generation_match(self, qwen2_source, converted_model, prompts, max_new_tokens): + """Converted model produces identical generation to source.""" + tokenizer = qwen2_source["tokenizer"] + ref_model = qwen2_source["model"] + test_model = converted_model["model"] + stage = converted_model["name"] + + inputs = tokenizer(prompts, return_tensors="pt", padding=True) + ref_device = next(ref_model.parameters()).device + test_device = next(test_model.parameters()).device + + with torch.no_grad(): + ref_gen = ref_model.generate( + input_ids=inputs.input_ids.to(ref_device), + attention_mask=inputs.attention_mask.to(ref_device), + max_new_tokens=max_new_tokens, + do_sample=False, + pad_token_id=tokenizer.pad_token_id, + ).cpu() + + test_gen = test_model.generate( + input_ids=inputs.input_ids.to(test_device), + attention_mask=inputs.attention_mask.to(test_device), + max_new_tokens=max_new_tokens, + do_sample=False, + pad_token_id=tokenizer.pad_token_id, + ).cpu() + + assert torch.equal(ref_gen, test_gen), ( + f"{stage} generation mismatch:\n" + f" Reference: {tokenizer.batch_decode(ref_gen, skip_special_tokens=True)}\n" + f" Test: {tokenizer.batch_decode(test_gen, skip_special_tokens=True)}" + ) diff --git a/fast_llm_external_models/tests/test_apriel2/test_mixer_equivalence.py b/fast_llm_external_models/tests/test_apriel2/test_mixer_equivalence.py index 1aa8a56d9..c6f3337e8 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_mixer_equivalence.py +++ b/fast_llm_external_models/tests/test_apriel2/test_mixer_equivalence.py @@ -28,15 +28,7 @@ import torch import torch.nn as nn -from fast_llm_external_models.apriel2.conversion import ( - Concat, - ExprPlan, - Ref, - Slice, - W, - execute, -) - +from fast_llm_external_models.apriel2.conversion import Concat, ExprPlan, Ref, Slice, W, execute # ============================================================================= # Shared Fixtures @@ -69,10 +61,10 @@ def hidden_size(request): @pytest.fixture( params=[ - pytest.param((8, 8, 32), id="mha-8h-32d"), # MHA: 8 heads, 8 kv heads, 32 head_dim - pytest.param((8, 4, 32), id="gqa-8h4kv-32d"), # GQA: 8 heads, 4 kv heads, 32 head_dim - pytest.param((8, 2, 64), id="gqa-8h2kv-64d"), # GQA: 8 heads, 2 kv heads, 64 head_dim - pytest.param((4, 1, 64), id="mqa-4h1kv-64d"), # MQA: 4 heads, 1 kv head, 64 head_dim + pytest.param((8, 8, 32), id="mha-8h-32d"), # MHA: 8 heads, 8 kv heads, 32 head_dim + pytest.param((8, 4, 32), id="gqa-8h4kv-32d"), # GQA: 8 heads, 4 kv heads, 32 head_dim + pytest.param((8, 2, 64), id="gqa-8h2kv-64d"), # GQA: 8 heads, 2 kv heads, 64 head_dim + pytest.param((4, 1, 64), id="mqa-4h1kv-64d"), # MQA: 4 heads, 1 kv head, 64 head_dim ] ) def attention_config(request): @@ -90,7 +82,7 @@ def attention_config(request): params=[ pytest.param((8, 4, 32, 32), id="8v-4k-32d"), # 8 value heads, 4 key heads, symmetric dims pytest.param((8, 2, 64, 64), id="8v-2k-64d"), # 8 value heads, 2 key heads, larger dims - pytest.param((4, 2, 32, 64), id="4v-2k-asym"), # Asymmetric key/value dims + pytest.param((4, 2, 32, 64), id="4v-2k-asym"), # Asymmetric key/value dims ] ) def gdn_config(request): @@ -100,9 +92,9 @@ def gdn_config(request): @pytest.fixture( params=[ - pytest.param((4, 8), id="4h-8d"), # 4 heads, 8 head_dim (small) - pytest.param((8, 16), id="8h-16d"), # 8 heads, 16 head_dim (medium) - pytest.param((4, 32), id="4h-32d"), # 4 heads, 32 head_dim (large head_dim) + pytest.param((4, 8), id="4h-8d"), # 4 heads, 8 head_dim (small) + pytest.param((8, 16), id="8h-16d"), # 8 heads, 16 head_dim (medium) + pytest.param((4, 32), id="4h-32d"), # 4 heads, 32 head_dim (large head_dim) ] ) def kda_config(request): @@ -283,9 +275,21 @@ def plan_qwen3next_gdn_to_apriel2( for g in range(num_k_heads): base = g * group_size q_slices.append(Slice(expr=qkvz_ref, slices=((base, base + head_k_dim, None), (None, None, None)))) - k_slices.append(Slice(expr=qkvz_ref, slices=((base + head_k_dim, base + 2 * head_k_dim, None), (None, None, None)))) - v_slices.append(Slice(expr=qkvz_ref, slices=((base + 2 * head_k_dim, base + 2 * head_k_dim + v_per_group, None), (None, None, None)))) - z_slices.append(Slice(expr=qkvz_ref, slices=((base + 2 * head_k_dim + v_per_group, base + group_size, None), (None, None, None)))) + k_slices.append( + Slice(expr=qkvz_ref, slices=((base + head_k_dim, base + 2 * head_k_dim, None), (None, None, None))) + ) + v_slices.append( + Slice( + expr=qkvz_ref, + slices=((base + 2 * head_k_dim, base + 2 * head_k_dim + v_per_group, None), (None, None, None)), + ) + ) + z_slices.append( + Slice( + expr=qkvz_ref, + slices=((base + 2 * head_k_dim + v_per_group, base + group_size, None), (None, None, None)), + ) + ) in_proj_qkvz_expr = Concat( exprs=( @@ -304,8 +308,15 @@ def plan_qwen3next_gdn_to_apriel2( b_slices, a_slices = [], [] for g in range(num_k_heads): base = g * ba_per_group - b_slices.append(Slice(expr=ba_ref, slices=((base, base + num_v_heads // num_k_heads, None), (None, None, None)))) - a_slices.append(Slice(expr=ba_ref, slices=((base + num_v_heads // num_k_heads, base + ba_per_group, None), (None, None, None)))) + b_slices.append( + Slice(expr=ba_ref, slices=((base, base + num_v_heads // num_k_heads, None), (None, None, None))) + ) + a_slices.append( + Slice( + expr=ba_ref, + slices=((base + num_v_heads // num_k_heads, base + ba_per_group, None), (None, None, None)), + ) + ) in_proj_ba_expr = Concat( exprs=(Concat(exprs=tuple(b_slices), dim=0), Concat(exprs=tuple(a_slices), dim=0)), @@ -565,6 +576,7 @@ def test_causal_vs_mistral( ): """Verify Apriel2Attention (causal) matches MistralAttention output.""" from transformers.models.mistral.modeling_mistral import MistralAttention, MistralRotaryEmbedding + from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2Attention mixer_config = apriel2_config.decoder["block"]["mixer"] @@ -593,13 +605,20 @@ def test_causal_vs_mistral( apriel2_attn.eval() with torch.no_grad(): - mistral_out = mistral_attn(hidden_states, position_embeddings=position_embeddings, attention_mask=causal_mask)[0] - apriel2_out = apriel2_attn(hidden_states, attention_mask=causal_mask, position_embeddings=position_embeddings)[0] + mistral_out = mistral_attn( + hidden_states, position_embeddings=position_embeddings, attention_mask=causal_mask + )[0] + apriel2_out = apriel2_attn( + hidden_states, attention_mask=causal_mask, position_embeddings=position_embeddings + )[0] rtol, atol = tolerance assert_close( - apriel2_out, mistral_out, rtol=rtol, atol=atol, - msg=f"Apriel2Attention vs MistralAttention (batch={batch_size}, seq={seq_len}, hidden={hidden_size})" + apriel2_out, + mistral_out, + rtol=rtol, + atol=atol, + msg=f"Apriel2Attention vs MistralAttention (batch={batch_size}, seq={seq_len}, hidden={hidden_size})", ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA") @@ -613,8 +632,9 @@ def test_noncausal_vs_pixtral( tolerance, ): """Verify Apriel2Attention (non-causal) matches PixtralAttention output.""" - from transformers.models.pixtral.modeling_pixtral import PixtralAttention, PixtralRotaryEmbedding from transformers.models.pixtral.configuration_pixtral import PixtralVisionConfig + from transformers.models.pixtral.modeling_pixtral import PixtralAttention, PixtralRotaryEmbedding + from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2TextConfig from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2Attention @@ -689,8 +709,11 @@ def test_noncausal_vs_pixtral( rtol, atol = tolerance assert_close( - apriel2_out, pixtral_out, rtol=rtol, atol=atol, - msg=f"Apriel2Attention (non-causal) vs PixtralAttention (batch={batch_size}, seq={seq_len})" + apriel2_out, + pixtral_out, + rtol=rtol, + atol=atol, + msg=f"Apriel2Attention (non-causal) vs PixtralAttention (batch={batch_size}, seq={seq_len})", ) @@ -737,6 +760,7 @@ def test_vs_qwen3next( ): """Verify Apriel2GatedDeltaNet matches Qwen3NextGatedDeltaNet output.""" from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextGatedDeltaNet + from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2GatedDeltaNet value_heads, key_heads, key_head_dim, value_head_dim = gdn_config @@ -758,8 +782,10 @@ def test_vs_qwen3next( # Transfer weights plan = plan_qwen3next_gdn_to_apriel2( - num_k_heads=key_heads, num_v_heads=value_heads, - head_k_dim=key_head_dim, head_v_dim=value_head_dim, + num_k_heads=key_heads, + num_v_heads=value_heads, + head_k_dim=key_head_dim, + head_v_dim=value_head_dim, ) source_weights = extract_module_weights(qwen_gdn) target_weights = execute(plan, source_weights, seed=seed) @@ -778,8 +804,11 @@ def test_vs_qwen3next( rtol, atol = tolerance assert_close( - apriel2_out, qwen_out, rtol=rtol, atol=atol, - msg=f"Apriel2GatedDeltaNet vs Qwen3NextGatedDeltaNet (batch={batch_size}, seq={seq_len})" + apriel2_out, + qwen_out, + rtol=rtol, + atol=atol, + msg=f"Apriel2GatedDeltaNet vs Qwen3NextGatedDeltaNet (batch={batch_size}, seq={seq_len})", ) @@ -803,6 +832,7 @@ def test_vs_fla( ): """Verify Apriel2 KimiDeltaAttention matches FLA KimiDeltaAttention output.""" from fla.layers.kda import KimiDeltaAttention as FLA_KDA + from fast_llm_external_models.apriel2.modeling_apriel2 import KimiDeltaAttention as Apriel2_KDA num_heads, head_dim = kda_config @@ -853,8 +883,11 @@ def test_vs_fla( rtol, atol = tolerance assert_close( - apriel2_out, fla_out, rtol=rtol, atol=atol, - msg=f"Apriel2 KDA vs FLA KDA (batch={batch_size}, seq={seq_len}, hidden={hidden_size})" + apriel2_out, + fla_out, + rtol=rtol, + atol=atol, + msg=f"Apriel2 KDA vs FLA KDA (batch={batch_size}, seq={seq_len}, hidden={hidden_size})", ) @@ -913,7 +946,4 @@ def test_gdn_fast_vs_slow(self, gdn_config, batch_size): slow_out = model(hidden_states)[0].clone() # Looser tolerance for kernel vs reference comparison - assert_close( - fast_out, slow_out, rtol=1e-3, atol=1e-3, - msg="GDN fast path (CUDA) vs slow path (PyTorch)" - ) + assert_close(fast_out, slow_out, rtol=1e-3, atol=1e-3, msg="GDN fast path (CUDA) vs slow path (PyTorch)") diff --git a/fast_llm_external_models/tests/test_apriel2/test_model_structure.py b/fast_llm_external_models/tests/test_apriel2/test_model_structure.py index 23856be30..56d2bc6a6 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_model_structure.py +++ b/fast_llm_external_models/tests/test_apriel2/test_model_structure.py @@ -1,9 +1,9 @@ """Tests for Apriel2 model structure and architecture validation.""" -import pytest import torch -from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForCausalLM + from fast_llm_external_models.apriel2.cache import Apriel2Cache, _AttentionCache, _SSMCache +from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForCausalLM class TestStochasticMixerStructure: @@ -14,20 +14,27 @@ def test_all_submixers_present(self, apriel2_config_all_mixers): model = Apriel2ForCausalLM(apriel2_config_all_mixers) stochastic_layer = model.model.decoder.blocks[1] # Layer 1 is the "all_mixers" layer - assert hasattr(stochastic_layer.mixer, 'mixers'), "Stochastic mixer should have 'mixers' attribute" + assert hasattr(stochastic_layer.mixer, "mixers"), "Stochastic mixer should have 'mixers' attribute" assert set(stochastic_layer.mixer.mixers.keys()) == { - 'attention', 'swa', 'mamba', 'gdn' + "attention", + "swa", + "mamba", + "gdn", }, "Stochastic mixer should contain all 4 configured mixer types" # Verify each mixer is the correct type from fast_llm_external_models.apriel2.modeling_apriel2 import ( - Apriel2Attention, Apriel2Mamba, Apriel2GatedDeltaNet + Apriel2Attention, + Apriel2GatedDeltaNet, + Apriel2Mamba, ) - assert isinstance(stochastic_layer.mixer.mixers['attention'], Apriel2Attention) - assert isinstance(stochastic_layer.mixer.mixers['swa'], Apriel2Attention) # SWA is Apriel2Attention with sliding_window - assert isinstance(stochastic_layer.mixer.mixers['mamba'], Apriel2Mamba) - assert isinstance(stochastic_layer.mixer.mixers['gdn'], Apriel2GatedDeltaNet) + assert isinstance(stochastic_layer.mixer.mixers["attention"], Apriel2Attention) + assert isinstance( + stochastic_layer.mixer.mixers["swa"], Apriel2Attention + ) # SWA is Apriel2Attention with sliding_window + assert isinstance(stochastic_layer.mixer.mixers["mamba"], Apriel2Mamba) + assert isinstance(stochastic_layer.mixer.mixers["gdn"], Apriel2GatedDeltaNet) def test_main_mixer_is_configured(self, apriel2_config_all_mixers): """Verify main_mixer_name is set correctly.""" @@ -44,7 +51,10 @@ def test_cache_has_all_submixer_slots(self, apriel2_config_all_mixers): assert isinstance(layer_cache, dict), "Stochastic layer cache should be a dict" assert set(layer_cache.keys()) == { - 'attention', 'swa', 'mamba', 'gdn' + "attention", + "swa", + "mamba", + "gdn", }, "Cache should have slots for all 4 mixers" def test_attention_mixers_use_attention_cache(self, apriel2_config_all_mixers): @@ -53,12 +63,12 @@ def test_attention_mixers_use_attention_cache(self, apriel2_config_all_mixers): layer_cache = cache.layers[1] # Attention-based mixers use AttentionCache - assert isinstance(layer_cache['attention'], _AttentionCache) - assert isinstance(layer_cache['swa'], _AttentionCache) + assert isinstance(layer_cache["attention"], _AttentionCache) + assert isinstance(layer_cache["swa"], _AttentionCache) # SSM-based mixers use SSMCache - assert isinstance(layer_cache['mamba'], _SSMCache) - assert isinstance(layer_cache['gdn'], _SSMCache) + assert isinstance(layer_cache["mamba"], _SSMCache) + assert isinstance(layer_cache["gdn"], _SSMCache) def test_parameter_counts_differ_by_config(self): """Different configs create models with different parameter counts.""" @@ -74,8 +84,10 @@ def test_parameter_counts_differ_by_config(self): } config_tiny = Apriel2Config( - vocab_size=100, hidden_size=64, - num_attention_heads=4, num_key_value_heads=2, + vocab_size=100, + hidden_size=64, + num_attention_heads=4, + num_key_value_heads=2, decoder={ "type": "fixed", "num_blocks": 2, @@ -88,8 +100,10 @@ def test_parameter_counts_differ_by_config(self): ) config_stochastic = Apriel2Config( - vocab_size=100, hidden_size=64, - num_attention_heads=4, num_key_value_heads=2, + vocab_size=100, + hidden_size=64, + num_attention_heads=4, + num_key_value_heads=2, decoder={ "type": "pattern", "num_blocks": 2, @@ -106,14 +120,14 @@ def test_parameter_counts_differ_by_config(self): "main_mixer_name": "attention", "mixers": { "attention": attn_config, - "mamba": {"type": "mamba", "conv_bias": True, "dt_proj_bias": True} - } + "mamba": {"type": "mamba", "conv_bias": True, "dt_proj_bias": True}, + }, }, "mlp": {"type": "mlp", "intermediate_size": 256, "gated": True}, "normalization": {"type": "rms_norm"}, - } - } - } + }, + }, + }, ) model_tiny = Apriel2ForCausalLM(config_tiny) @@ -122,8 +136,9 @@ def test_parameter_counts_differ_by_config(self): params_tiny = sum(p.numel() for p in model_tiny.parameters()) params_stochastic = sum(p.numel() for p in model_stochastic.parameters()) - assert params_stochastic > params_tiny, \ - "Stochastic mixer should have more parameters (has both attention and mamba)" + assert ( + params_stochastic > params_tiny + ), "Stochastic mixer should have more parameters (has both attention and mamba)" def test_weights_are_initialized(self, apriel2_config_all_mixers): """Verify model weights are initialized (not all zeros/constant).""" @@ -136,9 +151,7 @@ def test_weights_are_initialized(self, apriel2_config_all_mixers): # Basic sanity: at least some parameters should be non-zero non_zero_params = sum( - not torch.all(p == 0) - for mixer in stochastic_layer.mixer.mixers.values() - for p in mixer.parameters() + not torch.all(p == 0) for mixer in stochastic_layer.mixer.mixers.values() for p in mixer.parameters() ) assert non_zero_params > 0, "At least some mixer parameters should be non-zero" diff --git a/fast_llm_external_models/tests/test_apriel2/test_modeling.py b/fast_llm_external_models/tests/test_apriel2/test_modeling.py index 5dbd36159..8e2f610bb 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_modeling.py +++ b/fast_llm_external_models/tests/test_apriel2/test_modeling.py @@ -2,18 +2,23 @@ import pytest import torch + from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForCausalLM class TestApriel2Modeling: """End-to-end tests for Apriel2 model with different configurations.""" - @pytest.mark.parametrize("config_name", [ - "apriel2_config_tiny", - "apriel2_config_stochastic", - "apriel2_config_multi_mixer", - "apriel2_config_all_mixers" # Tests all 4 mixer types - ]) + @pytest.mark.parametrize( + "config_name", + [ + "apriel2_config_tiny", + "apriel2_config_stochastic", + "apriel2_config_multi_mixer", + "apriel2_config_all_mixers", # Tests all 4 mixer types + "apriel2_config_with_bias", # Tests per-layer bias and non-gated MLP + ], + ) def test_model_end_to_end(self, config_name, request): """Test instantiation, forward pass, cache correctness, and generation. @@ -42,7 +47,7 @@ def test_model_end_to_end(self, config_name, request): # 2. Forward pass - basic shape validation outputs = model(input_ids, use_cache=False) assert outputs.logits.shape == (2, seq_len, config.vocab_size) - assert hasattr(outputs, 'logits') + assert hasattr(outputs, "logits") # 3. Verify cache is actually being used (not dormant) split_pos = 30 @@ -52,28 +57,23 @@ def test_model_end_to_end(self, config_name, request): assert outputs_part1.past_key_values is not None outputs_correct_cache = model( - input_ids[:, split_pos:split_pos+1], - past_key_values=outputs_part1.past_key_values, - use_cache=True + input_ids[:, split_pos : split_pos + 1], past_key_values=outputs_part1.past_key_values, use_cache=True ) # Test 1: Empty cache should give different results than filled cache # This verifies cache is being used at all from fast_llm_external_models.apriel2.cache import Apriel2Cache, _AttentionCache + empty_cache = Apriel2Cache(config) outputs_empty_cache = model( - input_ids[:, split_pos:split_pos+1], - past_key_values=empty_cache, - use_cache=True + input_ids[:, split_pos : split_pos + 1], past_key_values=empty_cache, use_cache=True ) - cache_affects_output = not torch.allclose( - outputs_correct_cache.logits, - outputs_empty_cache.logits, - atol=1e-3 - ) - assert cache_affects_output, f"Cache appears dormant for {config_name} - empty cache gives same results as filled cache" + cache_affects_output = not torch.allclose(outputs_correct_cache.logits, outputs_empty_cache.logits, atol=1e-3) + assert ( + cache_affects_output + ), f"Cache appears dormant for {config_name} - empty cache gives same results as filled cache" # Test 2: Corrupted cache (zeros) should give different results than correct cache # This verifies the actual cache VALUES are being used @@ -98,17 +98,15 @@ def test_model_end_to_end(self, config_name, request): corrupted_layer[name].value = torch.zeros_like(correct_sub.value) outputs_corrupted_cache = model( - input_ids[:, split_pos:split_pos+1], - past_key_values=corrupted_cache, - use_cache=True + input_ids[:, split_pos : split_pos + 1], past_key_values=corrupted_cache, use_cache=True ) cache_values_matter = not torch.allclose( - outputs_correct_cache.logits, - outputs_corrupted_cache.logits, - atol=1e-3 + outputs_correct_cache.logits, outputs_corrupted_cache.logits, atol=1e-3 ) - assert cache_values_matter, f"Cache values not used for {config_name} - zeroed cache gives same results as correct cache" + assert ( + cache_values_matter + ), f"Cache values not used for {config_name} - zeroed cache gives same results as correct cache" # 4. Cache correctness - validate cache produces same results as no-cache # Compute full sequence without cache @@ -117,18 +115,14 @@ def test_model_end_to_end(self, config_name, request): # Compute in two steps with cache outputs_part1 = model(input_ids[:, :split_pos], use_cache=True) outputs_part2 = model( - input_ids[:, split_pos:split_pos+1], - past_key_values=outputs_part1.past_key_values, - use_cache=True + input_ids[:, split_pos : split_pos + 1], past_key_values=outputs_part1.past_key_values, use_cache=True ) # Logits should match between cached and non-cached # Note: GPU execution with bfloat16/float16 has lower precision than CPU float32, # so we use a looser tolerance here. assert torch.allclose( - outputs_full.logits[:, split_pos, :], - outputs_part2.logits[:, 0, :], - atol=1e-3 + outputs_full.logits[:, split_pos, :], outputs_part2.logits[:, 0, :], atol=1e-3 ), f"Cache correctness failed for {config_name}: cached and non-cached logits differ" # 5. Generation - end-to-end validation diff --git a/fast_llm_external_models/tests/test_apriel2/test_plan_execution.py b/fast_llm_external_models/tests/test_apriel2/test_plan_execution.py new file mode 100644 index 000000000..ca0c8739f --- /dev/null +++ b/fast_llm_external_models/tests/test_apriel2/test_plan_execution.py @@ -0,0 +1,598 @@ +"""test_plan_execution.py - Plan execution and algebraic composition laws. + +This module provides rigorous, parameterized tests for the mathematical properties +that the conversion system must satisfy. Each test class corresponds to one +algebraic structure, and each test method verifies one specific law. + +Conceptual Types +================ + +The conversion system operates on three conceptual types (all ``dict`` at runtime): + +- **S (State)**: Complete config without ``init`` fields +- **P (Partial Surgery)**: Incomplete config, may have ``init`` fields +- **T (Transition Spec)**: Complete config WITH ``init`` fields + +Algebraic Structures +==================== + +1. **Partial Surgeries (P)** form a **Monoid** under deep merge:: + + compose_configs : P ร— P โ†’ P + Identity: {} + Associativity: (p1 โˆ˜ p2) โˆ˜ p3 = p1 โˆ˜ (p2 โˆ˜ p3) + +2. **Surgeries act on States** to produce Transition Specs:: + + compose_configs : S ร— P โ†’ T + compose_configs : T ร— P โ†’ T + + Action law (additive surgeries): (s ยท p1) ยท p2 = s ยท (p1 โˆ˜ p2) + +3. **Plans** form a **Category** with composition:: + + compose : Plan(Aโ†’B) ร— Plan(Bโ†’C) โ†’ Plan(Aโ†’C) + Associativity: (P1 โˆ˜ P2) โˆ˜ P3 = P1 โˆ˜ (P2 โˆ˜ P3) + +4. **plan_surgery is a Functor** from config pairs to plans:: + + plan_surgery : S ร— T โ†’ Plan + Functoriality: compose(plan(S,T1), plan(T1,T2)) โ‰ก plan(S,T2) + + This is semantic equivalence: both produce identical weights when executed. + +Important Behaviors Tested +========================== + +- **init stripping**: Between surgery iterations, T โ†’ S conversion via + ``strip_init_fields()`` ensures ``init: random`` applies only to the surgery + that introduces a component. + +- **Bias inheritance**: Per-layer bias settings propagate through surgery chains. + +- **Plan composition**: Composed plans produce identical weights to direct plans. + +Design Principles +================= + +- Each law gets ONE parameterized test, not multiple similar tests +- Fixtures provide diverse configs (with/without biases) +- Corner cases are covered via parameterization, not test proliferation +- Tests document the laws they verify in their docstrings +""" + +from functools import reduce + +import pytest +import torch + +from fast_llm_external_models.apriel2.conversion import ( + Concat, + ExprPlan, + Init, + Ref, + Slice, + W, + compose, + compose_configs, + execute, + plan_surgery, +) + +# Import shared helper from conftest +from fast_llm_external_models.tests.test_apriel2.conftest import make_weights_for_config + +# ============================================================================= +# Fixtures: Use shared fixtures from conftest.py where possible +# ============================================================================= +# - base_config_dict: Complete config without biases (Llama-style) +# - base_config_with_bias_dict: Complete config with QKV biases +# - additive_surgery_chain: [wrap_stochastic, add_sliding_window, add_gdn] +# ============================================================================= + + +# ============================================================================= +# Test: Plan Composition Associativity +# ============================================================================= + + +class TestPlanCompositionAssociativity: + """ + LAW: Plan composition is associative. + + (Pโ‚ โˆ˜ Pโ‚‚) โˆ˜ Pโ‚ƒ = Pโ‚ โˆ˜ (Pโ‚‚ โˆ˜ Pโ‚ƒ) + + where โˆ˜ denotes compose(P1, P2). + + This must hold for the AST structure, not just semantic equivalence. + """ + + @pytest.mark.parametrize("expr_type", ["ref_chain", "with_concat", "with_slice", "with_init"]) + def test_associativity(self, expr_type): + """Plan composition is associative for various expression types.""" + # Build three plans that can be composed + if expr_type == "ref_chain": + p1 = ExprPlan(mappings={W("b"): Ref(key=W("a"))}) + p2 = ExprPlan(mappings={W("c"): Ref(key=W("b"))}) + p3 = ExprPlan(mappings={W("d"): Ref(key=W("c"))}) + elif expr_type == "with_concat": + p1 = ExprPlan(mappings={W("x"): Ref(key=W("a")), W("y"): Ref(key=W("b"))}) + p2 = ExprPlan(mappings={W("xy"): Concat(exprs=(Ref(key=W("x")), Ref(key=W("y"))), dim=0)}) + p3 = ExprPlan(mappings={W("final"): Ref(key=W("xy"))}) + elif expr_type == "with_slice": + p1 = ExprPlan(mappings={W("full"): Ref(key=W("src"))}) + p2 = ExprPlan(mappings={W("part"): Slice(expr=Ref(key=W("full")), slices=((0, 5, None),))}) + p3 = ExprPlan(mappings={W("out"): Ref(key=W("part"))}) + elif expr_type == "with_init": + p1 = ExprPlan(mappings={W("x"): Ref(key=W("a"))}) + p2 = ExprPlan( + mappings={W("y"): Concat(exprs=(Ref(key=W("x")), Init(shape=(5,), init_type="zeros")), dim=0)} + ) + p3 = ExprPlan(mappings={W("z"): Ref(key=W("y"))}) + + left = compose(compose(p1, p2), p3) + right = compose(p1, compose(p2, p3)) + + assert left.mappings == right.mappings, f"Associativity failed for {expr_type}" + + +# ============================================================================= +# Test: Functoriality of plan_surgery (THE CRITICAL PROPERTY) +# ============================================================================= + + +class TestPlanSurgeryFunctoriality: + """ + LAW: plan_surgery is functorial with respect to config composition. + + For a surgery chain Pโ‚, Pโ‚‚, ..., Pโ‚™ applied to base state Sโ‚€:: + + Tโ‚ = compose_configs(Sโ‚€, Pโ‚) # S ร— P โ†’ T + Tโ‚‚ = compose_configs(Tโ‚, Pโ‚‚) # T ร— P โ†’ T (no stripping!) + ... + Tโ‚™ = compose_configs(Tโ‚™โ‚‹โ‚, Pโ‚™) + + Plan functoriality says:: + + compose(plan(Sโ‚€,Tโ‚), plan(Tโ‚,Tโ‚‚), ...) โ‰ก plan(Sโ‚€, Tโ‚™) + + where โ‰ก denotes semantic equivalence (identical weights when executed). + + NOTE: This tests T ร— P composition WITHOUT stripping between steps. + This differs from build_plan which strips (T โ†’ S) between iterations. + Both patterns are valid: + + - Without stripping: init fields accumulate, testing plan composition purity + - With stripping: init consumed per-step, testing real usage (see + test_build_plan_strips_init_between_iterations) + + The functoriality law holds in both cases because plan composition + correctly substitutes Ref expressions with their definitions. + """ + + @pytest.mark.parametrize("chain_length", [1, 2, 3]) + @pytest.mark.parametrize("use_bias", [True, False]) + def test_functoriality( + self, + chain_length, + use_bias, + base_config_dict, + base_config_with_bias_dict, + additive_surgery_chain, + ): + """ + Composed incremental plans produce same weights as direct plan. + + Parameterized over: + - chain_length: Number of surgeries (1, 2, or 3) + - use_bias: Whether base config has biases + """ + base_config = base_config_with_bias_dict if use_bias else base_config_dict + surgeries = additive_surgery_chain[:chain_length] + + # Build config chain: Cโ‚€ โ†’ Cโ‚ โ†’ ... โ†’ Cโ‚™ + configs = [base_config] + for s in surgeries: + configs.append(compose_configs(configs[-1], s)) + + # Build incremental plans: Pโ‚– = plan_surgery(Cโ‚–โ‚‹โ‚, Cโ‚–) + plans = [plan_surgery(configs[i], configs[i + 1]) for i in range(len(surgeries))] + + # Compose all incremental plans + composed_plan = reduce(compose, plans) + + # Build direct plan: plan_surgery(Cโ‚€, Cโ‚™) + direct_plan = plan_surgery(configs[0], configs[-1]) + + # Execute both on same weights + weights = make_weights_for_config(base_config) + composed_weights = execute(composed_plan, weights, seed=42) + direct_weights = execute(direct_plan, weights, seed=42) + + # Verify semantic equivalence + assert set(composed_weights.keys()) == set( + direct_weights.keys() + ), f"Key sets differ for chain_length={chain_length}, use_bias={use_bias}" + + for key in composed_weights: + assert torch.allclose( + composed_weights[key], direct_weights[key], atol=1e-6 + ), f"Weight mismatch for {key} with chain_length={chain_length}, use_bias={use_bias}" + + @pytest.mark.parametrize("split_point", [1, 2]) + def test_arbitrary_grouping( + self, + split_point, + base_config_with_bias_dict, + additive_surgery_chain, + ): + """ + Any grouping of surgery chain produces same result. + + For surgeries [Sโ‚, Sโ‚‚, Sโ‚ƒ], tests that: + - compose(Pโ‚, compose(Pโ‚‚, Pโ‚ƒ)) + - compose(compose(Pโ‚, Pโ‚‚), Pโ‚ƒ) + - plan_surgery(Cโ‚€, Cโ‚ƒ) + + all produce identical weights. + """ + surgeries = additive_surgery_chain + + # Build config chain + configs = [base_config_with_bias_dict] + for s in surgeries: + configs.append(compose_configs(configs[-1], s)) + + # Build incremental plans + plans = [plan_surgery(configs[i], configs[i + 1]) for i in range(3)] + + # Different groupings + left_grouped = compose(compose(plans[0], plans[1]), plans[2]) + right_grouped = compose(plans[0], compose(plans[1], plans[2])) + direct = plan_surgery(configs[0], configs[-1]) + + # Execute all + weights = make_weights_for_config(base_config_with_bias_dict) + results = { + "left": execute(left_grouped, weights, seed=42), + "right": execute(right_grouped, weights, seed=42), + "direct": execute(direct, weights, seed=42), + } + + # All must match + keys = set(results["left"].keys()) + assert keys == set(results["right"].keys()) == set(results["direct"].keys()) + + for key in keys: + assert torch.allclose(results["left"][key], results["right"][key], atol=1e-6) + assert torch.allclose(results["left"][key], results["direct"][key], atol=1e-6) + + +# ============================================================================= +# Test: Bias Inheritance Preservation (Regression for the specific bug) +# ============================================================================= + + +class TestBiasInheritancePreservation: + """ + PROPERTY: Per-layer bias settings must be preserved through surgery chains. + + When a surgery spec does not mention bias settings, they must be inherited + from the source config. This is the specific failure mode of the build_plan + bug: passing partial surgery specs to plan_surgery lost inherited fields. + + This test verifies the SYMPTOM (missing biases) rather than the LAW + (functoriality). It's kept as a focused regression test. + """ + + @pytest.mark.parametrize("num_surgeries", [1, 2, 3]) + def test_qkv_biases_preserved_through_chain( + self, + num_surgeries, + base_config_with_bias_dict, + additive_surgery_chain, + ): + """QKV biases (enabled in source) appear in plan after N surgeries.""" + surgeries = additive_surgery_chain[:num_surgeries] + + # Build config and plan chain + configs = [base_config_with_bias_dict] + for s in surgeries: + configs.append(compose_configs(configs[-1], s)) + + plans = [plan_surgery(configs[i], configs[i + 1]) for i in range(num_surgeries)] + final_plan = reduce(compose, plans) if len(plans) > 1 else plans[0] + + # Check bias keys present + target_keys = {str(k) for k in final_plan.target_keys()} + + assert any("q_proj.bias" in k for k in target_keys), f"q_proj.bias missing after {num_surgeries} surgeries" + assert any("k_proj.bias" in k for k in target_keys), f"k_proj.bias missing after {num_surgeries} surgeries" + assert any("v_proj.bias" in k for k in target_keys), f"v_proj.bias missing after {num_surgeries} surgeries" + # O bias should NOT be present (disabled in source) + assert not any( + "o_proj.bias" in k for k in target_keys + ), f"o_proj.bias should not be present (disabled in source)" + + def test_bias_values_preserved( + self, + base_config_with_bias_dict, + additive_surgery_chain, + ): + """Bias tensor values are correctly transferred, not just keys.""" + surgery = additive_surgery_chain[0] # wrap_stochastic + c1 = compose_configs(base_config_with_bias_dict, surgery) + plan = plan_surgery(base_config_with_bias_dict, c1) + + weights = make_weights_for_config(base_config_with_bias_dict) + result = execute(plan, weights, seed=42) + + # Verify values match (not just that keys exist) + for i in range(base_config_with_bias_dict["decoder"]["num_blocks"]): + src_key = W(f"model.decoder.blocks.{i}.mixer.q_proj.bias") + dst_key = W(f"model.decoder.blocks.{i}.mixer.mixers.attention.q_proj.bias") + + assert dst_key in result, f"Missing {dst_key}" + assert torch.allclose(weights[src_key], result[dst_key]), f"Bias values differ for block {i}" + + +# ============================================================================= +# Test: build_plan Integration (Regression test for convert.py) +# ============================================================================= + + +class TestBuildPlanIntegration: + """ + REGRESSION: build_plan must compose configs before calling plan_surgery. + + The bug was: + plan_surgery(current_config, surgery_config) # WRONG: partial + + Should be: + target = compose_configs(current_config, surgery_config) + plan_surgery(current_config, target) # CORRECT: complete + + This test verifies the fix in convert.py's build_plan function. + """ + + @pytest.mark.parametrize("num_surgeries", [1, 2]) + def test_build_plan_preserves_inherited_fields( + self, + num_surgeries, + base_config_with_bias_dict, + additive_surgery_chain, + ): + """build_plan produces plans with inherited bias mappings.""" + from fast_llm_external_models.apriel2.convert import build_plan + + surgeries = additive_surgery_chain[:num_surgeries] + + plan, final_config = build_plan( + base_config_with_bias_dict, + surgeries, + source_format="apriel2", + ) + + # Verify inherited biases in config + if num_surgeries >= 1: + attn = final_config["decoder"]["block"]["mixer"]["mixers"]["attention"] + assert attn.get("query_layer", {}).get("bias", {}).get("enabled") is True + + # Verify bias mappings in plan + target_keys = {str(k) for k in plan.target_keys()} + assert any( + "q_proj.bias" in k for k in target_keys + ), f"build_plan with {num_surgeries} surgeries missing q_proj.bias" + + +# ============================================================================= +# Test: init Field Preservation (Critical for random initialization) +# ============================================================================= + + +class TestInitFieldPreservation: + """ + PROPERTY: The `init` field must be visible to plan_surgery. + + The `init` field controls weight initialization mode: + - `init: transfer` โ†’ use weight transfer/conversion + - `init: random` โ†’ use random initialization + + compose_configs must preserve `init` so plan_surgery can see it. + Stripping happens only at final output (when saving to disk). + """ + + def test_init_random_produces_init_expression(self, base_config_with_bias_dict): + """Surgery with init: random produces Init expressions in plan.""" + surgery = { + "decoder": { + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"init": "transfer"}, + "gdn": {"type": "gdn", "init": "random", "convolution_layer": {"kernel_size": 4}}, + }, + }, + }, + }, + } + + target = compose_configs(base_config_with_bias_dict, surgery) + plan = plan_surgery(base_config_with_bias_dict, target) + + # Check that GDN weights use Init expressions (random init) + target_keys = {str(k) for k in plan.target_keys()} + gdn_keys = [k for k in target_keys if "gdn" in k.lower()] + + assert len(gdn_keys) > 0, "No GDN keys in plan" + + # Verify at least one GDN weight uses Init (random initialization) + has_init_expr = False + for key in plan.target_keys(): + if "gdn" in str(key).lower(): + expr = plan.mappings[key] + if isinstance(expr, Init): + has_init_expr = True + break + # Also check inside Concat/other composite expressions + if hasattr(expr, "exprs"): + for sub in expr.exprs: + if isinstance(sub, Init): + has_init_expr = True + break + + assert has_init_expr, "init: random should produce Init expressions for GDN weights" + + def test_init_transfer_produces_ref_expression(self, base_config_with_bias_dict): + """Surgery with init: transfer produces Ref expressions (weight transfer).""" + surgery = { + "decoder": { + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"init": "transfer"}, + }, + }, + }, + }, + } + + target = compose_configs(base_config_with_bias_dict, surgery) + plan = plan_surgery(base_config_with_bias_dict, target) + + # Check that attention weights use Ref expressions (transfer) + has_ref = False + for key in plan.target_keys(): + if "attention" in str(key) and "q_proj.weight" in str(key): + expr = plan.mappings[key] + if isinstance(expr, Ref): + has_ref = True + break + + assert has_ref, "init: transfer should produce Ref expressions for attention weights" + + def test_build_plan_respects_init_random(self, base_config_with_bias_dict): + """build_plan correctly uses init: random for weight initialization.""" + from fast_llm_external_models.apriel2.convert import build_plan + + # Mamba requires many config fields for random init + surgery = { + "decoder": { + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"init": "transfer"}, + "mamba": { + "type": "mamba", + "init": "random", + "d_inner": 512, + "d_state": 16, + "dt_rank": 16, + "d_xb": 64, + "d_conv": 4, + "repeat_kv_before_conv": False, + "conv_bias": True, + "dt_proj_bias": True, + "dt_min": 0.001, + "dt_max": 0.1, + "dt_init_floor": 1e-4, + }, + }, + }, + }, + }, + } + + plan, final_config = build_plan( + base_config_with_bias_dict, + [surgery], + source_format="apriel2", + ) + + # Verify mamba weights use Init (random init) + has_mamba_init = False + for key in plan.target_keys(): + key_str = str(key) + if "mamba" in key_str: + expr = plan.mappings[key] + if isinstance(expr, Init): + has_mamba_init = True + break + + assert has_mamba_init, "build_plan should use Init for init: random components" + + def test_build_plan_strips_init_between_iterations(self, base_config_with_bias_dict): + """build_plan strips init between iterations (T โ†’ S conversion). + + This tests that the intermediate state between surgeries has no init fields. + The composed plan will show Init expressions because plan composition + substitutes Ref โ†’ Init, but the semantics are correct: GDN is initialized + once (in surgery 1), not re-randomized in surgery 2. + """ + from fast_llm_external_models.apriel2.conversion import compose_configs, plan_surgery, strip_init_fields + + # Surgery 1: Add GDN with random init + surgery1 = { + "decoder": { + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"init": "transfer"}, + "gdn": { + "type": "gdn", + "init": "random", + "convolution_layer": {"kernel_size": 4}, + }, + }, + }, + }, + }, + } + + # Surgery 2: Add sliding window (doesn't mention GDN) + surgery2 = { + "decoder": { + "block": { + "mixer": { + "mixers": { + "sliding_window": {"init": "transfer", "window_size": 512}, + }, + }, + }, + }, + } + + # Simulate build_plan's iteration loop + s0 = base_config_with_bias_dict + + # Iteration 1 + t1 = compose_configs(s0, surgery1) + assert t1["decoder"]["block"]["mixer"]["mixers"]["gdn"].get("init") == "random" + s1 = strip_init_fields(t1) + assert s1["decoder"]["block"]["mixer"]["mixers"]["gdn"].get("init") is None + + # Iteration 2: s1 has no init for GDN + t2 = compose_configs(s1, surgery2) + assert ( + t2["decoder"]["block"]["mixer"]["mixers"]["gdn"].get("init") is None + ), "GDN should have no init in T2 (wasn't in surgery2, stripped from s1)" + + # plan_surgery(s1, t2) should use Ref for GDN (transfer, not random) + plan2 = plan_surgery(s1, t2) + gdn_uses_ref = False + for key in plan2.target_keys(): + if "gdn" in str(key): + expr = plan2.mappings[key] + if isinstance(expr, Ref): + gdn_uses_ref = True + break + + assert gdn_uses_ref, "plan_surgery(s1, t2) should use Ref for GDN (transfer from s1)" diff --git a/setup.py b/setup.py index b273e077e..5c4d0def6 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,6 @@ -import sys -import re import pathlib +import re +import sys try: import pybind11 @@ -18,6 +18,7 @@ print(f"Error: setuptools version {_SETUPTOOLS_MIN_VERSION} " "or greater is required") sys.exit(1) + def get_version(): """Read version from fast_llm/__init__.py""" init_file = pathlib.Path(__file__).parent.joinpath("fast_llm", "__init__.py").read_text() @@ -26,6 +27,7 @@ def get_version(): return version_match.group(1) raise RuntimeError("Unable to find version string in fast_llm/__init__.py") + cpp_extension = setuptools.Extension( "fast_llm.csrc.data", sources=["fast_llm/csrc/data.cpp"], diff --git a/tests/data/test_tokenizer.py b/tests/data/test_tokenizer.py index c7fdef9ca..4e9e2fdd5 100644 --- a/tests/data/test_tokenizer.py +++ b/tests/data/test_tokenizer.py @@ -40,3 +40,263 @@ def test_tokenize_with_spans(common_tokenizer, spans, expected_token_spans, expe expected_token_spans = [(begin + 1, end + 1) for begin, end in expected_token_spans] Assert.eq(tokens.tolist(), expected_tokens) Assert.eq(token_spans, expected_token_spans) + + +def test_validate_chat_template_no_template(common_tokenizer): + """Tokenizer without chat template raises.""" + with pytest.raises(ValueError, match="does not have a chat template"): + common_tokenizer.validate_chat_template() + + +def test_validate_chat_template_no_markers(common_tokenizer): + """Tokenizer with chat template but no markers raises.""" + common_tokenizer.tokenizer.chat_template = "{{ messages }}" + with pytest.raises(ValueError, match="does not contain.*generation"): + common_tokenizer.validate_chat_template() + + +def test_validate_chat_template_with_markers(common_tokenizer): + """Tokenizer with generation markers validates.""" + common_tokenizer.tokenizer.chat_template = "{% generation %}{{ m }}{% endgeneration %}" + common_tokenizer.validate_chat_template() + + +# Realistic chat template following HF conventions (e.g., SmolLM3): +# The generation block includes the full assistant turn: opening tag, content, and closing tag. +# This ensures the model learns to emit the closing tag. +CHAT_TEMPLATE = ( + "{% for message in messages %}" + "{% if message.role == 'assistant' %}" + "{% generation %}{{ message.content }}{% endgeneration %}" + "{% else %}" + "<{{ message.role }}>{{ message.content }}" + "{% endif %}" + "{% endfor %}" +) + + +@pytest.mark.parametrize( + ("messages", "expected_tokens", "expected_loss_masking_spans"), + ( + # Single turn: full assistant turn (Hello) is trainable + # 15 tokens, trainable indices 7-13, loss mask spans cover 0-6 and 14 + ( + [{"role": "user", "content": "Hi"}, {"role": "assistant", "content": "Hello"}], + [49152, 27, 789, 29, 16946, 750, 789, 2293, 17822, 29, 7371, 750, 17822, 29, 49152], + [(0, 7), (14, 15)], + ), + # Multi-turn: both assistant turns are fully trainable + # 27 tokens, trainable indices 7-13 and 19-25 + ( + [ + {"role": "user", "content": "A"}, + {"role": "assistant", "content": "B"}, + {"role": "user", "content": "C"}, + {"role": "assistant", "content": "D"}, + ], + [ + 49152, + 27, + 789, + 29, + 32, + 750, + 789, + 2293, + 17822, + 29, + 33, + 750, + 17822, + 2293, + 789, + 29, + 34, + 750, + 789, + 2293, + 17822, + 29, + 35, + 750, + 17822, + 29, + 49152, + ], + [(0, 7), (14, 19), (26, 27)], + ), + # System + user + assistant: full assistant turn trainable + # 23 tokens, trainable indices 15-21 + ( + [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello"}, + ], + [ + 49152, + 27, + 3144, + 29, + 5815, + 1139, + 44569, + 6928, + 3144, + 2293, + 789, + 29, + 16946, + 750, + 789, + 2293, + 17822, + 29, + 7371, + 750, + 17822, + 29, + 49152, + ], + [(0, 15), (22, 23)], + ), + # User only: no trainable tokens + # 9 tokens, no trainable indices + ( + [{"role": "user", "content": "Hi"}], + [49152, 27, 789, 29, 16946, 750, 789, 29, 49152], + [(0, 9)], + ), + # Long multi-turn (85 tokens, 3 assistant responses with tags, tests span machinery) + # Trainable: indices 27-40, 49-62, 70-83 + ( + [ + {"role": "system", "content": "You are a helpful assistant that answers questions."}, + {"role": "user", "content": "What is the capital of France?"}, + {"role": "assistant", "content": "The capital of France is Paris."}, + {"role": "user", "content": "What about Germany?"}, + {"role": "assistant", "content": "The capital of Germany is Berlin."}, + {"role": "user", "content": "And Italy?"}, + {"role": "assistant", "content": "The capital of Italy is Rome."}, + ], + [ + 49152, + 27, + 3144, + 29, + 5815, + 1139, + 373, + 44569, + 2424, + 11886, + 954, + 15737, + 14516, + 6928, + 3144, + 2293, + 789, + 29, + 13938, + 438, + 331, + 25016, + 457, + 12409, + 562, + 35838, + 789, + 2293, + 17822, + 29, + 2111, + 25016, + 457, + 12409, + 562, + 438, + 4235, + 280, + 6928, + 17822, + 2293, + 789, + 29, + 13938, + 5028, + 759, + 42226, + 35838, + 789, + 2293, + 17822, + 29, + 2111, + 25016, + 457, + 759, + 42226, + 438, + 29784, + 3556, + 6928, + 17822, + 2293, + 789, + 29, + 1996, + 4413, + 3326, + 35838, + 789, + 2293, + 17822, + 29, + 2111, + 25016, + 457, + 4413, + 3326, + 438, + 613, + 1361, + 6928, + 17822, + 29, + 49152, + ], + [(0, 27), (41, 49), (63, 70), (84, 85)], + ), + ), +) +def test_tokenize_chat(common_tokenizer, messages, expected_tokens, expected_loss_masking_spans): + common_tokenizer.tokenizer.chat_template = CHAT_TEMPLATE + tokens, loss_masking_spans = common_tokenizer.tokenize_chat(messages) + Assert.eq(tokens.tolist(), expected_tokens) + Assert.eq(loss_masking_spans, expected_loss_masking_spans) + + +@pytest.mark.parametrize( + ("train_mask", "expected_loss_spans"), + ( + # All masked (no trainable tokens) + ([False, False, False], [(0, 3)]), + # All trainable (no spans) + ([True, True, True], []), + # Single trainable at start + ([True, False, False], [(1, 3)]), + # Single trainable at end + ([False, False, True], [(0, 2)]), + # Single trainable in middle + ([False, True, False], [(0, 1), (2, 3)]), + # Multiple trainable regions (simulates multi-turn conversation) + ([False, False, True, True, False, False, True, True, True, False], [(0, 2), (4, 6), (9, 10)]), + # Alternating + ([False, True, False, True, False], [(0, 1), (2, 3), (4, 5)]), + ), +) +def test_train_mask_to_loss_spans(train_mask, expected_loss_spans): + from fast_llm.data.preprocessing.tokenizer import _train_mask_to_loss_spans + + Assert.eq(_train_mask_to_loss_spans(train_mask), expected_loss_spans)