diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md
index 50c5a2c1c..a09f78c6c 100644
--- a/.github/ISSUE_TEMPLATE/feature_request.md
+++ b/.github/ISSUE_TEMPLATE/feature_request.md
@@ -8,26 +8,26 @@ assignees: ''
---
# 🎯 **Goal (What & Why)**
-> **Clearly state the purpose of this feature.**
+> **Clearly state the purpose of this feature.**
> _(Example: Add FP8 support using torchao to improve training throughput by 1.5x.)_
# 🚀 **Execution Plan**
-> _(This section may start as an incomplete draft but must be defined before implementation begins.)_
+> _(This section may start as an incomplete draft but must be defined before implementation begins.)_
### **Step 1: What is the smallest working version?**
-> _(Describe the simplest way to implement this feature with minimal effort.)_
+> _(Describe the simplest way to implement this feature with minimal effort.)_
-### **Step 2: What additional optimizations are possible (but optional)?**
-> _(List potential refinements that can be added in later PRs if needed.)_
+### **Step 2: What additional optimizations are possible (but optional)?**
+> _(List potential refinements that can be added in later PRs if needed.)_
# 📌 **Acceptance Criteria** (Must-Haves for Completion)
-* The feature must be **functional and tested**.
-* The implementation must be **documented in practical terms**.
-* The PR must include a **performance/impact summary**.
-* **No refactors unless directly necessary** for feature completion.
+* The feature must be **functional and tested**.
+* The implementation must be **documented in practical terms**.
+* The PR must include a **performance/impact summary**.
+* **No refactors unless directly necessary** for feature completion.
# 🛠️ **Project Management**
- [ ] **Assign the project to the Fast-LLM project.**
- [ ] **Set the `Estimate` field (in days) in the GitHub project.**
- [ ] **Use the `Size` field to categorize the PR size (Small/Medium/Large).**
-- [ ] **Assign an owner when opening the issue.**
+- [ ] **Assign an owner when opening the issue.**
diff --git a/fast_llm/core/distributed.py b/fast_llm/core/distributed.py
index 2a200688a..4dcc53d55 100644
--- a/fast_llm/core/distributed.py
+++ b/fast_llm/core/distributed.py
@@ -29,10 +29,15 @@
logger = logging.getLogger(__name__)
-def add_ephemeral_timeout(group: ProcessGroup, timeout: float | None = None) -> None:
+@contextlib.contextmanager
+def set_timeout(group: ProcessGroup | None, timeout: float | None = None):
if group is not None and timeout is not None:
- # TODO: Only works for nccl?
- group._add_ephemeral_timeout(datetime.timedelta(seconds=timeout))
+ timeout_ = group.options._timeout
+ group.set_timeout(datetime.timedelta(seconds=timeout))
+ yield
+ group.set_timeout(timeout_)
+ else:
+ yield
def broadcast(
@@ -43,8 +48,8 @@ def broadcast(
opts = torch.distributed.BroadcastOptions()
opts.rootRank = src
opts.rootTensor = 0
- add_ephemeral_timeout(group, timeout)
- work = group.broadcast([tensor], opts)
+ with set_timeout(group, timeout):
+ work = group.broadcast([tensor], opts)
if async_op:
return work
else:
@@ -55,7 +60,7 @@ def broadcast(
def check_parallel_match(tensor: torch.Tensor, group: ProcessGroup | None, name: str) -> None:
# A utility function to check for tensor-parallel (or other) mismatches.
all_tensors = tensor.new_empty((group.size(),) + tensor.shape)
- all_gather_into_tensor(all_tensors, tensor, group)
+ all_gather_into_tensor(all_tensors, tensor.unsqueeze(0), group)
mismatches = (all_tensors != tensor).any(dim=0)
num_mismatches = mismatches.sum().item()
@@ -84,8 +89,8 @@ def allreduce_scalar(
) -> float | int:
if group:
value = torch.full([1], value, dtype=dtype, device=torch.cuda.current_device())
- add_ephemeral_timeout(group, timeout)
- torch.distributed.all_reduce(value, op=op, group=group)
+ with set_timeout(group, timeout):
+ torch.distributed.all_reduce(value, op=op, group=group)
return value.item()
else:
return value
@@ -99,9 +104,9 @@ def all_gather_scalar(
):
if group:
value = torch.full([1], value, dtype=dtype, device=torch.cuda.current_device())
- add_ephemeral_timeout(group, timeout)
output_tensor = value.new_empty((group.size(),))
- torch.distributed.all_gather_into_tensor(output_tensor, value, group=group)
+ with set_timeout(group, timeout):
+ torch.distributed.all_gather_into_tensor(output_tensor, value, group=group)
return output_tensor.tolist()
else:
return value
@@ -147,6 +152,12 @@ def broadcast_object(input_object: typing.Any | None, group: ProcessGroup | None
def send(tensor: torch.Tensor, dst: int, group: ProcessGroup, async_op=False, tag: int = 0) -> Work | None:
assert group is not None
+ if isinstance(group, torch.distributed.ProcessGroupGloo) and tensor.device.type != "cpu":
+ # send not supported for gloo on GPU.
+ tensor_cpu = tensor.cpu()
+ group.send([tensor_cpu], dst, tag).wait()
+ tensor.copy_(tensor_cpu)
+ return None
work = group.send([tensor], dst, tag)
if async_op:
return work
@@ -157,6 +168,12 @@ def send(tensor: torch.Tensor, dst: int, group: ProcessGroup, async_op=False, ta
def recv(tensor: torch.Tensor, src: int, group: ProcessGroup, async_op=False, tag: int = 0) -> Work | None:
assert group is not None
+ if isinstance(group, torch.distributed.ProcessGroupGloo) and tensor.device.type != "cpu":
+ # recv not supported for gloo on GPU.
+ tensor_cpu = tensor.cpu()
+ group.recv([tensor_cpu], src, tag).wait()
+ tensor.copy_(tensor_cpu)
+ return None
work = group.recv([tensor], src, tag)
if async_op:
return work
diff --git a/fast_llm/data/data/data_loader.py b/fast_llm/data/data/data_loader.py
new file mode 100644
index 000000000..ba7e5e612
--- /dev/null
+++ b/fast_llm/data/data/data_loader.py
@@ -0,0 +1,72 @@
+import itertools
+import typing
+
+import torch.utils.data
+
+from fast_llm.core.distributed import broadcast_object
+
+
+class SampledDatasetIterator(torch.utils.data.Sampler):
+ """
+ A distributed sampler generating indices for a `SampledDataset` (i.e., the natural numbers).
+ To be used as the `batch_sampler` of a `torch.utils.data.DataLoader`.
+ """
+
+ def __init__(self, total_samples, begin_index, micro_batch_size, data_rank, data_parallel):
+ super().__init__()
+ self._total_samples = total_samples
+ self._begin_index = begin_index
+ self._batch_size = micro_batch_size * data_parallel
+ self._start_idx = data_rank * micro_batch_size
+ self._end_idx = (data_rank + 1) * micro_batch_size
+
+ def __len__(self) -> int:
+ return self._total_samples
+
+ def __iter__(self) -> typing.Iterator[list[int]]:
+ for idx in range(self._begin_index, self._total_samples - self._batch_size + 1, self._batch_size):
+ yield list(range(idx + self._start_idx, idx + self._end_idx))
+
+
+class DistributedDataLoaderWrapper:
+ """
+ Wraps a regular dataloader so that only the process group leader
+ loads data, and then broadcasts the batch to other ranks in the group.
+ """
+
+ def __init__(
+ self,
+ data_loader: torch.utils.data.dataloader.DataLoader,
+ process_group: torch.distributed.ProcessGroup | None,
+ ):
+ self._data_loader = data_loader
+ self._rank = 0 if process_group is None else process_group.rank()
+ self._process_group = process_group
+
+ def __iter__(self):
+ if self._rank == 0:
+ self._iterator = iter(self._data_loader)
+ else:
+ self._iterator = itertools.repeat(None)
+ if self._process_group is None:
+ return self._iterator
+ return self
+
+ def __next__(self):
+ # TODO:
+ # Instead of broadcasting a general object, make this iterator yield an actual Batch class.
+ # Implement `get_state_dict` and `from_state_dict` in the Batch class so that we can
+ # efficiently broadcast tensors directly. This avoids using `broadcast_object` on the
+ # entire Batch object, which is inefficient for tensors because it serializes
+ # (pickles) them before sending.
+
+ try:
+ data = next(self._iterator) # may raise StopIteration
+ except Exception as e:
+ data = e
+ data = broadcast_object(data, self._process_group, 0)
+
+ if isinstance(data, Exception):
+ raise data
+
+ return data
diff --git a/fast_llm/data/data/data_loader_wrapper.py b/fast_llm/data/data/data_loader_wrapper.py
deleted file mode 100644
index f9e517248..000000000
--- a/fast_llm/data/data/data_loader_wrapper.py
+++ /dev/null
@@ -1,52 +0,0 @@
-import torch.distributed
-import torch.utils.data.dataloader
-
-from fast_llm.core.distributed import broadcast_object
-
-
-class DistributedDataLoaderWrapper:
- """
- Wraps a regular dataloader so that only the process group leader
- loads data, and then broadcasts the batch to other ranks in the group.
- """
-
- def __init__(
- self,
- dataloader: torch.utils.data.dataloader.DataLoader | None,
- rank: int,
- process_group: torch.distributed.ProcessGroup | None,
- ):
- self.dataloader = dataloader
- self.rank = rank
- self.process_group = process_group
-
- assert (self.rank == 0 and self.dataloader is not None) or (self.rank > 0 and self.dataloader is None)
-
- def __iter__(self):
- if self.rank == 0:
- self.iterator = iter(self.dataloader)
- if self.process_group is None:
- return self.iterator
- return self
-
- def __next__(self):
- # TODO:
- # Instead of broadcasting a general object, make this iterator yield an actual Batch class.
- # Implement `get_state_dict` and `from_state_dict` in the Batch class so that we can
- # efficiently broadcast tensors directly. This avoids using `broadcast_object` on the
- # entire Batch object, which is inefficient for tensors because it serializes
- # (pickles) them before sending.
-
- if self.rank == 0:
- try:
- data = next(self.iterator) # may raise StopIteration
- except Exception as e:
- data = e
- data = broadcast_object(data, self.process_group, 0)
- else:
- data = broadcast_object(None, self.process_group, 0)
-
- if isinstance(data, Exception):
- raise data
-
- return data
diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py
index 70966a051..3af86652a 100644
--- a/fast_llm/data/data/gpt/data.py
+++ b/fast_llm/data/data/gpt/data.py
@@ -8,14 +8,12 @@
from fast_llm.core.distributed import safe_barrier
from fast_llm.data.data.abstract import Data
-from fast_llm.data.data.data_loader_wrapper import DistributedDataLoaderWrapper
+from fast_llm.data.data.data_loader import DistributedDataLoaderWrapper, SampledDatasetIterator
from fast_llm.data.data.gpt.config import GPTDataConfig
from fast_llm.data.dataset.abstract import SampledDataset
-from fast_llm.data.dataset.abstract_iterable import SampledIterableDataset
from fast_llm.data.dataset.config import SamplingParameters
from fast_llm.data.dataset.gpt.config import GPTSamplingData
from fast_llm.data.dataset.monitor import DatasetMonitor
-from fast_llm.data.iterator import SampledDatasetIterator
from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig
from fast_llm.data.sample.language_model import LanguageModelBatch
from fast_llm.engine.config_utils.run import log_main_rank
@@ -92,12 +90,7 @@ def setup(
dataset_name=dataset_name,
)
dataset = self._config.datasets[dataset_name].build_and_sample(sampling)
- if isinstance(dataset, SampledDataset):
- self._datasets[dataset_name] = DatasetMonitor(dataset, self._config.data_sample_warn_time_ms)
- else:
- # Do not set monitor for iterable dataset as monitor only works with map style datasets
- assert isinstance(dataset, SampledIterableDataset)
- self._datasets[dataset_name] = dataset
+ self._datasets[dataset_name] = DatasetMonitor(dataset, self._config.data_sample_warn_time_ms)
safe_barrier(self._distributed.world_group, "data_preparation", timeout)
self._is_setup = True
@@ -123,45 +116,23 @@ def get_iterator(
Assert.in_range_incl(batch_config.sequence_length, 1, sampling_parameters.sequence_length)
log_main_rank(f"Initializing {dataset_name} dataset iterator from sample {consumed_samples}...")
- dataset = self._datasets[dataset_name]
-
- if isinstance(dataset, SampledDataset):
- data_loader = torch.utils.data.DataLoader(
- dataset, # noqa
- batch_sampler=SampledDatasetIterator(
- total_samples=len(self._datasets[dataset_name]),
- begin_index=consumed_samples,
- micro_batch_size=batch_config.micro_batch_size,
- data_rank=self._distributed.config.batch_data_rank,
- data_parallel=self._distributed.config.batch_data_parallel,
- ),
- num_workers=num_workers,
- prefetch_factor=prefetch_factor,
- pin_memory=True,
- collate_fn=LanguageModelBatch.from_samples,
- multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None,
- )
-
- elif isinstance(dataset, SampledIterableDataset):
- if (
- self.distributed.model_and_sequence_data_group is None
- or self.distributed.model_and_sequence_data_group.rank() == 0
- ):
- rank = 0
- data_loader = torch.utils.data.DataLoader(
- dataset, # noqa
- batch_size=batch_config.micro_batch_size,
- num_workers=0 if num_workers == 0 else 1,
- prefetch_factor=prefetch_factor,
- pin_memory=True,
- collate_fn=LanguageModelBatch.from_samples,
- multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None,
- )
- else:
- rank = self.distributed.model_and_sequence_data_group.rank()
- data_loader = None
- data_loader = DistributedDataLoaderWrapper(
- data_loader, rank, self.distributed.model_and_sequence_data_group
- )
+ data_loader = torch.utils.data.DataLoader(
+ self._datasets[dataset_name], # noqa
+ batch_sampler=SampledDatasetIterator(
+ total_samples=len(self._datasets[dataset_name]),
+ begin_index=consumed_samples,
+ micro_batch_size=batch_config.micro_batch_size,
+ data_rank=self._distributed.config.batch_data_rank,
+ data_parallel=self._distributed.config.batch_data_parallel,
+ ),
+ num_workers=num_workers,
+ prefetch_factor=prefetch_factor,
+ pin_memory=True,
+ collate_fn=LanguageModelBatch.from_samples,
+ multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None,
+ )
+
+ if self._datasets[dataset_name].requires_broadcast:
+ data_loader = DistributedDataLoaderWrapper(data_loader, self.distributed.model_and_sequence_data_group)
return iter(data_loader)
diff --git a/fast_llm/data/dataset/abstract.py b/fast_llm/data/dataset/abstract.py
index 2efdf3841..1df24e92b 100644
--- a/fast_llm/data/dataset/abstract.py
+++ b/fast_llm/data/dataset/abstract.py
@@ -5,6 +5,7 @@
if typing.TYPE_CHECKING:
from fast_llm.data.dataset.config import SamplingData
+ from fast_llm.data.dataset.sampled import SampledIterableDataset
class Dataset[SampleType: Sample](abc.ABC):
@@ -27,6 +28,14 @@ def __getstate__(self):
del state["__orig_class__"]
return state
+ @property
+ def requires_broadcast(self) -> bool:
+ """
+ Some dataset schemes load the dataset on a batch-data-parallel group leaders,
+ then broadcast to the other devices.
+ """
+ return False
+
class SampledDataset[SampleType: Sample](Dataset[SampleType]):
"""
@@ -44,6 +53,18 @@ def __len__(self) -> int:
class SamplableDataset[SampleType: Sample](Dataset[SampleType]):
+
@abc.abstractmethod
def sample(self, config: "SamplingData") -> SampledDataset[SampleType]:
pass
+
+
+class SamplableIterableDataset[SampleType: Sample](SamplableDataset[SampleType]):
+ @abc.abstractmethod
+ def __iter__(self) -> typing.Iterator[SampleType]:
+ pass
+
+ def sample(self, config: "SamplingData") -> "SampledIterableDataset[SampleType]":
+ from fast_llm.data.dataset.sampled import SampledIterableDataset
+
+ return SampledIterableDataset(self, config)
diff --git a/fast_llm/data/dataset/abstract_iterable.py b/fast_llm/data/dataset/abstract_iterable.py
deleted file mode 100644
index 770f4f97c..000000000
--- a/fast_llm/data/dataset/abstract_iterable.py
+++ /dev/null
@@ -1,30 +0,0 @@
-import abc
-import typing
-
-import torch.utils.data
-
-from fast_llm.data.sample.abstract import Sample
-
-if typing.TYPE_CHECKING:
- from fast_llm.data.dataset.config import SamplingData
-
-
-# NOTE: We need to inherit from IterableDataset otherwise torch data loader can not detect it properly
-class SampledIterableDataset[SampleType: Sample](torch.utils.data.IterableDataset[SampleType]):
- """
- A sampled dataset class that provides an iterator over samples.
- """
-
- # NOTE: We add name here so it is compatible with Fast-LLM Dataset
- @property
- @abc.abstractmethod
- def name(self) -> str:
- """
- A name for the dataset to facilitate identification and debugging.
- """
-
-
-class SamplableIterableDataset[SampleType: Sample](SampledIterableDataset[SampleType]):
- @abc.abstractmethod
- def sample(self, config: "SamplingData") -> SampledIterableDataset[SampleType]:
- pass
diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py
index f2a24c48b..b94f0e5f4 100644
--- a/fast_llm/data/dataset/config.py
+++ b/fast_llm/data/dataset/config.py
@@ -7,16 +7,15 @@
import pathlib
import typing
-from fast_llm.config import Config, Field, FieldHint, FieldUpdate, UpdateType, check_field, config_class
+from fast_llm.config import Config, Field, FieldHint, UpdateType, check_field, config_class
from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset
from fast_llm.data.preprocessing.abstract import PreprocessingConfig
from fast_llm.data.sample.abstract import Sample
-from fast_llm.redis.config import RedisConfig
from fast_llm.utils import Assert, normalize_probabilities
if typing.TYPE_CHECKING:
- from fast_llm.data.dataset.abstract_iterable import SamplableIterableDataset, SampledIterableDataset
from fast_llm.data.dataset.indexed import ConcatenatedDataset, DatasetSlice, IndexedDataset
+ from fast_llm.data.sample.language_model import LanguageModelSample
from fast_llm.engine.distributed.distributed import Distributed
logger = logging.getLogger(__name__)
@@ -108,25 +107,19 @@ class DatasetConfig[SampleType: Sample](Config):
@config_class(registry=True)
class SampledDatasetConfig[SampleType: Sample](DatasetConfig[SampleType]):
"""
- A sampled dataset containing a prepared list or iterable of samples to be indexed sequentially (as-is) during training.
+ A sampled dataset containing a prepared list of samples to be indexed sequentially (as-is) during training.
"""
- def build_and_sample(
- self, sampling: SamplingData
- ) -> "SampledDataset[SampleType] | SampledIterableDataset[SampleType]":
+ def build_and_sample(self, sampling: SamplingData) -> SampledDataset[SampleType]:
raise NotImplementedError()
@config_class()
class SamplableDatasetConfig[SampleType: Sample](SampledDatasetConfig[SampleType]):
- def build(
- self, preprocessing: PreprocessingConfig
- ) -> "SamplableDataset[SampleType] | SamplableIterableDataset[SampleType]":
+ def build(self, preprocessing: PreprocessingConfig) -> SamplableDataset[SampleType]:
raise NotImplementedError()
- def build_and_sample(
- self, sampling: SamplingData
- ) -> "SampledDataset[SampleType] | SampledIterableDataset[SampleType]":
+ def build_and_sample(self, sampling: SamplingData) -> SampledDataset[SampleType]:
return self.build(sampling.preprocessing).sample(sampling)
@@ -308,89 +301,54 @@ def build(self, preprocessing: PreprocessingConfig) -> "IndexedDataset[SampleTyp
raise FileNotFoundError(self.path)
-@config_class()
-class StreamingDatasetRedisConfig(RedisConfig):
- stream_key: str = FieldUpdate(default="fast_llm_streaming")
-
- payload_key: str = FieldUpdate(
- default="data",
- )
+REDIS_DATA_STREAM = "fast_llm_streaming"
+REDIS_FIELD = "data"
+REDIS_GROUP_NAME = "fast_llm_group"
-class IngestionType(str, enum.Enum):
- CONSUMER_GROUP = "consumer_group"
- ONE_STREAM = "one_stream"
- N_STREAMS = "n_streams"
-
-
-class HashType(str, enum.Enum):
- MESSAGE_INDEX = "message_index"
- """Use the index of the received message for hashing. Provides precise distribution but may not be well shuffled."""
+@config_class()
+class RedisConfig(Config):
+ REDIS_FIELD: typing.ClassVar[str] = "data"
+ REDIS_FIELD_B: typing.ClassVar[bytes] = REDIS_FIELD.encode()
+ REDIS_GROUP_NAME: typing.ClassVar[str] = "fast_llm_group"
+ REDIS_GROUP_NAME_B: typing.ClassVar[bytes] = REDIS_GROUP_NAME.encode()
+
+ # TODO: Move elsewhere? (Also used in trainer) Get it from the trainer in sampling config?
+ host: str = Field(
+ default="localhost",
+ desc="Hostname or IP address of the Redis server.",
+ hint=FieldHint.core,
+ )
- MESSAGE_ID = "message_id"
- """Hash messages based on their unique message ID. Good for probabilistic distribution.
- Redis message IDs are regenerated each time, so this is not reproducible.
- """
+ port: int = Field(
+ default=6379,
+ desc="Port number on which the Redis server is running.",
+ hint=FieldHint.core,
+ )
- MESSAGE_BODY = "message_body"
- """Hash messages based on their payload content (bytes). Distributes messages roughly evenly.
- Deterministic based on message content, but not perfectly balanced across ranks.
- """
+ def get_client(self):
+ import redis
- PRODUCER_PROVIDED = "producer_provided"
- """Use the hash or index provided by the producer. Allows deterministic splitting and perfect balance."""
+ return redis.Redis(self.host, self.port)
@config_class(dynamic_type={SampledDatasetConfig: "streaming"})
-class StreamingDatasetConfig[SampleType: LanguageModelSample](SamplableDatasetConfig[SampleType]):
+class StreamingDatasetConfig[SampleType: LanguageModelSample](RedisConfig, SamplableDatasetConfig[SampleType]):
"""
Configuration for a streaming dataset that reads training data from a Redis stream.
"""
_abstract = False
- redis: StreamingDatasetRedisConfig = Field(
- desc="Redis connection and stream settings used to fetch incoming training data.",
- hint=FieldHint.core,
- )
-
- group_name: str = Field(
- default="fast_llm_dp_group",
- desc="Name of the Redis consumer group used for data-parallel reading.",
- hint=FieldHint.core,
- )
-
- consumer_name_prefix: str = Field(
- default="fast_llm_dp_group_consumer",
- desc="Prefix used to generate unique consumer names for each rank in Redis consumer group.",
- hint=FieldHint.core,
- )
-
- ingestion_type: IngestionType = Field(
- default=IngestionType.CONSUMER_GROUP,
- desc="Strategy used to ingest data from Redis streams (consumer group, single stream, or multiple streams).",
- hint=FieldHint.core,
- )
-
- hash_type: HashType = Field(
- default=HashType.MESSAGE_ID,
- desc="How to compute hash for assigning messages to ranks.",
- hint=FieldHint.core,
- )
-
- hash_key: str = Field(
- default="hash",
- desc="Key in the message dict containing the hash or index provided by the producer.",
- hint=FieldHint.core,
- )
-
- ack_period_per_consumer: int = Field(
+ acknowledge_interval: int = Field(
default=10,
desc="Number of messages after which the consumer acknowledges received IDs back to the Redis hash.",
hint=FieldHint.core,
)
- def build_and_sample(self, sampling: SamplingData) -> "SampledIterableDataset[SampleType]":
- from fast_llm.data.dataset.streaming import StreamingDataset
+ def build_and_sample(self, sampling: SamplingData) -> SampledDataset[SampleType]:
+ from fast_llm.data.dataset.streaming import RedisStreamingDataset
- return StreamingDataset[SampleType](self, sampling.distributed).sample(sampling)
+ return RedisStreamingDataset[StreamingDatasetConfig, SampleType](self, sampling.distributed.config).sample(
+ sampling
+ )
diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py
index fc326d366..5e978ac2b 100644
--- a/fast_llm/data/dataset/gpt/config.py
+++ b/fast_llm/data/dataset/gpt/config.py
@@ -64,7 +64,11 @@ def build(self, preprocessing: PreprocessingConfig) -> SamplableDataset[SampleTy
def _load_config(self) -> SampledDatasetConfig[SampleType]:
assert self.path.is_file(), f"File {self.path} does not exist."
- return SampledDatasetConfig[SampleType].from_dict(self._convert_paths(yaml.safe_load(self.path.open("r"))))
+ config = yaml.safe_load(self.path.open("r"))
+ if config.keys() == {"config", "metadata"}:
+ # Newer format with metadata
+ config = config["config"]
+ return SampledDatasetConfig[SampleType].from_dict(self._convert_paths(config))
def _convert_paths(self, config):
# Recursively convert paths relative to `self.path.parent` to make them relative to cwd.
diff --git a/fast_llm/data/dataset/memmap.py b/fast_llm/data/dataset/memmap.py
index f80a48b0a..9831f81ba 100644
--- a/fast_llm/data/dataset/memmap.py
+++ b/fast_llm/data/dataset/memmap.py
@@ -8,7 +8,12 @@
from fast_llm.data.dataset.config import SamplingParameters
from fast_llm.data.dataset.indexed import IndexedDataset
from fast_llm.data.preprocessing.abstract import PreprocessingConfig
-from fast_llm.data.sample.abstract import MemmapIndexDatasetReaderConfig, MemmapWriter, Sample
+from fast_llm.data.sample.abstract import (
+ MemmapIndexDatasetReaderConfig,
+ MemmapIndexedDatasetReader,
+ MemmapWriter,
+ Sample,
+)
FILE_HEADER = b"fast_llm_prepared_dataset"
@@ -82,6 +87,10 @@ def get_document_sizes(self) -> torch.Tensor:
def get_document_size(self, index: int) -> int:
return self._reader.get_document_size(index)
+ @property
+ def reader(self) -> MemmapIndexedDatasetReader:
+ return self._reader
+
@classmethod
def write_dataset(
cls,
diff --git a/fast_llm/data/dataset/monitor.py b/fast_llm/data/dataset/monitor.py
index 01f3195e4..27070f674 100644
--- a/fast_llm/data/dataset/monitor.py
+++ b/fast_llm/data/dataset/monitor.py
@@ -51,3 +51,11 @@ def __getitem__(self, index: int) -> SampleType:
@property
def name(self) -> str:
return self._dataset.name
+
+ @property
+ def requires_broadcast(self) -> bool:
+ """
+ Some dataset schemes load the dataset on a batch-data-parallel group leaders,
+ then broadcast to the other devices.
+ """
+ return self._dataset.requires_broadcast
diff --git a/fast_llm/data/dataset/sampled.py b/fast_llm/data/dataset/sampled.py
index 979fd7a60..8cf7d938a 100644
--- a/fast_llm/data/dataset/sampled.py
+++ b/fast_llm/data/dataset/sampled.py
@@ -8,8 +8,7 @@
import torch
import yaml
-from fast_llm.data.dataset.abstract import SampledDataset
-from fast_llm.data.dataset.abstract_iterable import SamplableIterableDataset, SampledIterableDataset
+from fast_llm.data.dataset.abstract import SamplableIterableDataset, SampledDataset
from fast_llm.data.dataset.config import SamplingData, ShufflingType
from fast_llm.data.dataset.indexed import IndexedDataset
from fast_llm.data.sample.abstract import Sample
@@ -112,6 +111,10 @@ def __init__(
# No barrier yet to allow running in parallel.
# There needs to be one before calling `__getitem__`, normally handled through `Data`.
+ @property
+ def requires_broadcast(self) -> bool:
+ return self._indexed_dataset.requires_broadcast
+
def _sample(self) -> None:
"""
Create a `SampledDataset` with the requested parameters.
@@ -432,52 +435,58 @@ def _load_yaml_data(self, data: dict[str, typing.Any]) -> None:
self._unshuffled_documents = data["unshuffled_epochs"] * self._documents_per_epoch
-class NaiveSampledIterableDataset[SampleType: Sample](SampledIterableDataset[SampleType]):
+class SampledIterableDataset[SampleType: Sample](SampledDataset[SampleType]):
def __init__(
self,
- iterable_dataset: SamplableIterableDataset[SampleType],
+ dataset: SamplableIterableDataset[SampleType],
sampling: SamplingData,
):
- self._dataset = iterable_dataset
+ self._dataset = dataset
self._config = sampling.config
self._parameters = sampling.parameters
+ self._documents: list[SampleType] = []
+ self._current_length = 0
+ self._sample_length = self._parameters.sequence_length + self._parameters.extra_tokens
+ # Delay iterator creation to avoid pickling issues.
+ self._iterator: typing.Iterator[SampleType] | None = None
- assert self._parameters.truncate_documents == False
- assert self._config.shuffle == ShufflingType.disabled
+ @property
+ def requires_broadcast(self) -> bool:
+ # TODO: ====== fix ======
+ # return self._iterator.requires_broadcast
+ return True
- def __iter__(self) -> typing.Iterator[SampleType]:
- sample_length = self._parameters.sequence_length + self._parameters.extra_tokens
- current_sample_length = 0
- documents: list[SampleType] = []
- for doc in self._dataset:
- if len(doc) > sample_length:
- logging.warning(f"Dropping doc with length {len(doc)} higher then sample_length {sample_length}")
+ def __getitem__(self, index: int) -> SampleType:
+ if self._iterator is None:
+ self._iterator = iter(self._dataset)
+ while self._current_length < self._sample_length:
+ document = next(self._iterator)
+ if len(document) > self._sample_length:
+ logging.warning(f"Dropping document with length {len(document)} > {self._sample_length}.")
continue
- if current_sample_length + len(doc) > sample_length:
- padding_length = sample_length - current_sample_length
- assert padding_length > 0
- documents.append(documents[-1].get_padding(padding_length))
-
- yield documents[0].from_documents(documents)
+ self._documents.append(document)
+ self._current_length += len(document)
- documents = [doc]
- current_sample_length = len(doc)
+ if self._current_length == self._sample_length:
+ documents = self._documents
+ self._documents = []
+ self._current_length = 0
+ else:
+ last_length = len(self._documents[-1])
+ remaining_length = last_length - (self._current_length - self._sample_length)
+ if self._parameters.truncate_documents:
+ documents = self._documents[:-1] + [self._documents[-1].crop(0, remaining_length)]
+ self._documents = [self._documents[-1].crop(remaining_length, last_length)]
else:
- documents.append(doc)
- current_sample_length += len(doc)
-
- if current_sample_length == sample_length:
- yield documents[0].from_documents(documents)
-
- documents = []
- current_sample_length = 0
+ documents = self._documents[:-1] + [self._documents[0].get_padding(remaining_length)]
+ self._documents = [self._documents[-1]]
+ self._current_length = len(self._documents[0])
+ sample = documents[0].from_documents(documents)
+ Assert.eq(len(sample), self._sample_length)
+ return sample
- if current_sample_length > 0:
- padding_length = sample_length - current_sample_length
- assert padding_length > 0
- documents.append(documents[-1].get_padding(padding_length))
-
- yield documents[0].from_documents(documents)
+ def __len__(self) -> int:
+ return self._parameters.num_samples
@property
def name(self) -> str:
diff --git a/fast_llm/data/dataset/streaming.py b/fast_llm/data/dataset/streaming.py
index 1aabf60cc..9f47395a2 100644
--- a/fast_llm/data/dataset/streaming.py
+++ b/fast_llm/data/dataset/streaming.py
@@ -1,19 +1,16 @@
+import json
import typing
+import redis
import torch.utils.data
-import xxhash
-from fast_llm.data.dataset.abstract_iterable import SamplableIterableDataset, SampledIterableDataset
-from fast_llm.data.dataset.config import HashType, IngestionType, SamplingData, StreamingDatasetConfig
-from fast_llm.data.dataset.sampled import NaiveSampledIterableDataset
+from fast_llm.config import Configurable
+from fast_llm.data.dataset.abstract import SamplableIterableDataset
+from fast_llm.data.dataset.config import REDIS_DATA_STREAM, REDIS_FIELD, REDIS_GROUP_NAME, StreamingDatasetConfig
from fast_llm.data.sample.language_model import LanguageModelSample
from fast_llm.data.sample.range import RangeSample
from fast_llm.data.sample.token import TokenSample
-from fast_llm.engine.config_utils.run import is_main_rank
-from fast_llm.engine.distributed.distributed import Distributed
-
-if typing.TYPE_CHECKING:
- import redis
+from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames
def dtype_from_string(name: str) -> torch.dtype:
@@ -23,70 +20,39 @@ def dtype_from_string(name: str) -> torch.dtype:
raise ValueError(f"Unknown torch dtype: {name}")
-class StreamingDataset[SampleType: LanguageModelSample](SamplableIterableDataset[SampleType]):
- def __init__(self, config: StreamingDatasetConfig, distributed: Distributed):
- super().__init__()
- if distributed.config.pipeline_parallel > 1:
- # NOTE: It is not yet clear whether the issue comes from the streaming dataset
- # itself or from the distributed data-loader wrappers, but currently it
- # interferes with pipeline-parallel training and causes a timeout during
- # the training step.
- raise NotImplementedError("Streaming dataset support is not implemented for pipeline-parallel training.")
+class RedisStreamingDataset[ConfigType: StreamingDatasetConfig, SampleType: LanguageModelSample](
+ Configurable[ConfigType], SamplableIterableDataset[SampleType]
+):
+ def __init__(self, config: ConfigType, distributed_config: DistributedConfig):
+ super().__init__(config)
+ # if distributed_config.pipeline_parallel > 1:
+ # NOTE: It is not yet clear whether the issue comes from the streaming dataset
+ # itself or from the distributed data-loader wrappers, but currently it
+ # interferes with pipeline-parallel training and causes a timeout during
+ # the training step.
+ # raise NotImplementedError("Streaming dataset support is not implemented for pipeline-parallel training.")
- self._name = f"redis[{config.redis.host}:{config.redis.port}]({config.redis.stream_key}|{config.group_name})[{config.redis.payload_key}]"
+ self._name = f"redis[{config.host}:{config.port}]({REDIS_DATA_STREAM}|{REDIS_GROUP_NAME})[data]"
self._config = config
- self.batch_data_rank = distributed.config.batch_data_rank
- self.batch_data_parallel = distributed.config.batch_data_parallel
+ self._rank = distributed_config.batch_data_rank
self.is_batch_data_group_leader = (
- distributed.model_and_sequence_data_group is None or distributed.model_and_sequence_data_group.rank() == 0
+ distributed_config.get_distributed_dim(DistributedDimNames.model_and_sequence_data).rank == 0
)
- self.payload_key_b = self._config.redis.payload_key.encode()
- self.hash_key_b = self._config.hash_key.encode()
- self._set_consumer_count()
+ # TODO: Not needed?
+ # if distributed_config.rank == 0:
+ # redis_client = redis.Redis(host=self._config.host, port=self._config.port)
+ # redis_client.hset(f"{REDIS_DATA_KEY}:consumer_count", "0", str(distributed_config.batch_data_parallel))
+
+ @property
+ def requires_broadcast(self) -> bool:
+ return True
@property
def name(self) -> str:
return self._name
- def sample(self, config: SamplingData) -> SampledIterableDataset[LanguageModelSample]:
- # TODO: actually sample the dataset and not return docs
- return NaiveSampledIterableDataset(self, config)
-
- def _set_consumer_count(self):
- import redis
-
- if is_main_rank():
- redis_client = redis.Redis(host=self._config.redis.host, port=self._config.redis.port)
- redis_client.hset(f"{self._config.redis.stream_key}:consumer_count", "0", self.batch_data_parallel)
-
- def __getstate__(self) -> tuple[str, StreamingDatasetConfig, int, int, bool, bytes, bytes]:
- return (
- self._name,
- self._config,
- self.batch_data_parallel,
- self.batch_data_rank,
- self.is_batch_data_group_leader,
- self.payload_key_b,
- self.hash_key_b,
- )
-
- def __setstate__(self, state: tuple[str, StreamingDatasetConfig, int, bool, bytes, bytes]):
- name, config, batch_data_parallel, batch_data_rank, is_batch_data_group_leader, payload_key_b, hash_key_b = (
- state
- )
- self._name = name
- self._config = config
- self.batch_data_parallel = batch_data_parallel
- self.batch_data_rank = batch_data_rank
- self.is_batch_data_group_leader = is_batch_data_group_leader
- self.payload_key_b = payload_key_b
- self.hash_key_b = hash_key_b
-
def __iter__(self) -> typing.Iterator[LanguageModelSample]:
- import orjson
- import redis
-
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None and worker_info.num_workers > 1:
raise RuntimeError("StreamingDataset can work only with one instance per rank")
@@ -94,41 +60,12 @@ def __iter__(self) -> typing.Iterator[LanguageModelSample]:
if not self.is_batch_data_group_leader:
raise RuntimeError("Must be only called on the batch data group leader")
- redis_client = redis.Redis(host=self._config.redis.host, port=self._config.redis.port)
-
- match self._config.ingestion_type:
- case IngestionType.CONSUMER_GROUP:
- messages_iter = self._iter_consumer_group
- case IngestionType.ONE_STREAM:
- messages_iter = self._iter_stream
- case IngestionType.N_STREAMS:
- messages_iter = self._iter_stream
- case _:
- raise ValueError(f"Unknown ingestion type {self._config.ingestion_type}")
-
- ack_hash = f"{self._config.redis.stream_key}:ack"
- consumer_id = str(self.batch_data_rank)
- # If one stream each groups receives all data otherwise each groups receives just its data
- if self._config.ingestion_type == IngestionType.ONE_STREAM:
- ack_period = self._config.ack_period_per_consumer * self.batch_data_parallel
- else:
- ack_period = self._config.ack_period_per_consumer
-
- for msg_data in messages_iter(redis_client, ack_hash=ack_hash, consumer_id=consumer_id, ack_period=ack_period):
- data = orjson.loads(msg_data[self.payload_key_b])
- yield self._sample_from_msg_data(data)
-
- def _iter_consumer_group(
- self, redis_client: "redis.Redis", ack_hash: str, consumer_id: str, ack_period: int
- ) -> typing.Iterator[LanguageModelSample]:
- import redis.exceptions
+ client = redis.Redis(host=self._config.host, port=self._config.port)
# Create the consumer group at the start of the stream ("0")
# If the stream already exists, XGROUP CREATE will fail unless we add mkstream=True
try:
- redis_client.xgroup_create(
- name=self._config.redis.stream_key, groupname=self._config.group_name, id="0", mkstream=True
- )
+ client.xgroup_create(name=REDIS_DATA_STREAM, groupname=REDIS_GROUP_NAME, id="0", mkstream=True)
except redis.exceptions.ResponseError as e:
if "BUSYGROUP" in str(e):
# Consumer group already exists
@@ -141,11 +78,11 @@ def _iter_consumer_group(
# XREADGROUP reads from the consumer group
# COUNT: max number of messages to fetch at once
# BLOCK: wait for new messages (milliseconds)
- messages = redis_client.xreadgroup(
- groupname=self._config.group_name,
- consumername=f"{self._config.consumer_name_prefix}_{self.batch_data_rank}",
+ messages = client.xreadgroup(
+ groupname=REDIS_GROUP_NAME,
+ consumername=f"fast_llm_consumer_{self._rank}",
# ">" reads only new messages that have not been delivered to any consumer
- streams={self._config.redis.stream_key: ">"},
+ streams={REDIS_DATA_STREAM: ">"},
count=1,
block=1000,
# No explicit ACK: messages are processed immediately; on rank failure the job restarts,
@@ -154,106 +91,33 @@ def _iter_consumer_group(
)
if messages:
for stream_key, msgs in messages:
- assert stream_key == self._config.redis.stream_key.encode()
+ assert stream_key == REDIS_DATA_STREAM.encode()
for msg_id, msg_data in msgs:
processed += 1
# TODO: or do it after processing all received messaged then count > 1?
- if processed % ack_period == 0:
- redis_client.hset(ack_hash, consumer_id, msg_id)
-
- yield msg_data
-
- def _iter_stream(
- self, redis_client: "redis.Redis", ack_hash: str, consumer_id: str, ack_period: int
- ) -> typing.Iterator[LanguageModelSample]:
- last_id = "0-0"
- stream_key = self._config.redis.stream_key
- if self._config.ingestion_type == IngestionType.N_STREAMS:
- stream_key += f"_{self.batch_data_rank}"
- stream_key_b = stream_key.encode()
- processed = 0
- while True:
- messages = redis_client.xread(
- streams={stream_key: last_id},
- count=1,
- block=1000,
- )
- if not messages:
- continue
- for this_stream_key_b, msgs in messages:
- assert this_stream_key_b == stream_key_b
- for msg_id, msg_data in msgs:
- last_id = msg_id
-
- processed += 1
- # TODO: or do it after processing all received messaged then count > 1?
- if processed % ack_period == 0:
- redis_client.hset(ack_hash, consumer_id, last_id)
+ if processed % self._config.acknowledge_interval == 0:
+ client.hset(f"{REDIS_DATA_STREAM}:ack", str(self._rank), msg_id)
- if self._config.ingestion_type == IngestionType.N_STREAMS or self._is_for_this_rank(
- msg_id, msg_data, processed - 1
- ):
- yield msg_data
-
- def _is_for_this_rank(self, msg_id: bytes, msg_data: dict, msg_index: int) -> bool:
- hash_type = self._config.hash_type
-
- if hash_type is HashType.MESSAGE_INDEX:
- h = msg_index
- elif hash_type is HashType.MESSAGE_ID:
- h = xxhash.xxh64(msg_id).intdigest()
- elif hash_type is HashType.MESSAGE_BODY:
- h = xxhash.xxh64(msg_data[self.payload_key_b]).intdigest()
- elif hash_type is HashType.PRODUCER_PROVIDED:
- h = self._get_hash_key_value(msg_data[self.hash_key_b])
- else:
- raise ValueError(f"Unknown hash_type: {hash_type}")
-
- return (h % self.batch_data_parallel) == self.batch_data_rank
-
- def _get_hash_key_value(self, value: bytes | int | str):
- if isinstance(value, int):
- # already an integer
- msg_hash = value
- elif isinstance(value, bytes):
- try:
- # try decoding as UTF-8 string and converting to int
- msg_hash = int(value.decode("utf-8"))
- except ValueError:
- # not an integer, treat as a hash string
- import xxhash
-
- msg_hash = xxhash.xxh64(value).intdigest()
- elif isinstance(value, str):
- try:
- msg_hash = int(value)
- except ValueError:
- # not an integer, treat as a hash string
- import xxhash
-
- msg_hash = xxhash.xxh64(value.encode("utf-8")).intdigest()
- else:
- raise TypeError(f"Unexpected type for hash key: {type(value)}")
- return msg_hash
+ yield self._read_document(json.loads(msg_data[REDIS_FIELD.encode()]))
- def _sample_from_msg_data(self, data: dict) -> LanguageModelSample:
+ def _read_document(self, data: dict) -> LanguageModelSample:
tokens = torch.tensor(data["tokens"], dtype=dtype_from_string(data["tokens_dtype"]))
sample_size = len(tokens)
if "loss_masking_spans" in data:
- loss_masking_spans = [tuple(el) for el in data["loss_masking_spans"]]
+ loss_masking_spans = RangeSample([(begin, end) for begin, end in data["loss_masking_spans"]], sample_size)
else:
loss_masking_spans = None
if "chosen_spans" in data:
- chosen_spans = [tuple(el) for el in data["chosen_spans"]]
+ chosen_spans = RangeSample([(begin, end) for begin, end in data["chosen_spans"]], sample_size)
else:
chosen_spans = None
if "rejected_spans" in data:
- rejected_spans = [tuple(el) for el in data["rejected_spans"]]
+ rejected_spans = RangeSample([(begin, end) for begin, end in data["rejected_spans"]], sample_size)
else:
rejected_spans = None
return LanguageModelSample(
TokenSample(tokens, [sample_size]),
- RangeSample(loss_masking_spans, sample_size) if loss_masking_spans is not None else None,
- RangeSample(chosen_spans, sample_size) if chosen_spans is not None else None,
- RangeSample(rejected_spans, sample_size) if rejected_spans is not None else None,
+ loss_masking_spans,
+ chosen_spans,
+ rejected_spans,
)
diff --git a/fast_llm/data/iterator.py b/fast_llm/data/iterator.py
deleted file mode 100644
index a407c0258..000000000
--- a/fast_llm/data/iterator.py
+++ /dev/null
@@ -1,25 +0,0 @@
-import typing
-
-import torch.utils.data
-
-
-class SampledDatasetIterator(torch.utils.data.Sampler):
- """
- A distributed sampler generating indices for a `SampledDataset` (i.e., the natural numbers).
- To be used as the `batch_sampler` of a `torch.utils.data.DataLoader`.
- """
-
- def __init__(self, total_samples, begin_index, micro_batch_size, data_rank, data_parallel):
- super().__init__()
- self._total_samples = total_samples
- self._begin_index = begin_index
- self._batch_size = micro_batch_size * data_parallel
- self._start_idx = data_rank * micro_batch_size
- self._end_idx = (data_rank + 1) * micro_batch_size
-
- def __len__(self) -> int:
- return self._total_samples
-
- def __iter__(self) -> typing.Iterator[list[int]]:
- for idx in range(self._begin_index, self._total_samples - self._batch_size + 1, self._batch_size):
- yield list(range(idx + self._start_idx, idx + self._end_idx))
diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py
index 503b400c3..a1aadf40a 100644
--- a/fast_llm/data/preparator/gpt_memmap/config.py
+++ b/fast_llm/data/preparator/gpt_memmap/config.py
@@ -15,30 +15,81 @@
from fast_llm.data.preparator.gpt_memmap.prepare import GPTMemmapDatasetPreparator
-@config_class()
+@config_class(registry=True)
class LanguageModelSourceConfig(Config):
"""
- A schema holding the name of each relevant column in the dataset.
- Setting optional entries will enable the associated feature.
+ Abstract base class for data source schemas.
+
+ Use `type: document` (default) for documents with text, optional span annotations, and optional images.
+ Use `type: conversation` for structured chat/conversation datasets.
+ """
+
+ @classmethod
+ def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self:
+ if cls is LanguageModelSourceConfig and cls.get_subclass(default.get("type")) is None:
+ # Default to DocumentSourceConfig when type is not specified
+ return DocumentSourceConfig._from_dict(default, strict)
+ return super()._from_dict(default, strict=strict)
+
+ @functools.cached_property
+ def columns(self) -> list[str]:
+ """Columns to read from the dataset."""
+ raise NotImplementedError
+
+ @functools.cached_property
+ def has_loss_masking_span(self) -> bool:
+ return False
+
+ @functools.cached_property
+ def has_preference_spans(self) -> bool:
+ return False
+
+ @functools.cached_property
+ def has_images(self) -> bool:
+ return False
+
+
+@config_class(dynamic_type={LanguageModelSourceConfig: "document"})
+class DocumentSourceConfig(LanguageModelSourceConfig):
+ """
+ Source schema for document datasets with text, optional span annotations, and optional images.
+
+ The dataset should have a text column containing the document text.
+ Optionally, it can have additional columns for:
+ - Loss masking spans: character ranges to mask from loss computation
+ - Preference spans: chosen/rejected text for DPO training
+ - Images: image data with character positions for multimodal training
"""
text: str = Field(
default="text",
- desc="Field of the dataset to use.",
+ desc="Field containing the document text.",
+ hint=FieldHint.optional,
+ )
+ loss_masking_spans: str | None = Field(
+ default=None,
+ desc="Field containing character spans to mask for loss computation.",
hint=FieldHint.optional,
)
- loss_masking_spans: None | str = Field(
- default=None, desc="Field containing character spans to mask for loss computation", hint=FieldHint.optional
+ chosen_span: str | None = Field(
+ default=None,
+ desc="Field containing chosen text for preference optimization.",
+ hint=FieldHint.optional,
)
- chosen_span: None | str = Field(
- default=None, desc="Field containing chosen text for preference optimization", hint=FieldHint.optional
+ rejected_span: str | None = Field(
+ default=None,
+ desc="Field containing rejected text for preference optimization.",
+ hint=FieldHint.optional,
)
- rejected_span: None | str = Field(
- default=None, desc="Field containing rejected text for preference optimization", hint=FieldHint.optional
+ images: str | None = Field(
+ default=None,
+ desc="Field containing images.",
+ hint=FieldHint.optional,
)
- images: None | str = Field(default=None, desc="Field containing images", hint=FieldHint.optional)
- image_positions: None | str = Field(
- default=None, desc="Field containing image positions in the text.", hint=FieldHint.optional
+ image_positions: str | None = Field(
+ default=None,
+ desc="Field containing image positions in the text.",
+ hint=FieldHint.optional,
)
@functools.cached_property
@@ -48,6 +99,8 @@ def columns(self) -> list[str]:
columns.append(self.loss_masking_spans)
if self.has_preference_spans:
columns.extend([self.chosen_span, self.rejected_span])
+ if self.has_images:
+ columns.extend([self.images, self.image_positions])
return columns
@functools.cached_property
@@ -67,7 +120,50 @@ def has_images(self) -> bool:
def _validate(self):
super()._validate()
if self.has_preference_spans and self.has_loss_masking_span:
- raise ValueError(f"Can not enable both loss masking and preference spans.")
+ raise ValueError("Cannot enable both loss masking and preference spans.")
+
+
+@config_class(dynamic_type={LanguageModelSourceConfig: "conversation"})
+class ConversationSourceConfig(LanguageModelSourceConfig):
+ """
+ Source schema for chat/conversation datasets (e.g., Tulu 3, ShareGPT, OpenAI format).
+
+ The dataset should have a messages column containing a list of message dicts,
+ where each message has 'role' and 'content' keys. Common roles include:
+ - 'system': System prompt
+ - 'user': User input
+ - 'assistant': Model response (trained on by default)
+ - 'tool': Tool/function results
+ - 'ipython': Code execution results
+
+ The conversation is formatted using the tokenizer's chat template, which must
+ contain {% generation %}...{% endgeneration %} markers to define which content
+ to train on. Loss masking spans are automatically computed from these markers.
+
+ Example dataset format:
+ {
+ "messages": [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": "Hello!"},
+ {"role": "assistant", "content": "Hi there!"},
+ ]
+ }
+ """
+
+ messages: str = Field(
+ default="messages",
+ desc="Field containing the conversation messages list. Each message should have 'role' and 'content' keys.",
+ hint=FieldHint.core,
+ )
+
+ @functools.cached_property
+ def columns(self) -> list[str]:
+ return [self.messages]
+
+ @functools.cached_property
+ def has_loss_masking_span(self) -> bool:
+ # Conversation format always generates loss masking spans from chat template markers
+ return True
@config_class()
diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py
index 2ea81d8a6..325d33c43 100644
--- a/fast_llm/data/preparator/gpt_memmap/prepare.py
+++ b/fast_llm/data/preparator/gpt_memmap/prepare.py
@@ -28,7 +28,12 @@
)
from fast_llm.data.dataset.memmap import MemmapDataset
from fast_llm.data.preparator.config import DatasetPreparator
-from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig, LanguageModelSourceConfig
+from fast_llm.data.preparator.gpt_memmap.config import (
+ ConversationSourceConfig,
+ DocumentSourceConfig,
+ GPTMemmapDatasetPreparatorConfig,
+ LanguageModelSourceConfig,
+)
from fast_llm.data.preprocessing.abstract import NullPreprocessingConfig
from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig
from fast_llm.data.preprocessing.tokenizer import Tokenizer
@@ -39,7 +44,7 @@
from fast_llm.data.sample.token import TokenSample
from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type
from fast_llm.engine.config_utils.run import log_main_rank
-from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum
+from fast_llm.utils import normalize_probabilities, padded_cumsum
logger = logging.getLogger(__name__)
@@ -132,6 +137,10 @@ def run(self) -> None:
# Load tokenizer
self._tokenizer = self._config.tokenizer.get_tokenizer()
+ # Validate chat template for conversation format
+ if isinstance(self._source_schema, ConversationSourceConfig):
+ self._tokenizer.validate_chat_template()
+
# Decide the datatype based on the tokenizer vocabulary size
self._data_type = (
get_unsigned_integer_type(self._tokenizer.vocab_size)
@@ -216,92 +225,112 @@ def _preprocessing_config(self) -> LanguageModelPreprocessingConfig:
)
def _prepare_sample(self, sample: dict[str, typing.Any]) -> LanguageModelSample:
- text = sample[self._source_schema.text]
- all_spans = []
- if self._source_schema.has_loss_masking_span:
- # Spans are typically stored in the (begin, last) format. We convert to (begin, end) range format.
- loss_masking_spans = _sort_spans(
- (SpanType.loss_masking, (begin, last + 1))
- for begin, last in np.array(sample[self._source_schema.loss_masking_spans], dtype=np.int32)
- .reshape(-1, 2)
- .tolist()
+ token_spans_by_type = collections.defaultdict(list)
+ image_patches = image_token_maps = image_position_ids = patch_counts = None
+
+ if isinstance(self._source_schema, ConversationSourceConfig):
+ # Conversation format: tokenize messages and get loss masking spans from chat template
+ tokens, loss_masking_spans = self._tokenizer.tokenize_chat(
+ sample[self._source_schema.messages],
+ True,
+ True,
+ data_type=self._data_type,
)
- all_spans.extend(loss_masking_spans)
-
- if self._source_schema.has_preference_spans:
- full_chosen_text = text + sample[self._source_schema.chosen_span] + self._tokenizer.tokenizer.eos_token
- full_rejected_text = self._tokenizer.tokenizer.bos_token + text + sample[self._source_schema.rejected_span]
- # compute chosen span
- chosen_spans = [(SpanType.chosen, (len(text), len(full_chosen_text)))]
-
- # compute rejected span
- rejected_span = [
- (
- SpanType.rejected,
- (
- len(full_chosen_text) + len(self._tokenizer.tokenizer.bos_token) + len(text),
- len(full_chosen_text) + len(full_rejected_text),
- ),
+ token_spans_by_type[SpanType.loss_masking] = loss_masking_spans
+ elif isinstance(self._source_schema, DocumentSourceConfig):
+ # Document format: use the text-spans pipeline
+ text = sample[self._source_schema.text]
+ all_spans = []
+
+ if self._source_schema.has_loss_masking_span:
+ # Spans are typically stored in the (begin, last) format. We convert to (begin, end) range format.
+ loss_masking_spans = _sort_spans(
+ (SpanType.loss_masking, (begin, last + 1))
+ for begin, last in np.array(sample[self._source_schema.loss_masking_spans], dtype=np.int32)
+ .reshape(-1, 2)
+ .tolist()
+ )
+ all_spans.extend(loss_masking_spans)
+
+ if self._source_schema.has_preference_spans:
+ full_chosen_text = text + sample[self._source_schema.chosen_span] + self._tokenizer.tokenizer.eos_token
+ full_rejected_text = (
+ self._tokenizer.tokenizer.bos_token + text + sample[self._source_schema.rejected_span]
)
- ]
- # pack texts
- text = full_chosen_text + full_rejected_text
- all_spans.extend(chosen_spans + rejected_span)
-
- if self._source_schema.has_images:
- # Get the images and positions, sorted by position.
- images, image_positions = (
- zip(
- *sorted(
- zip(
- sample[self._source_schema.images],
- sample[self._source_schema.image_positions],
- strict=True,
+ # compute chosen span
+ chosen_spans = [(SpanType.chosen, (len(text), len(full_chosen_text)))]
+
+ # compute rejected span
+ rejected_span = [
+ (
+ SpanType.rejected,
+ (
+ len(full_chosen_text) + len(self._tokenizer.tokenizer.bos_token) + len(text),
+ len(full_chosen_text) + len(full_rejected_text),
),
- key=lambda x: x[1],
)
+ ]
+ # pack texts
+ text = full_chosen_text + full_rejected_text
+ all_spans.extend(chosen_spans + rejected_span)
+
+ if self._source_schema.has_images:
+ # Get the images and positions, sorted by position.
+ images, image_positions = (
+ zip(
+ *sorted(
+ zip(
+ sample[self._source_schema.images],
+ sample[self._source_schema.image_positions],
+ strict=True,
+ ),
+ key=lambda x: x[1],
+ )
+ )
+ if len(sample[self._source_schema.images]) > 0
+ else ([], [])
)
- if len(sample[self._source_schema.images]) > 0
- else ([], [])
- )
- # Get the image patches and associated data.
- image_patches, image_position_ids, image_token_maps, image_token_ids, patch_counts = (
- self._config.image_patches.get_patches_from_images(images, self._data_type)
+ # Get the image patches and associated data.
+ image_patches, image_position_ids, image_token_maps, image_token_ids, patch_counts = (
+ self._config.image_patches.get_patches_from_images(images, self._data_type)
+ )
+ patch_count_cumsum = padded_cumsum(patch_counts).tolist()
+ # Add an empty "span" at each image position so we know where to insert them in the tokenized sequence.
+ all_spans.extend([(SpanType.image, (position, position)) for position in image_positions])
+
+ # Sort the spans by location (begin), keeping track of their type.
+ # Note: overlapping spans are not supported (explicit assertion in the tokenizer).
+ span_types, spans = zip(*_sort_spans(all_spans)) if all_spans else ([], [])
+ # Tokenize the text, and determine the span locations in the tokenized text.
+ tokens, token_spans = self._tokenizer.tokenize_with_spans(
+ text, True, True, text_spans=spans, data_type=self._data_type
)
- patch_count_cumsum = padded_cumsum(patch_counts).tolist()
- # Add an empty "span" at each image position so we know where to insert them in the tokenized sequence.
- all_spans.extend([(SpanType.image, (position, position)) for position in image_positions])
-
- # Sort the spans by location (begin), keeping track of their type.
- # Note: overlapping spans are not supported (explicit assertion in the tokenizer).
- span_types, spans = zip(*_sort_spans(all_spans)) if all_spans else ([], [])
- # Tokenize the text, and determine the span locations in the tokenized text.
- tokens, token_spans = self._tokenizer.tokenize_with_spans(
- text, True, True, text_spans=spans, data_type=self._data_type
- )
- # Gather token spans by type.
- token_spans_by_type = collections.defaultdict(list)
- if self._source_schema.has_images:
- # Insert the image token ids in the token sequence and shift the spans accordingly.
- tokens_shift = 0
- image_index = 0
- for span_type, (begin, end) in zip(span_types, token_spans, strict=True):
- # Account for the tokens already inserted.
- begin = begin + tokens_shift
- end = end + tokens_shift
- if span_type == SpanType.image:
- # Shift the token map to the image location.
- image_token_maps[patch_count_cumsum[image_index] : patch_count_cumsum[image_index + 1]] += begin
- # Insert the placeholder and image break tokens.
- tokens = torch.cat([tokens[:begin], image_token_ids[image_index], tokens[begin:]])
- tokens_shift += len(image_token_ids[image_index])
- image_index += 1
- else:
- token_spans_by_type[span_type].append((begin, end))
+ # Gather token spans by type.
+ if self._source_schema.has_images:
+ # Insert the image token ids in the token sequence and shift the spans accordingly.
+ tokens_shift = 0
+ image_index = 0
+ for span_type, (begin, end) in zip(span_types, token_spans, strict=True):
+ # Account for the tokens already inserted.
+ begin = begin + tokens_shift
+ end = end + tokens_shift
+ if span_type == SpanType.image:
+ # Shift the token map to the image location.
+ image_token_maps[
+ patch_count_cumsum[image_index] : patch_count_cumsum[image_index + 1]
+ ] += begin
+ # Insert the placeholder and image break tokens.
+ tokens = torch.cat([tokens[:begin], image_token_ids[image_index], tokens[begin:]])
+ tokens_shift += len(image_token_ids[image_index])
+ image_index += 1
+ else:
+ token_spans_by_type[span_type].append((begin, end))
+ else:
+ for span_type, token_span in zip(span_types, token_spans, strict=True):
+ token_spans_by_type[span_type].append(token_span)
else:
- for span_type, token_span in zip(span_types, token_spans, strict=True):
- token_spans_by_type[span_type].append(token_span)
+ raise NotImplementedError(f"Unsupported source schema type: {type(self._source_schema)}")
sample_size = len(tokens)
@@ -346,16 +375,18 @@ def generate_config_yaml_for_sharded_dst(
# Create the config file(s) on rank 0
dataset_configs, reader_configs = zip(*dataset_and_reader_configs)
if self._config.splits:
- for split_name, split_config in self._split_and_blend_dataset_configs(
+ for split_name, (split_config, metadata) in self._split_and_blend_dataset_configs(
dataset_configs, reader_configs, self._config.splits, self._config.output_path
).items():
self._save_dataset_config(
- split_config, self._config.output_path / f"fast_llm_config_{split_name}.yaml"
+ split_config,
+ metadata,
+ output_path=self._config.output_path / f"fast_llm_config_{split_name}.yaml",
)
else:
self._save_dataset_config(
- self._blend_dataset_configs(dataset_configs, reader_configs),
- self._config.output_path / f"fast_llm_config.yaml",
+ *self._blend_dataset_configs(dataset_configs, reader_configs),
+ output_path=self._config.output_path / f"fast_llm_config.yaml",
)
# Save metadata on rank 0
@@ -368,29 +399,30 @@ def generate_config_yaml_for_sharded_dst(
@classmethod
def _save_dataset_config(
- cls, dataset_config: IndexedDatasetConfig[_sample_type], output_path: pathlib.Path
+ cls,
+ dataset_config: IndexedDatasetConfig[_sample_type],
+ metadata: dict[str, typing.Any],
+ output_path: pathlib.Path,
) -> None:
logger.info(f"Saving config to {output_path}")
- yaml.safe_dump(
- dataset_config.to_dict(),
- output_path.open("w"),
- )
+ yaml.safe_dump({"config": dataset_config.to_dict(), "metadata": metadata}, output_path.open("w"))
@classmethod
def _blend_dataset_configs(
cls,
dataset_configs: list[MemmapDatasetConfig[_sample_type]],
reader_configs: list[MemmapIndexDatasetReaderConfig],
- ) -> IndexedDatasetConfig[_sample_type]:
+ ) -> tuple[IndexedDatasetConfig[_sample_type], dict[str, typing.Any]]:
+ datasets_metadata = [reader_config.get_metadata() for reader_config in reader_configs]
if len(dataset_configs) == 1:
- return dataset_configs[0]
+ return dataset_configs[0], datasets_metadata[0]
return SampledDatasetConfig[cls._sample_type].from_dict(
{
"type": "blended",
"datasets": dataset_configs,
"weights": [reader_config.num_tokens for reader_config in reader_configs],
}
- )
+ ), reader_configs[0].blend_metadata(datasets_metadata)
def _split_and_blend_dataset_configs(
self,
@@ -398,7 +430,7 @@ def _split_and_blend_dataset_configs(
reader_configs: list[MemmapIndexDatasetReaderConfig],
splits: dict[str, int | float],
output_path: pathlib.Path,
- ) -> dict[str, SampledDatasetConfig[_sample_type]]:
+ ) -> dict[str, tuple[SampledDatasetConfig[_sample_type], dict[str, typing.Any]]]:
split_cumsum = padded_cumsum(normalize_probabilities(list(splits.values()), return_array=True)).tolist()
dataset_sizes = [reader_config.num_tokens for reader_config in reader_configs]
dataset_probabilities = normalize_probabilities(dataset_sizes)
@@ -407,7 +439,7 @@ def _split_and_blend_dataset_configs(
for split_index, split_name in enumerate(splits):
datasets_in_split = []
- dataset_tokens_in_split = []
+ datasets_metadata = []
for dataset_index, (dataset_config, reader_config) in enumerate(
zip(dataset_configs, reader_configs, strict=True)
):
@@ -424,17 +456,17 @@ def _split_and_blend_dataset_configs(
if split_begin_in_dataset == 0 and split_end_in_dataset == 1:
# All the dataset belongs to the split.
datasets_in_split.append(dataset_configs[dataset_index])
- dataset_tokens_in_split.append(dataset_sizes[dataset_index])
+ datasets_metadata.append(reader_config.get_metadata())
+
elif split_end_in_dataset > split_begin_in_dataset:
# Part of the dataset belongs to the split.
# TODO: Somehow getting a segfault when merging two lines below (numpy bug?).
dataset = dataset_config.to_copy({"path": output_path / dataset_config.path}).build(
self._preprocessing_config
)
- sizes_cumsum = dataset.get_document_sizes().numpy().cumsum()
- Assert.eq(sizes_cumsum[-1], reader_config.num_tokens)
- begin_index = _get_nearest_split(sizes_cumsum, split_begin_in_dataset * reader_config.num_tokens)
- end_index = _get_nearest_split(sizes_cumsum, split_end_in_dataset * reader_config.num_tokens)
+ begin_index, end_index, metadata = dataset.reader.get_split(
+ split_begin_in_dataset, split_end_in_dataset
+ )
if end_index > begin_index:
datasets_in_split.append(
DatasetSliceConfig[self._sample_type].from_dict(
@@ -446,10 +478,7 @@ def _split_and_blend_dataset_configs(
}
)
)
- dataset_tokens_in_split.append(
- sizes_cumsum[end_index - 1].item()
- - (sizes_cumsum[begin_index - 1].item() if begin_index > 0 else 0)
- )
+ datasets_metadata.append(metadata)
# [else] None of the dataset belongs to the split.
@@ -457,14 +486,17 @@ def _split_and_blend_dataset_configs(
# This is a big problem, but we don't want to crash the whole run.
logger.error(f"Datasets split {split_name} is empty!")
elif len(datasets_in_split) == 1:
- dataset_splits[split_name] = datasets_in_split[0]
+ dataset_splits[split_name] = (datasets_in_split[0], datasets_metadata[0])
else:
- dataset_splits[split_name] = BlendedDatasetConfig[self._sample_type].from_dict(
- {
- "type": "blended",
- "datasets": datasets_in_split,
- "weights": dataset_tokens_in_split,
- }
+ dataset_splits[split_name] = (
+ BlendedDatasetConfig[self._sample_type].from_dict(
+ {
+ "type": "blended",
+ "datasets": datasets_in_split,
+ "weights": [dataset_metadata["num_tokens"] for dataset_metadata in datasets_metadata],
+ }
+ ),
+ reader_configs[0].blend_metadata(datasets_metadata),
)
return dataset_splits
diff --git a/fast_llm/data/preprocessing/tokenizer.py b/fast_llm/data/preprocessing/tokenizer.py
index abfb5b3d2..157744f51 100644
--- a/fast_llm/data/preprocessing/tokenizer.py
+++ b/fast_llm/data/preprocessing/tokenizer.py
@@ -213,3 +213,105 @@ def _remove_delimiters(
@property
def eod(self):
return self.eod_id
+
+ @staticmethod
+ def _has_generation_markers(template: str | None) -> bool:
+ """Check if a template has generation markers."""
+ return template is not None and "{% generation %}" in template
+
+ def validate_chat_template(self) -> None:
+ """
+ Validate the tokenizer's chat template has generation markers.
+
+ Raises:
+ ValueError: If the tokenizer lacks a chat template or generation markers.
+ """
+ template = self.tokenizer.chat_template
+
+ if template is None:
+ raise ValueError(
+ "Tokenizer does not have a chat template. "
+ "Conversation format requires a tokenizer with a built-in chat template "
+ "containing {% generation %}...{% endgeneration %} markers."
+ )
+
+ if not self._has_generation_markers(template):
+ raise ValueError(
+ "Tokenizer's chat template does not contain {% generation %}...{% endgeneration %} markers. "
+ "These markers are required to determine which tokens to train on. "
+ "Please use a tokenizer with generation markers in its chat template."
+ )
+
+ def tokenize_chat(
+ self,
+ messages: list[dict[str, str]],
+ begin: bool = True,
+ end: bool = True,
+ data_type: DataType = DataType.int64,
+ ) -> tuple["torch.Tensor", list[tuple[int, int]]]:
+ """
+ Apply chat template and return (tokens, loss_masking_spans).
+
+ The loss_masking_spans mark token ranges to EXCLUDE from training (where the model
+ should not learn). These are derived from the chat template's generation markers -
+ tokens outside {% generation %}...{% endgeneration %} blocks are masked.
+ """
+ import torch
+
+ result = self.tokenizer.apply_chat_template(
+ messages,
+ tokenize=True,
+ return_assistant_tokens_mask=True,
+ return_dict=True,
+ add_generation_prompt=False,
+ )
+ tokens = result["input_ids"]
+ train_mask = result["assistant_masks"]
+
+ # Prepend BOS / append EOS if not already present anywhere in the sequence.
+ # We check anywhere (not just first/last) because some chat templates add trailing
+ # whitespace after the final EOS token, e.g. "<|im_end|>\n".
+ prepend_bos = begin and self.bod_id not in tokens
+ append_eos = end and self.eod_id not in tokens
+ tokens = [self.bod_id] * prepend_bos + list(tokens) + [self.eod_id] * append_eos
+ train_mask = [False] * prepend_bos + [bool(m) for m in train_mask] + [False] * append_eos
+
+ # Convert boolean train mask to loss masking spans (spans where train_mask[i] == False)
+ loss_masking_spans = _train_mask_to_loss_spans(train_mask)
+
+ if self._config.max_vocab_size is not None:
+ tokens = (
+ torch.tensor(
+ tokens,
+ dtype=torch.int64 if len(self.tokenizer) > torch.iinfo(data_type.torch).max else data_type.torch,
+ )
+ % self._config.max_vocab_size
+ ).to(data_type.torch)
+ else:
+ tokens = torch.tensor(tokens, dtype=data_type.torch)
+ return tokens, loss_masking_spans
+
+
+def _train_mask_to_loss_spans(train_mask: list[bool]) -> list[tuple[int, int]]:
+ """
+ Convert a boolean train mask to loss masking spans.
+
+ Args:
+ train_mask: Boolean list where True = train on this token, False = don't train
+
+ Returns:
+ List of (begin, end) spans marking token ranges to EXCLUDE from training
+ (i.e., where train_mask[i] == False).
+ """
+ spans = []
+ start = None
+ for i, should_train in enumerate(train_mask):
+ if not should_train:
+ if start is None:
+ start = i
+ elif start is not None:
+ spans.append((start, i))
+ start = None
+ if start is not None:
+ spans.append((start, len(train_mask)))
+ return spans
diff --git a/fast_llm/data/sample/abstract.py b/fast_llm/data/sample/abstract.py
index 1d71363b7..494a5c4a5 100644
--- a/fast_llm/data/sample/abstract.py
+++ b/fast_llm/data/sample/abstract.py
@@ -73,6 +73,13 @@ def expected_buffer_size(self) -> int:
"""
raise NotImplementedError()
+ def get_metadata(self) -> dict[str, typing.Any]:
+ raise NotImplementedError()
+
+ @classmethod
+ def blend_metadata(cls, metadata: list[dict[str, typing.Any]]) -> dict[str, typing.Any]:
+ raise NotImplementedError()
+
@config_class(dynamic_type={MemmapReaderBaseConfig: "none"})
class NullReaderConfig(MemmapReaderBaseConfig):
@@ -159,6 +166,13 @@ def reader_class(self) -> "type[MemmapIndexedDatasetReader]":
def get_reader(self, buffer: memoryview, model_preprocessing: PreprocessingConfig) -> "MemmapIndexedDatasetReader":
return self.reader_class(self, buffer, model_preprocessing)
+ def get_metadata(self) -> dict[str, typing.Any]:
+ return {"num_tokens": self.num_tokens}
+
+ @classmethod
+ def blend_metadata(cls, metadata: list[dict[str, typing.Any]]) -> dict[str, typing.Any]:
+ return {"num_tokens": sum(metadata_["num_tokens"] for metadata_ in metadata)}
+
class MemmapReaderBase[ConfigType: MemmapReaderBaseConfig](Configurable[ConfigType]):
@abc.abstractmethod
@@ -196,6 +210,9 @@ def get_document_sizes(self) -> "torch.Tensor":
def get_document_size(self, index: int) -> int:
pass
+ def get_split(self, begin_ratio: float, end_ratio: float) -> tuple[int, int, dict[str, typing.Any]]:
+ raise NotImplementedError()
+
class MemmapWriter(abc.ABC):
def __init__(
diff --git a/fast_llm/data/sample/language_model.py b/fast_llm/data/sample/language_model.py
index 3183a9ec1..22b89acf1 100644
--- a/fast_llm/data/sample/language_model.py
+++ b/fast_llm/data/sample/language_model.py
@@ -23,6 +23,7 @@
from fast_llm.data.sample.patch import (
EmptyPatchReader,
PatchBatch,
+ PatchReader,
PatchReaderBaseConfig,
PatchReaderConfig,
PatchSample,
@@ -31,6 +32,7 @@
from fast_llm.data.sample.range import (
EmptyRangeReader,
RangeBatch,
+ RangeReader,
RangeReaderBaseConfig,
RangeReaderConfig,
RangeSample,
@@ -222,6 +224,41 @@ def _expected_buffer_size(self) -> int:
+ self.image_patches.expected_buffer_size
)
+ def get_metadata(self) -> dict[str, typing.Any]:
+ out = super().get_metadata()
+ out["tokens"] = self.tokens.get_metadata()
+ if not isinstance(self.loss_masking_spans, NullReaderConfig):
+ out["loss_masking_spans"] = self.loss_masking_spans.get_metadata()
+ if not isinstance(self.chosen_spans, NullReaderConfig):
+ out["chosen_spans"] = self.chosen_spans.get_metadata()
+ if not isinstance(self.rejected_spans, NullReaderConfig):
+ out["rejected_spans"] = self.rejected_spans.get_metadata()
+ if not isinstance(self.image_patches, NullReaderConfig):
+ out["image_patches"] = self.image_patches.get_metadata()
+ return out
+
+ @classmethod
+ def blend_metadata(cls, metadata: list[dict[str, typing.Any]]) -> dict[str, typing.Any]:
+ out = super().blend_metadata(metadata)
+ out["tokens"] = TokenReaderConfig.blend_metadata([metadata_["tokens"] for metadata_ in metadata])
+ if "loss_masking_spans" in metadata[0]:
+ out["loss_masking_spans"] = RangeReaderConfig.blend_metadata(
+ [metadata_["loss_masking_spans"] for metadata_ in metadata]
+ )
+ if "chosen_spans" in metadata[0]:
+ out["chosen_spans"] = RangeReaderConfig.blend_metadata(
+ [metadata_["chosen_spans"] for metadata_ in metadata]
+ )
+ if "rejected_spans" in metadata[0]:
+ out["image_patches"] = RangeReaderConfig.blend_metadata(
+ [metadata_["image_patches"] for metadata_ in metadata]
+ )
+ if "image_patches" in metadata[0]:
+ out["image_patches"] = PatchReaderConfig.blend_metadata(
+ [metadata_["image_patches"] for metadata_ in metadata]
+ )
+ return out
+
class LanguageModelReader[ConfigType: LanguageModelReaderConfig](MemmapIndexedDatasetReader[ConfigType]):
_model_preprocessing: LanguageModelPreprocessingConfig
@@ -305,6 +342,23 @@ def get_document_sizes(self) -> torch.Tensor:
def get_document_size(self, index: int) -> int:
return self._tokens.get_document_size(index)
+ def get_split(self, begin_ratio: float, end_ratio: float) -> tuple[int, int, dict[str, typing.Any]]:
+ begin_index, end_index, token_metadata = self._tokens.get_split(begin_ratio, end_ratio)
+ metadata = {
+ "num_tokens": token_metadata["num_tokens"],
+ "tokens": token_metadata,
+ }
+ if hasattr(self, "_loss_masking_spans") and isinstance(self._loss_masking_spans, RangeReader):
+ metadata["loss_masking_spans"] = self._loss_masking_spans.get_split(begin_index, end_index)
+ if hasattr(self, "_chosen_spans") and isinstance(self._chosen_spans, RangeReader):
+ metadata["chosen_spans"] = self._chosen_spans.get_split(begin_index, end_index)
+ if hasattr(self, "_rejected_spans") and isinstance(self._rejected_spans, RangeReader):
+ metadata["rejected_spans"] = self._rejected_spans.get_split(begin_index, end_index)
+ if hasattr(self, "_image_patches") and isinstance(self._image_patches, PatchReader):
+ metadata["image_patches"] = self._image_patches.get_split(begin_index, end_index)
+
+ return begin_index, end_index, metadata
+
class LanguageModelWriter(MemmapWriter):
_preprocessing_config: LanguageModelPreprocessingConfig
diff --git a/fast_llm/data/sample/patch.py b/fast_llm/data/sample/patch.py
index 221746752..7ae537104 100644
--- a/fast_llm/data/sample/patch.py
+++ b/fast_llm/data/sample/patch.py
@@ -192,6 +192,27 @@ def _expected_buffer_size(self) -> int:
* torch.int32.itemsize
)
+ def get_metadata(self) -> dict[str, typing.Any]:
+ return {
+ "num_documents": self.num_documents,
+ "num_patches": self.num_patches,
+ "num_patch_groups": self.num_patch_groups,
+ "num_pixels": self.patch_size * self.num_patches,
+ "patch_shape": self.patch_shape,
+ "data_type": str(self.data_type),
+ }
+
+ @classmethod
+ def blend_metadata(cls, metadata: list[dict[str, typing.Any]]) -> dict[str, typing.Any]:
+ return {
+ "num_documents": sum(metadata_["num_documents"] for metadata_ in metadata),
+ "num_patches": sum(metadata_["num_patches"] for metadata_ in metadata),
+ "num_patch_groups": sum(metadata_["num_patch_groups"] for metadata_ in metadata),
+ "num_pixels": sum(metadata_["num_pixels"] for metadata_ in metadata),
+ "patch_shape": get_unique(metadata_["patch_shape"] for metadata_ in metadata),
+ "data_type": get_unique(metadata_["data_type"] for metadata_ in metadata),
+ }
+
class PatchReader[ConfigType: PatchReaderConfig](MemmapReader[ConfigType]):
def __init__(self, config: ConfigType, buffer: memoryview, model_preprocessing: PreprocessingConfig | None = None):
@@ -253,6 +274,19 @@ def get_document(self, index: int, begin: int, end: int) -> Sample:
),
)
+ def get_split(self, begin_index: int, end_index: int) -> dict[str, typing.Any]:
+ Assert.custom(lambda x: x == sorted(x), [0, begin_index, end_index, self._config.num_documents])
+ num_patches = self._patch_count_cumsums[end_index].item() - self._patch_count_cumsums[begin_index].item()
+ return {
+ "num_documents": end_index - begin_index,
+ "num_patches": num_patches,
+ "num_patch_groups": self._group_count_cumsums[end_index].item()
+ - self._group_count_cumsums[begin_index].item(),
+ "num_pixels": self._config.patch_size * num_patches,
+ "patch_shape": self._config.patch_shape,
+ "data_type": str(self._config.data_type),
+ }
+
class EmptyPatchReader[ConfigType: PatchReaderBaseConfig](MemmapReaderBase[ConfigType]):
def get_document(self, index: int, begin: int, end: int) -> Sample:
diff --git a/fast_llm/data/sample/range.py b/fast_llm/data/sample/range.py
index a77846725..53683342a 100644
--- a/fast_llm/data/sample/range.py
+++ b/fast_llm/data/sample/range.py
@@ -92,6 +92,19 @@ def writer_class(self) -> "type[RangeWriter]":
def _expected_buffer_size(self) -> int:
return self.num_ranges * torch.int32.itemsize * 2 + (self.num_documents + 1) * torch.int32.itemsize
+ def get_metadata(self) -> dict[str, typing.Any]:
+ return {
+ "num_documents": self.num_documents,
+ "num_ranges": self.num_ranges,
+ }
+
+ @classmethod
+ def blend_metadata(cls, metadata: list[dict[str, typing.Any]]) -> dict[str, typing.Any]:
+ return {
+ "num_documents": sum(metadata_["num_documents"] for metadata_ in metadata),
+ "num_ranges": sum(metadata_["num_ranges"] for metadata_ in metadata),
+ }
+
class RangeReader[ConfigType: RangeReaderConfig](MemmapReader[ConfigType]):
def __init__(self, config: ConfigType, buffer: memoryview, model_preprocessing: PreprocessingConfig | None = None):
@@ -116,6 +129,13 @@ def get_document(self, index: int, begin: int, end: int) -> Sample:
)
return RangeSample([(begin_, end_) for begin_, end_ in cropped_ranges if end_ > begin_], sample_size)
+ def get_split(self, begin_index: int, end_index: int) -> dict[str, typing.Any]:
+ Assert.custom(lambda x: x == sorted(x), [0, begin_index, end_index, self._config.num_documents])
+ return {
+ "num_documents": end_index - begin_index,
+ "num_ranges": self._count_cumsums[end_index].item() - self._count_cumsums[begin_index].item(),
+ }
+
class EmptyRangeReader[ConfigType: RangeReaderBaseConfig](MemmapReaderBase[ConfigType]):
def get_document(self, index: int, begin: int, end: int) -> Sample:
diff --git a/fast_llm/data/sample/token.py b/fast_llm/data/sample/token.py
index 1bc9ef1a1..cd4d7fa02 100644
--- a/fast_llm/data/sample/token.py
+++ b/fast_llm/data/sample/token.py
@@ -14,7 +14,7 @@
Sample,
)
from fast_llm.engine.config_utils.data_type import DataType
-from fast_llm.utils import Assert
+from fast_llm.utils import Assert, get_unique
def crop_lengths(lengths: list[int], begin: int, end: int) -> list[int]:
@@ -110,6 +110,21 @@ def writer_class(self) -> "type[TokenWriter]":
def _expected_buffer_size(self) -> int:
return self.num_tokens * self.data_type.torch.itemsize + (self.num_documents + 1) * torch.int64.itemsize
+ def get_metadata(self) -> dict[str, typing.Any]:
+ return {
+ "num_tokens": self.num_tokens,
+ "num_documents": self.num_documents,
+ "data_type": str(self.data_type),
+ }
+
+ @classmethod
+ def blend_metadata(cls, metadata: list[dict[str, typing.Any]]) -> dict[str, typing.Any]:
+ return {
+ "num_tokens": sum(metadata_["num_tokens"] for metadata_ in metadata),
+ "num_documents": sum(metadata_["num_documents"] for metadata_ in metadata),
+ "data_type": get_unique(metadata_["data_type"] for metadata_ in metadata),
+ }
+
class TokenReader[ConfigType: TokenReaderConfig](MemmapIndexedDatasetReader[ConfigType]):
def __init__(self, config: ConfigType, buffer: memoryview, model_preprocessing: PreprocessingConfig | None = None):
@@ -135,6 +150,28 @@ def get_document_sizes(self) -> torch.Tensor:
def get_document_size(self, index: int) -> int:
return self._size_cumsums[index + 1].item() - self._size_cumsums[index].item()
+ def get_split(self, begin_ratio: float, end_ratio: float) -> tuple[int, int, dict[str, typing.Any]]:
+ Assert.custom(lambda x: x == sorted(x), [0, begin_ratio, end_ratio, 1])
+ begin_index = _get_nearest_split(self._size_cumsums[1:], begin_ratio * self.num_tokens)
+ end_index = _get_nearest_split(self._size_cumsums[1:], end_ratio * self.num_tokens)
+
+ return (
+ begin_index,
+ end_index,
+ {
+ "num_tokens": self._size_cumsums[end_index].item() - self._size_cumsums[begin_index].item(),
+ "num_documents": end_index - begin_index,
+ "data_type": str(self._config.data_type),
+ },
+ )
+
+
+def _get_nearest_split(cumsum: torch.Tensor, value: float) -> int:
+ left = torch.searchsorted(cumsum, value, side="right")
+ if left == len(cumsum):
+ return left.item()
+ return left.item() + 1 if (value - cumsum[left]) / (cumsum[left + 1] - cumsum[left]) > 0.5 else left.item()
+
class TokenWriter(MemmapWriter):
def __enter__(self):
diff --git a/fast_llm/engine/checkpoint/config.py b/fast_llm/engine/checkpoint/config.py
index 3f1970538..98303539e 100644
--- a/fast_llm/engine/checkpoint/config.py
+++ b/fast_llm/engine/checkpoint/config.py
@@ -141,11 +141,12 @@ class CheckpointSaveConfigBase(CheckpointConfigBase):
@config_class()
class CheckpointStateSaveConfigBase(CheckpointSaveConfigBase, CheckpointStateConfigBase):
+ _abstract = False
model_weights: bool = FieldUpdate(desc="Save the model weights.")
optimizer_state: bool = FieldUpdate(desc="Save the optimizer state. Default: save if supported by the `format`.")
def _validate(self) -> None:
- if self.optimizer_state is None:
+ if self.optimizer_state is None and hasattr(self.format, "support_optimizer"):
with self._set_implicit_default():
# TODO: Make sure it's a type
self.optimizer_state = self.format.support_optimizer
diff --git a/fast_llm/engine/checkpoint/distributed.py b/fast_llm/engine/checkpoint/distributed.py
index d953ea35d..fecc35ef7 100644
--- a/fast_llm/engine/checkpoint/distributed.py
+++ b/fast_llm/engine/checkpoint/distributed.py
@@ -123,7 +123,7 @@ def _copy_shard_overlaps(self, loaded_model, loaded_shards, context):
for loaded_stage, loaded_fsdp, loaded_fsdp_shards in loaded_model.split_shards_by_fsdp(loaded_shards):
# Skip tied weight copies to avoid duplicate loads.
# We can't call `loaded_stage.is_tied_weight_copy` because the loaded model isn't setup.
- if not loaded_stage.index not in loaded_model.stages_owned:
+ if loaded_stage.index in loaded_model.stages_owned:
for self_stage, self_fsdp, self_fsdp_shards in self._model.split_shards_by_fsdp(self_shards):
counter = self_fsdp.copy_shard_overlaps(
loaded_fsdp,
diff --git a/fast_llm/engine/checkpoint/safe_load.py b/fast_llm/engine/checkpoint/safe_load.py
index 2e2a01881..d3f72a47c 100644
--- a/fast_llm/engine/checkpoint/safe_load.py
+++ b/fast_llm/engine/checkpoint/safe_load.py
@@ -4,7 +4,7 @@
import torch
from torch.distributed import all_reduce
-from fast_llm.core.distributed import add_ephemeral_timeout
+from fast_llm.core.distributed import set_timeout
from fast_llm.engine.multi_stage.config import ShardName
from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel
from fast_llm.functional.triton.pointwise import triton_fill
@@ -171,8 +171,8 @@ def _check_parameters(self, errors: list[str]) -> None:
if self._distributed.world_group is not None:
counter_tensor = torch.tensor(counters, dtype=torch.int64).to(self._distributed.device)
# This may be the first distributed barrier after loading, so we need to wait for everyone to finish.
- add_ephemeral_timeout(self._distributed.world_group, self._timeout)
- all_reduce(counter_tensor, group=self._distributed.world_group)
+ with set_timeout(self._distributed.world_group, self._timeout):
+ all_reduce(counter_tensor, group=self._distributed.world_group)
counters = counter_tensor.tolist()
# Compare global counts against expected values.
diff --git a/fast_llm/engine/checkpoint/state_dict.py b/fast_llm/engine/checkpoint/state_dict.py
index bbb0fa34b..32eea2db6 100644
--- a/fast_llm/engine/checkpoint/state_dict.py
+++ b/fast_llm/engine/checkpoint/state_dict.py
@@ -15,6 +15,7 @@
CheckpointLoadMetadataConfig,
CheckpointSaveConfig,
CheckpointSaveMetadataConfig,
+ CheckpointStateSaveConfigBase,
FastLLMCheckpointFormat,
export_safetensors_metadata,
)
@@ -72,7 +73,7 @@ def save(self, config: CheckpointSaveConfig, metadata: CheckpointMetadata) -> No
self._save_serialized_metadata(config, serialized_metadata, index)
def iter_tensors(
- self, config: CheckpointSaveConfig, metadata: "CheckpointMetadata"
+ self, config: CheckpointStateSaveConfigBase, metadata: "CheckpointMetadata"
) -> typing.Iterator[tuple[str, str, torch.Tensor]]:
# The tensor mapping may not be one-to-one. `convert_state_dict` pops all tensors from
# `state_dict` that are ready for conversion,
diff --git a/fast_llm/engine/distributed/config.py b/fast_llm/engine/distributed/config.py
index 8b9c0c13f..c2d6d1405 100644
--- a/fast_llm/engine/distributed/config.py
+++ b/fast_llm/engine/distributed/config.py
@@ -40,7 +40,7 @@
MAX_SEED = 2**64
-class PhaseType(str, enum.Enum):
+class PhaseType(enum.StrEnum):
training = "Training"
validation = "Validation"
test = "Test"
@@ -51,6 +51,21 @@ def is_training(self) -> bool:
return self == PhaseType.training
+class DistributedBackend(enum.StrEnum):
+ nccl = "nccl"
+ gloo = "gloo"
+
+ @property
+ def process_group_class(self) -> type["ProcessGroup"]:
+ import torch
+
+ return (
+ torch.distributed.ProcessGroupNCCL
+ if self == DistributedBackend.nccl
+ else torch.distributed.ProcessGroupGloo
+ )
+
+
@dataclasses.dataclass
class DistributedDim:
"""
@@ -87,7 +102,9 @@ def from_sizes_and_strides(cls, name: str, global_rank: int, *sizes_and_strides:
start = global_rank
rank = 0
world_size = 1
- for size, stride in sizes_and_strides:
+ for i, (size, stride) in enumerate(sizes_and_strides):
+ if i > 0:
+ Assert.multiple(stride, sizes_and_strides[i - 1][1])
rank_ = global_rank // stride % size
start -= rank_ * stride
rank += world_size * rank_
@@ -96,13 +113,13 @@ def from_sizes_and_strides(cls, name: str, global_rank: int, *sizes_and_strides:
for size, stride in sizes_and_strides:
if size == 1:
continue
- if len(global_ranks) == 1 or (
- isinstance(global_ranks, range) and stride == global_ranks.stop - global_ranks.start
- ):
- global_ranks = range(start, start + size * stride, sizes_and_strides[0][0])
+ if len(global_ranks) == 1:
+ global_ranks = range(start, start + size * stride, stride)
+ elif isinstance(global_ranks, range) and stride == global_ranks.stop - global_ranks.start:
+ global_ranks = range(start, start + size * stride, global_ranks.step)
else:
- global_ranks = (rank0 + rank1 for rank1 in range(0, size, stride) for rank0 in global_ranks)
- global_ranks = global_ranks if isinstance(global_ranks, range) else list(global_ranks)
+ global_ranks = [rank0 + rank1 for rank1 in range(0, size * stride, stride) for rank0 in global_ranks]
+ Assert.eq(len(global_ranks), world_size)
return DistributedDim(name=name, size=world_size, rank=rank, global_ranks=global_ranks)
@@ -199,6 +216,11 @@ class DistributedConfig(Config):
desc="Prioritize the pipeline groups for placement of nearby ranks over data groups.",
hint=FieldHint.expert,
)
+ backend: DistributedBackend = Field(
+ default=DistributedBackend.nccl,
+ desc="The distributed backend to use.",
+ hint=FieldHint.expert,
+ )
timeout: float = Field(
default=60,
desc="Timeout for distributed operations.",
@@ -267,8 +289,6 @@ class DistributedConfig(Config):
)
def _validate(self) -> None:
- super()._validate()
-
if self.world_size is None:
self.world_size = self.default_world_size
if self.rank is None:
@@ -348,10 +368,19 @@ def _validate(self) -> None:
self._add_distributed_dim_from_sizes_and_strides(
DistributedDimNames.model_and_sequence_data,
(self.tensor_parallel, tensor_stride),
- (self.sequence_data_parallel, sequence_data_stride),
- (self.pipeline_rank, pipeline_stride),
+ (
+ (self.pipeline_parallel, pipeline_stride)
+ if self.pipeline_first
+ else (self.sequence_data_parallel, sequence_data_stride)
+ ),
+ (
+ (self.sequence_data_parallel, sequence_data_stride)
+ if self.pipeline_first
+ else (self.pipeline_parallel, pipeline_stride)
+ ),
)
+ super()._validate()
if self.reference_config is not None:
self.compare(self.reference_config, ValueError)
Assert.in_range(self.rank, 0, self.world_size)
diff --git a/fast_llm/engine/distributed/distributed.py b/fast_llm/engine/distributed/distributed.py
index 7b95cecfb..d93e17d1c 100644
--- a/fast_llm/engine/distributed/distributed.py
+++ b/fast_llm/engine/distributed/distributed.py
@@ -8,6 +8,7 @@
from fast_llm.core.distributed import ProcessGroup
from fast_llm.engine.distributed.config import (
MAX_SEED,
+ DistributedBackend,
DistributedConfig,
DistributedDim,
DistributedDimNames,
@@ -27,6 +28,8 @@ def __init__(
local_world_size: int | None = None,
timeout: float = 60,
use_cpu: bool = False,
+ init_method: str = "env://",
+ backend: DistributedBackend = DistributedBackend.nccl,
):
self._rank = DistributedConfig.default_rank if rank is None else rank
@@ -36,10 +39,12 @@ def __init__(
)
self._timeout = timeout
self._use_cpu = use_cpu
+ self._backend = backend
self._process_groups = {}
if self._use_cpu:
- Assert.eq(self._world_size, 1)
+ if backend == DistributedBackend.nccl:
+ Assert.eq(self._world_size, 1)
self._device = torch.device("cpu")
else:
Assert.in_range_incl(self._local_world_size, 1, torch.cuda.device_count())
@@ -54,7 +59,7 @@ def __init__(
# TODO: Allow other init methods?
self.store, _, _ = next(
torch.distributed.rendezvous(
- "env://",
+ init_method,
self._rank,
self._world_size,
timeout=datetime.timedelta(seconds=timeout),
@@ -77,6 +82,10 @@ def local_world_size(self):
def device(self):
return self._device
+ @property
+ def backend(self):
+ return self._backend
+
def get_process_group(self, global_ranks: range | tuple, group_rank: int) -> ProcessGroup | None:
"""
Get the requested process group from the pool, or create it if it doesn't exist.
@@ -100,8 +109,7 @@ def get_process_group(self, global_ranks: range | tuple, group_rank: int) -> Pro
if isinstance(global_ranks, range)
else f"ranks_{"_".join(str(rank) for rank in global_ranks)}"
)
-
- group = torch.distributed.ProcessGroupNCCL(
+ group = self._backend.process_group_class(
torch.distributed.PrefixStore(prefix + "/", self.store),
group_rank,
group_size,
@@ -157,6 +165,7 @@ def __init__(self, config: DistributedConfig, use_cpu: bool = False):
self._config.local_world_size,
self._config.timeout,
use_cpu,
+ self._config.backend,
)
else:
self._pool = _default_pool
@@ -164,6 +173,7 @@ def __init__(self, config: DistributedConfig, use_cpu: bool = False):
Assert.eq(self._pool.rank, self._config.rank)
Assert.geq(self._pool.local_world_size, self._config.local_world_size)
Assert.eq(self._pool.device.type, "cpu" if use_cpu else "cuda")
+ Assert.eq(self._pool.backend, self._config.backend)
self.world_group = self.add_group(self._config.distributed_dims[DistributedDimNames.world])
self.data_group = self.add_group(self._config.distributed_dims[DistributedDimNames.data])
@@ -171,16 +181,10 @@ def __init__(self, config: DistributedConfig, use_cpu: bool = False):
self.tensor_group = self.add_group(self._config.distributed_dims[DistributedDimNames.tensor])
self.sequence_data_group = self.add_group(self._config.distributed_dims[DistributedDimNames.sequence_data])
self.batch_data_group = self.add_group(self._config.distributed_dims[DistributedDimNames.batch_data])
-
- # Global ranks wrong with pipeline first, so we hide the dims as a safety check.
- if not self._config.pipeline_first:
- self.tensor_and_sequence_data_group = self.add_group(
- self._config.distributed_dims[DistributedDimNames.tensor_and_sequence_data]
- )
- self.tensor_and_data_group = self.add_group(
- self._config.distributed_dims[DistributedDimNames.tensor_and_data]
- )
-
+ self.tensor_and_sequence_data_group = self.add_group(
+ self._config.distributed_dims[DistributedDimNames.tensor_and_sequence_data]
+ )
+ self.tensor_and_data_group = self.add_group(self._config.distributed_dims[DistributedDimNames.tensor_and_data])
self.model_and_sequence_data_group = self.add_group(
self._config.distributed_dims[DistributedDimNames.model_and_sequence_data]
)
diff --git a/fast_llm/engine/multi_stage/fast_llm_model.py b/fast_llm/engine/multi_stage/fast_llm_model.py
index ed6835140..68fd41af5 100644
--- a/fast_llm/engine/multi_stage/fast_llm_model.py
+++ b/fast_llm/engine/multi_stage/fast_llm_model.py
@@ -5,7 +5,7 @@
from fast_llm.config import UpdateType
from fast_llm.core.distributed import broadcast
-from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, CheckpointSaveConfig
+from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, CheckpointSaveConfig, CheckpointStateSaveConfigBase
from fast_llm.engine.distributed.distributed import Distributed
from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageMode
from fast_llm.engine.multi_stage.multi_stage import MultiStageModel
@@ -34,7 +34,7 @@ def save_checkpoint(
def iter_checkpoint(
self,
- config: CheckpointSaveConfig,
+ config: CheckpointStateSaveConfigBase,
extra_metadata: dict | None = None,
) -> typing.Iterator[tuple[str, str, torch.Tensor]]:
# TODO: Handle barriers, ok file, mkdir, etc. here
diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py
index 7624c72c4..0b492703c 100644
--- a/fast_llm/engine/training/config.py
+++ b/fast_llm/engine/training/config.py
@@ -7,6 +7,7 @@
from fast_llm.config import (
Config,
+ Configurable,
Field,
FieldHint,
FieldUpdate,
@@ -16,6 +17,7 @@
skip_valid_if_none,
)
from fast_llm.data.data.config import DataConfig
+from fast_llm.data.dataset.config import RedisConfig
from fast_llm.engine.checkpoint.config import (
CheckpointLoadConfig,
CheckpointSaveConfig,
@@ -24,15 +26,17 @@
)
from fast_llm.engine.config_utils.run import ExperimentConfig
from fast_llm.engine.config_utils.runnable import RunnableConfig
+from fast_llm.engine.distributed.config import DistributedBackend
from fast_llm.engine.evaluation.config import EvaluatorConfig, EvaluatorConfigBase
from fast_llm.engine.multi_stage.config import PretrainedFastLLMModelConfig
from fast_llm.engine.optimizer.config import OptimizerConfig
from fast_llm.engine.schedule.config import BatchConfig, ScheduleConfig
from fast_llm.profile import ProfilingConfig
-from fast_llm.redis.config import RedisConfig
from fast_llm.utils import Assert
if typing.TYPE_CHECKING:
+ from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel
+ from fast_llm.engine.training.streaming import StreamingTrainerCallback
from fast_llm.engine.training.trainer import Trainer, TrainingEvaluator
@@ -322,111 +326,63 @@ def _validate(self) -> None:
self.wandb.alert.assert_sub_interval(self.logs)
-@config_class()
-class TrainerEventsRedisConfig(RedisConfig):
- stream_key: str = FieldUpdate(default="fast_llm_events")
-
- payload_key: str = FieldUpdate(default="event")
-
+@config_class(registry=True)
+class TrainerCallbackConfig(Config):
+ def get_callback(self, model: "FastLLMModel") -> "TrainerCallback":
+ raise NotImplementedError()
-@config_class()
-class TrainerEvent(Config):
- enabled: bool = Field(
- default=False,
- desc="Flag indicating whether this event is enabled. If False, the event will be skipped.",
- hint=FieldHint.feature,
- )
+ def setup(self, config: "TrainerConfig") -> None:
+ pass
@config_class()
-class WeightsBroadcastEventConfig(TrainerEvent):
- """
- Event sent to indicate that updated weights are ready for broadcast.
- """
-
- initial_weights_step_message_type: str = Field(
- default="initial_weights_step",
- desc="Message indicating that weights the training starting/ continuing from.",
- hint=FieldHint.feature,
- )
-
- initial_weights_step_message_includes_weights: bool = Field(
- default=False,
- desc=(
- "Whether to include the loaded model weights in the initial event message. "
- "Useful when training restarts from an internal checkpoint format that "
- "which does not have an exported checkpoint for that step."
- ),
- hint=FieldHint.feature,
- )
-
- weights_ready_message_type: str = Field(
- default="weights_ready",
- desc="Message indicating that weights are ready to be broadcast.",
- hint=FieldHint.feature,
- )
-
- # NCCL rendezvous details
- rdvz_master_address: str | None = Field(
- default=None,
+class WeightsBroadcastConfig(Config):
+ # TODO: Have the external model send these instead?
+ host: str = Field(
+ default="localhost",
desc="Master address for the external NCCL process group.",
hint=FieldHint.feature,
)
-
- rdvz_master_port: int | None = Field(
- default=None,
+ port: int = Field(
+ default=23456,
desc="Master port for the external NCCL process group.",
hint=FieldHint.feature,
)
-
- world_size: int | None = Field(
- default=None,
+ external_world_size: int = Field(
+ default=1,
desc="World size of the external NCCL process group.",
hint=FieldHint.feature,
)
-
- rank: int | None = Field(
- default=None,
- desc="Rank of this process in the external NCCL process group.",
- hint=FieldHint.feature,
- )
-
-
-@config_class()
-class TrainingFinishedEventConfig(TrainerEvent):
- """
- Event sent to indicate that training has completed.
- """
-
- training_finished_message_type: str = Field(
- default="training_finished",
- desc="Message indicating that weights the training starting/ continuing from.",
+ backend: DistributedBackend = Field(
+ default=DistributedBackend.nccl,
+ desc="Backend for the external NCCL process group.",
hint=FieldHint.feature,
)
-@config_class()
-class TrainerEventsConfig(Config):
+@config_class(dynamic_type={TrainerCallbackConfig: "streaming"})
+class StreamingTrainerCallbackConfig(TrainerCallbackConfig, RedisConfig):
"""
Aggregates all trainer-side Redis-based event configurations.
"""
- redis: TrainerEventsRedisConfig = Field(
- desc="Redis connection and stream settings used to fetch incoming training data.",
+ broadcast: WeightsBroadcastConfig = Field(
+ desc="Configuration for signaling weight-ready events via Redis.",
hint=FieldHint.core,
)
- weights_broadcast: WeightsBroadcastEventConfig = Field(
- default=None,
- desc="Configuration for signaling weight-ready events via Redis.",
- hint=FieldHint.feature,
+ export: CheckpointStateSaveConfigBase = Field(
+ desc="Configuration for exporting checkpoints before broadcasting them.",
+ hint=FieldHint.core,
)
- training_finished: TrainingFinishedEventConfig = Field(
- default=None,
- desc="Configuration for signaling training-finished events via Redis.",
- hint=FieldHint.feature,
- )
+ def get_callback(self, model: "FastLLMModel") -> "StreamingTrainerCallback":
+ from fast_llm.engine.training.streaming import StreamingTrainerCallback
+
+ return StreamingTrainerCallback(self, model)
+
+ def setup(self, config: "TrainerConfig") -> None:
+ self.export.setup(config.model)
@config_class(registry=True, dynamic_type={RunnableConfig: "train"})
@@ -460,9 +416,9 @@ class TrainerConfig(PretrainedFastLLMModelConfig, ExperimentConfig):
hint=FieldHint.feature,
)
- events: TrainerEventsConfig = Field(
- default=None,
- desc="Optional Trainer event configurations (weight broadcast, training finished, etc.).",
+ callbacks: dict[str, TrainerCallbackConfig] = Field(
+ default_factory=dict,
+ desc="Configuration for training callbacks.",
hint=FieldHint.feature,
)
@@ -470,6 +426,9 @@ def _validate(self) -> None:
self.training.export.setup(self.model)
for reference_model in self.reference_models.values():
self._add_reference_distributed_to_pretrained(reference_model)
+ for callback in self.callbacks.values():
+ # We don't know anything about the callbacks, so we forward `self` and let them handle their own setup.
+ callback.setup(self)
super()._validate()
if self.reference_models:
# TODO: Add support.
@@ -517,3 +476,21 @@ def new_setup():
old_setup()
object.__setattr__(pretrained, "_setup", new_setup)
+
+
+class TrainerCallback[ConfigType: TrainerCallbackConfig](Configurable[ConfigType]):
+ # TODO: Make a more exhaustive set of events and arguments.
+ def run_begin(self, step: int):
+ pass
+
+ def step_end(
+ self,
+ step: int,
+ reduced_losses: dict[str, float | int],
+ update_successful: bool,
+ train_metrics: dict[str, typing.Any] | None,
+ ):
+ pass
+
+ def train_end(self, step: int):
+ pass
diff --git a/fast_llm/engine/training/streaming.py b/fast_llm/engine/training/streaming.py
new file mode 100644
index 000000000..9a8bbc723
--- /dev/null
+++ b/fast_llm/engine/training/streaming.py
@@ -0,0 +1,77 @@
+import json
+import logging
+import typing
+
+import torch.distributed
+
+from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel
+from fast_llm.engine.training.config import StreamingTrainerCallbackConfig, TrainerCallback
+
+logger = logging.getLogger(__name__)
+
+
+REDIS_TRAINING_STREAM = "fast_llm_events"
+REDIS_TRAINING_FIELD = "event"
+
+
+class StreamingTrainerCallback[ConfigType: StreamingTrainerCallbackConfig](TrainerCallback[ConfigType]):
+ def __init__(self, config: ConfigType, model: "FastLLMModel"):
+ super().__init__(config)
+ self._model = model
+ self._do_broadcast = self._model.config.distributed.rank == 0
+ if self._do_broadcast:
+ self._client = self._config.get_client()
+ init_method = f"tcp://{config.broadcast.host}:{config.broadcast.port}"
+ logger.info(f"Waiting for weights broadcast rendezvous at {init_method} ...")
+ # TODO: Create a custom process group instead.
+ self._process_group = torch.distributed.init_process_group(
+ backend="nccl",
+ init_method=init_method,
+ world_size=config.broadcast.external_world_size + 1,
+ rank=0,
+ )
+ logger.info(f"Weights broadcast rendezvous at {init_method} connected")
+
+ def run_begin(self, step: int):
+ # TODO: ====== Send a train / run begin signal? ======
+ self._broadcast_weights(step)
+
+ def step_end(
+ self,
+ step: int,
+ reduced_losses: dict[str, float | int],
+ update_successful: bool,
+ train_metrics: dict[str, typing.Any] | None,
+ ):
+ if update_successful:
+ self._broadcast_weights(step)
+
+ def train_end(self, step: int):
+ # TODO: ====== Send something on unsuccessful ends? ======
+ if self._do_broadcast:
+ self._client.xadd(REDIS_TRAINING_STREAM, {REDIS_TRAINING_FIELD: json.dumps({"type": "training_finished"})})
+ self._clear()
+
+ def __del__(self):
+ self._clear()
+
+ def _clear(self):
+ if hasattr(self, "_process_group"):
+ torch.distributed.destroy_process_group(self._process_group)
+ del self._process_group
+
+ def _broadcast_weights(self, step: int):
+ if self._do_broadcast:
+ self._client.xadd(
+ REDIS_TRAINING_STREAM, {REDIS_TRAINING_FIELD: json.dumps({"type": "weights_ready", "step": step})}
+ )
+ for shard_name, layer_name, tensor in self._model.iter_checkpoint(self._config.export, {}):
+ if self._do_broadcast:
+ # TODO: ====== Broadcast metadata in advance =======
+ meta = [(shard_name, layer_name, tensor.shape, tensor.dtype)]
+ torch.distributed.broadcast_object_list(meta, group=self._process_group, group_src=0)
+ torch.distributed.broadcast(tensor, group=self._process_group, group_src=0)
+ # Broadcast end of weights broadcast
+ if self._do_broadcast:
+ meta = [None]
+ torch.distributed.broadcast_object_list(meta, group=self._process_group, group_src=0)
diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py
index a2f98c05f..5ef625316 100644
--- a/fast_llm/engine/training/trainer.py
+++ b/fast_llm/engine/training/trainer.py
@@ -36,7 +36,6 @@
TrainingCheckpointConfig,
TrainingEvaluatorConfig,
)
-from fast_llm.engine.training.trainer_events import TrainerEvents
from fast_llm.engine.training.wandb import Wandb
from fast_llm.logging import format_metrics, log_memory_usage
from fast_llm.utils import Assert, Interrupter, get_and_reset_memory_usage_mib
@@ -132,8 +131,6 @@ def __init__(self, config: TrainerConfig):
self._is_evaluation_only = config.training.train_iters == 0
- self.trainer_events = TrainerEvents(config.events)
-
self._data = self._get_data()
log_main_rank("Creating model...")
self._multi_stage = self._config.model.get_model_class()(
@@ -154,6 +151,9 @@ def __init__(self, config: TrainerConfig):
distributed_config=self._config.model.distributed,
)
self._loss_definitions = self._multi_stage.base_model.get_loss_definitions()
+ self._callbacks = {
+ name: config.get_callback(self._multi_stage) for name, config in self._config.callbacks.items()
+ }
if not self._is_evaluation_only:
steps_per_split = {
@@ -289,7 +289,8 @@ def run(self) -> None:
assert self._is_setup
with self._wandb:
self._run_training()
- self.trainer_events.send_training_finished()
+ for callback in self._callbacks.values():
+ callback.train_end(self._completed_steps)
def _run_training(self) -> None:
self._prepare_training_state()
@@ -363,9 +364,8 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]:
# TODO: Synchronization is probably unnecessary.
safe_barrier(self._distributed.world_group, "train begin")
- self.trainer_events.send_initial_weights_step(
- self._completed_steps, self._multi_stage, self._config.training.export
- )
+ for callback in self._callbacks.values():
+ callback.run_begin(self._completed_steps)
torch.cuda.synchronize()
start_time = time.perf_counter()
@@ -393,13 +393,12 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]:
advanced_iters += 1
for name, value in reduced_losses.items():
total_losses[name] += value
- self.trainer_events.send_weights(
- self._completed_steps, self._multi_stage, self._config.training.export
- )
else:
skipped_iters += 1
nan_iters += not all(math.isfinite(loss) for loss in reduced_losses.values())
+ for callback in self._callbacks.values():
+ callback.step_end(self._completed_steps, reduced_losses, update_successful, train_metrics)
# Logging.
metrics = {}
if is_logging:
diff --git a/fast_llm/engine/training/trainer_events.py b/fast_llm/engine/training/trainer_events.py
deleted file mode 100644
index 8bce3e6de..000000000
--- a/fast_llm/engine/training/trainer_events.py
+++ /dev/null
@@ -1,105 +0,0 @@
-import logging
-
-import orjson
-import redis
-import torch.distributed
-
-from fast_llm.engine.config_utils.run import is_main_rank
-from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel
-from fast_llm.engine.training.config import TrainerEventsConfig, TrainerEventsRedisConfig, TrainingExportConfig
-
-logger = logging.getLogger(__name__)
-
-
-class RedisEventSender:
- def __init__(self, config: TrainerEventsRedisConfig):
- self.config = config
- self.client = None
-
- if is_main_rank():
- self.client = redis.Redis(
- host=config.host,
- port=config.port,
- )
-
- def send(self, msg_type: str, payload: dict | None = None):
- if not is_main_rank():
- return
-
- if not payload:
- payload = {}
- payload.update({"type": msg_type})
-
- self.client.xadd(self.config.stream_key, {self.config.payload_key: orjson.dumps(payload)})
-
-
-class TrainerEvents:
- """
- Main helper class holding all event channels.
- Each event may have its own RedisConfig.
-
- Usage:
- events = TrainerEvents(cfg.events)
- events.weights_broadcast.send({"step": 100})
- events.training_finished.send()
- """
-
- def __init__(self, config: TrainerEventsConfig):
- self.config = config
-
- if config.weights_broadcast.enabled or config.training_finished.enabled:
- self.sender = RedisEventSender(config.redis)
- else:
- self.sender = None
-
- if config.weights_broadcast.enabled and is_main_rank():
- init_method = (
- f"tcp://{config.weights_broadcast.rdvz_master_address}:{config.weights_broadcast.rdvz_master_port}"
- )
- logger.info(f"Waiting for weights broadcast rendezvous at {init_method} ...")
- self.weights_pg = torch.distributed.init_process_group(
- backend="nccl",
- init_method=init_method,
- world_size=config.weights_broadcast.world_size,
- rank=config.weights_broadcast.rank,
- )
- logger.info(f"Weights broadcast rendezvous at {init_method} connected")
- else:
- self.weights_pg = None
-
- def send_initial_weights_step(self, step: int, model: FastLLMModel, export_config: TrainingExportConfig):
- if self.config.weights_broadcast.enabled:
- self.sender.send(
- msg_type=self.config.weights_broadcast.initial_weights_step_message_type, payload={"step": step}
- )
- if self.config.weights_broadcast.initial_weights_step_message_includes_weights:
- self._broadcast_weights(model, export_config)
-
- def send_weights(self, step: int, model: FastLLMModel, export_config: TrainingExportConfig):
- if self.config.weights_broadcast.enabled:
- self.sender.send(msg_type=self.config.weights_broadcast.weights_ready_message_type, payload={"step": step})
- self._broadcast_weights(model, export_config)
-
- def send_training_finished(self):
- if self.config.training_finished.enabled:
- self.sender.send(msg_type=self.config.training_finished.training_finished_message_type)
-
- if is_main_rank() and self.config.weights_broadcast.enabled:
- torch.distributed.destroy_process_group()
-
- def _broadcast_weights(self, model: FastLLMModel, export_config: TrainingExportConfig):
- for shard_name, layer_name, tensor in model.iter_checkpoint(export_config.get_save_config("", 10), {}):
- if is_main_rank():
- meta = [(shard_name, layer_name, tensor.shape, tensor.dtype)]
- torch.distributed.broadcast_object_list(
- meta, group=self.weights_pg, group_src=self.config.weights_broadcast.rank
- )
- torch.distributed.broadcast(
- tensor, group=self.weights_pg, group_src=self.config.weights_broadcast.rank
- )
- # Broadcast end of weights broadcast
- if is_main_rank():
- meta = [None]
- torch.distributed.broadcast_object_list(
- meta, group=self.weights_pg, group_src=self.config.weights_broadcast.rank
- )
diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py
index 8c9ea9399..a12516b5d 100644
--- a/fast_llm/functional/cross_entropy.py
+++ b/fast_llm/functional/cross_entropy.py
@@ -227,6 +227,7 @@ def distributed_log_softmax(
return logits_norm - sum_exp_logits.log() # log_softmax
+@torch.compile
def _reverse_kl_forward_backward(
logits: torch.Tensor,
target: torch.Tensor,
@@ -259,24 +260,21 @@ def _reverse_kl_forward_backward(
if loss_mask is not None:
Assert.eq(loss_mask.shape, logits.shape[:-1])
- # Compute log probabilities
teacher_log_probs = distributed_log_softmax(target.float(), group=group)
- student_log_probs = distributed_log_softmax(logits, group=group)
-
- # Reverse KL: input=teacher_log_probs, target=student_probs
- loss_terms = torch.nn.functional.kl_div(
- teacher_log_probs, # input = log(p)
- student_log_probs, # target = log(q)
- reduction="none",
- log_target=True,
- ).sum(dim=-1)
+ log_ratio = distributed_log_softmax(logits, group=group)
+
+ student_probs = log_ratio.exp()
+ log_ratio = log_ratio - teacher_log_probs # In-place: log_ratio = student_log_probs - teacher_log_probs
+ del teacher_log_probs
+ # Compute loss terms: student_probs * log_ratio, then sum over vocab
+ # This is equivalent to kl_div(..., log_target=True) but more memory efficient
+ loss_terms = (student_probs * log_ratio).sum(dim=-1)
+
if loss_mask is not None:
# loss mask is the same on all ranks for TP over vocab.
valid = loss_mask.to(loss_terms.dtype)
loss_terms = loss_terms * valid
- valid_tokens = valid.sum()
- else:
- valid_tokens = torch.prod(torch.tensor(loss_terms.shape, device=loss_terms.device, dtype=loss_terms.dtype))
+ valid_tokens = torch.prod(torch.tensor(loss_terms.shape, device=loss_terms.device, dtype=loss_terms.dtype))
loss = loss_terms.sum() # sums over batch and seq. len.
if group is not None:
@@ -284,20 +282,20 @@ def _reverse_kl_forward_backward(
loss /= valid_tokens
if grad_output is not None:
- # need to calculate gradient manually, backprop through all reduce can be problematic, see https://github.com/pytorch/pytorch/issues/58005
- log_ratio = student_log_probs - teacher_log_probs
- expected = torch.sum(torch.exp(student_log_probs) * log_ratio, dim=-1, keepdim=True)
- # expected E_q(log s - log t) -- this is actually dependent on the full vocab!
+ # Gradient: d/d(logits) KL(q||p) = q * (log(q/p) - E_q[log(q/p)])
+ # where E_q[log(q/p)] is the expected log ratio under the student distribution
+ expected = torch.sum(student_probs * log_ratio, dim=-1, keepdim=True)
if group is not None:
all_reduce(expected, op=ReduceOp.SUM, group=group)
- grad_base = torch.exp(student_log_probs) * (log_ratio - expected)
+ log_ratio = log_ratio - expected
+ log_ratio = log_ratio * student_probs
+ del student_probs # Free after use
if loss_mask is not None:
- valid = loss_mask.to(logits.dtype).unsqueeze(-1)
- grad_base = grad_base * valid
+ log_ratio = log_ratio * loss_mask.to(logits.dtype).unsqueeze(-1)
- grad = grad_base.mul(grad_output / valid_tokens)
- grad = grad.to(logits.dtype)
+ log_ratio = log_ratio * (grad_output / valid_tokens)
+ grad = log_ratio.to(logits.dtype)
else:
grad = None
diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py
index 261d54025..b06f69ee5 100644
--- a/fast_llm/layers/block/config.py
+++ b/fast_llm/layers/block/config.py
@@ -45,6 +45,7 @@ class BlockKwargs:
device = "device"
hidden_states = "hidden_states"
output_hidden_states = "output_hidden_states"
+ activation_mask = "activation_mask"
@config_class(registry=True)
diff --git a/fast_llm/layers/decoder/block.py b/fast_llm/layers/decoder/block.py
index 148dabd5c..f5abd1f6d 100644
--- a/fast_llm/layers/decoder/block.py
+++ b/fast_llm/layers/decoder/block.py
@@ -136,9 +136,9 @@ def forward(
fw_input = input_
hidden_states = self.norm_1(input_)
self._debug(hidden_states, "norm_1", kwargs.get(BlockKwargs.hidden_dims), kwargs)
- hidden_states, bias = self.mixer(hidden_states, kwargs)
+ hidden_states, bias = self.mixer(hidden_states, kwargs, metrics=metrics)
- hidden_states, bias = self.activation_distillation_loss(hidden_states, bias, kwargs, losses)
+ hidden_states, bias = self.activation_distillation_loss(hidden_states, bias, kwargs, losses, metrics)
with set_generator(generator):
input_ = self._bias_dropout_add(hidden_states, bias, input_)
@@ -154,7 +154,7 @@ def forward(
hidden_states = torch.stack((fw_input, hidden_states), dim=0)
return hidden_states
- def activation_distillation_loss(self, hidden_states, bias, kwargs, losses):
+ def activation_distillation_loss(self, hidden_states, bias, kwargs, losses, metrics):
"""
Maybe apply activation distillation loss and setup backward hooks.
"""
@@ -178,26 +178,91 @@ def activation_distillation_loss(self, hidden_states, bias, kwargs, losses):
# L2 loss
activation_loss_factor = self._config.activation_distillation_factor
# (batch, sequence, hidden) or (sequence, batch, hidden). Take the norm over hidden dim.
- # TODO: handle possible padding?
- local_loss_sum = torch.sum(torch.norm(mixer_output - teacher_tensor, p=2, dim=(2)))
- # mixer_output.shape is (batch, sequence, hidden) or (sequence, batch, hidden)
- # In either case, dims 0 and 1 are batch and sequence
- total_count = mixer_output.shape[0] * mixer_output.shape[1]
+
+ # Handle possible padding by using pre-computed activation mask
+ sequence_first = kwargs.get(BlockKwargs.sequence_first, False)
+ activation_mask = kwargs.get(BlockKwargs.activation_mask)
+
+ if activation_mask is not None:
+ # Use pre-computed activation mask (bool tensor where True = valid token)
+ mask = activation_mask.to(dtype=mixer_output.dtype)
+ if sequence_first:
+ # (batch, sequence) -> (sequence, batch)
+ mask = mask.T
+
+ # Compute masked L2 loss: norm over hidden dim, then apply mask
+ per_token_loss = torch.norm(
+ mixer_output - teacher_tensor, p=2, dim=-1
+ ) # (batch, sequence) or (sequence, batch)
+
+ # Slice mask to match per_token_loss shape (for sequence parallelism)
+ # When sequence_tensor_parallel is enabled, per_token_loss only has local sequence length
+ if mask.shape != per_token_loss.shape:
+ # Calculate the sequence offset for this rank using the hidden_dims parallel rank
+ hidden_dims = kwargs.get(BlockKwargs.hidden_dims)
+ seq_dim_idx = 0 if sequence_first else 1
+ hidden_seq_dim = hidden_dims[seq_dim_idx] if hidden_dims else None
+
+ if hidden_seq_dim and hidden_seq_dim.parallel_dim:
+ # Use the rank from the actual parallel dimension used by hidden states
+ local_seq_length = per_token_loss.shape[0] if sequence_first else per_token_loss.shape[1]
+ seq_offset = hidden_seq_dim.parallel_dim.rank * local_seq_length
+ else:
+ seq_offset = 0
+
+ if sequence_first:
+ # mask: (sequence, batch), per_token_loss: (local_sequence, batch)
+ mask = mask[seq_offset : seq_offset + per_token_loss.shape[0], :]
+ else:
+ # mask: (batch, sequence), per_token_loss: (batch, local_sequence)
+ mask = mask[:, seq_offset : seq_offset + per_token_loss.shape[1]]
+
+ masked_loss = per_token_loss * mask
+ local_loss_sum = torch.sum(masked_loss)
+ total_count = int(mask.sum().item())
+ else:
+ # No activation_mask available, compute loss on all tokens
+ per_token_loss = torch.norm(
+ mixer_output - teacher_tensor, p=2, dim=-1
+ ) # (batch, sequence) or (sequence, batch)
+ local_loss_sum = torch.sum(per_token_loss)
+ # mixer_output.shape is (batch, sequence, hidden) or (sequence, batch, hidden)
+ # In either case, dims 0 and 1 are batch and sequence
+ total_count = mixer_output.shape[0] * mixer_output.shape[1]
# All-reduce across tensor-parallel group if sequence-parallel is enabled
if self._sequence_parallel and self._distributed.tensor_group is not None:
all_reduce(local_loss_sum, group=self._distributed.tensor_group, op=ReduceOp.SUM)
- # Assume all ranks contribute the same count (not the case if padding)
- total_count *= self._distributed.tensor_group.size()
+ if activation_mask is not None:
+ # Different ranks may have different amounts of padding
+ total_count_tensor = torch.tensor(total_count, device=mixer_output.device, dtype=torch.int64)
+ all_reduce(total_count_tensor, group=self._distributed.tensor_group, op=ReduceOp.SUM)
+ total_count = int(total_count_tensor.item())
+ else:
+ # All ranks contribute the same count
+ total_count *= self._distributed.tensor_group.size()
- activation_loss = activation_loss_factor * (local_loss_sum / total_count)
+ activation_loss = local_loss_sum / total_count
+ scaled_activation_loss = activation_loss_factor * activation_loss
# Backward hooks
- hidden_states = AuxiliaryLoss.apply(hidden_states, activation_loss, 1.0)
- bias = AuxiliaryLoss.apply(bias, activation_loss, 1.0) if bias is not None else None
+ hidden_states = AuxiliaryLoss.apply(hidden_states, scaled_activation_loss, 1.0)
+ bias = AuxiliaryLoss.apply(bias, scaled_activation_loss, 1.0) if bias is not None else None
# Logging
if losses is not None and self._activation_distillation_loss_name in losses:
losses[self._activation_distillation_loss_name].append(activation_loss.detach())
+ # Per-layer metrics
+ if metrics is not None:
+ metrics[f"{self.module_name}/activation_distillation_loss"] = activation_loss.detach()
+
+ # If using stochastic mixer, also log per-mixer-type activation distillation loss
+ from fast_llm.layers.decoder.stochastic_mixer import StochasticMixer
+
+ if isinstance(self.mixer, StochasticMixer):
+ selected_mixer = self.mixer._last_selected_mixer
+ metrics[f"{self.module_name}/activation_distillation_loss/{selected_mixer}"] = (
+ activation_loss.detach()
+ )
return hidden_states, bias
def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int:
diff --git a/fast_llm/layers/decoder/stochastic_mixer.py b/fast_llm/layers/decoder/stochastic_mixer.py
index 673c64034..984f34b80 100644
--- a/fast_llm/layers/decoder/stochastic_mixer.py
+++ b/fast_llm/layers/decoder/stochastic_mixer.py
@@ -94,6 +94,10 @@ def __init__(
if hasattr(param, "allow_no_grad"):
param.allow_no_grad = True
+ # Track mixer selection counts for logging actual proportions during training
+ self._mixer_counts_total = {name: 0 for name in self.mixers.keys()}
+ self._last_selected_mixer = None
+
def setup(self, distributed: Distributed) -> None:
"""Setup all mixers with the distributed context."""
super().setup(distributed)
@@ -117,6 +121,24 @@ def _forward(
) -> tuple[torch.Tensor, torch.Tensor | None]:
mixer_name = self._sample_mixer_name(kwargs)
+ if self.training:
+ self._mixer_counts_total[mixer_name] += 1
+ self._last_selected_mixer = mixer_name
+
+ if metrics is not None:
+ # Use module_name as prefix to distinguish between different layer indices
+ metric_prefix = f"{self.module_name}/stochastic"
+
+ # Instantaneous metric: last selected mixer
+ metrics[f"{metric_prefix}/last_selected"] = mixer_name
+
+ # Cumulative metrics: total counts and proportions
+ total_count = sum(self._mixer_counts_total.values())
+ for name, count in self._mixer_counts_total.items():
+ proportion = count / total_count if total_count > 0 else 0.0
+ metrics[f"{metric_prefix}/{name}_count_total"] = count
+ metrics[f"{metric_prefix}/{name}_proportion_total"] = proportion
+
if get_model_debug_level() > 0:
from fast_llm.layers.block.config import BlockKwargs
diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py
index fc8794b5a..93850d24c 100644
--- a/fast_llm/layers/language_model/embedding.py
+++ b/fast_llm/layers/language_model/embedding.py
@@ -106,6 +106,9 @@ def _forward(
if self._sequence_parallel:
embeddings = split(embeddings, group=group, dim=0)
+ if isinstance(group, torch.distributed.ProcessGroupGloo):
+ # Somehow needed is some rare cases to prevent autograd from complaining, ex. in `stp2_pp2s1_bf4`.
+ embeddings = embeddings.clone()
else:
if self._sequence_parallel:
token_ids = split(token_ids, group=group, dim=0)
diff --git a/fast_llm/models/gpt/conversion/apriel2.py b/fast_llm/models/gpt/conversion/apriel2.py
index 4ed588ed5..91e3be508 100644
--- a/fast_llm/models/gpt/conversion/apriel2.py
+++ b/fast_llm/models/gpt/conversion/apriel2.py
@@ -39,8 +39,20 @@ def import_config(cls, config: dict) -> dict:
"head_groups": config["head_groups"],
"head_size": config["head_size"],
"rotary": rotary,
- "add_linear_biases": config["add_linear_biases"],
}
+ # Per-layer bias configuration mirroring Fast-LLM structure
+ # If per-layer configs exist, use them; otherwise fall back to add_linear_biases
+ if "query_layer" in config:
+ result["query_layer"] = config["query_layer"]
+ if "key_layer" in config:
+ result["key_layer"] = config["key_layer"]
+ if "value_layer" in config:
+ result["value_layer"] = config["value_layer"]
+ if "dense_layer" in config:
+ result["dense_layer"] = config["dense_layer"]
+ # add_linear_biases serves as default for layers without explicit config
+ if "add_linear_biases" in config:
+ result["add_linear_biases"] = config["add_linear_biases"]
if "window_size" in config:
result["window_size"] = config["window_size"]
return result
@@ -58,18 +70,37 @@ def export_config(cls, config: AttentionConfig) -> dict:
else:
raise NotImplementedError(f"Unsupported rotary type: {type(config.rotary).__name__}")
- return {
+ result = {
"type": "attention",
"heads": config.heads,
"head_groups": config.head_groups,
"head_size": config.head_size,
- "add_linear_biases": config.add_linear_biases,
"rotary": {
"type": rotary_type,
"theta": config.rotary.theta,
},
"window_size": config.window_size,
}
+ # Export per-layer bias configuration
+ # Only include if explicitly set (not None)
+ if config.query_layer.bias.enabled is not None:
+ result["query_layer"] = {"bias": {"enabled": config.query_layer.bias.enabled}}
+ if config.key_layer.bias.enabled is not None:
+ result["key_layer"] = {"bias": {"enabled": config.key_layer.bias.enabled}}
+ if config.value_layer.bias.enabled is not None:
+ result["value_layer"] = {"bias": {"enabled": config.value_layer.bias.enabled}}
+ if config.dense_layer.bias.enabled is not None:
+ result["dense_layer"] = {"bias": {"enabled": config.dense_layer.bias.enabled}}
+ # add_linear_biases as fallback default
+ result["add_linear_biases"] = config.add_linear_biases
+ return result
+
+ @classmethod
+ def _get_effective_bias(cls, layer_config, default: bool) -> bool:
+ """Get effective bias setting: use layer-specific if set, else default."""
+ if layer_config.bias.enabled is not None:
+ return layer_config.bias.enabled
+ return default
@classmethod
def get_converters(
@@ -79,11 +110,20 @@ def get_converters(
hf_prefix: str,
drop_on_export: bool = False,
) -> list[WeightConverter]:
+ # Determine effective bias for each projection
+ q_bias = cls._get_effective_bias(config.query_layer, config.add_linear_biases)
+ k_bias = cls._get_effective_bias(config.key_layer, config.add_linear_biases)
+ v_bias = cls._get_effective_bias(config.value_layer, config.add_linear_biases)
+ o_bias = cls._get_effective_bias(config.dense_layer, config.add_linear_biases)
+ # For key_value, both k and v must have same bias setting
+ # (they're combined in Fast-LLM's key_value layer)
+ kv_bias = k_bias and v_bias
+
return [
*get_weight_and_bias_converters(
f"{fast_llm_prefix}.query",
f"{hf_prefix}.q_proj",
- config.add_linear_biases,
+ q_bias,
QueryWeightConverter,
config,
drop_on_export=drop_on_export,
@@ -91,7 +131,7 @@ def get_converters(
*get_weight_and_bias_converters(
f"{fast_llm_prefix}.key_value",
(f"{hf_prefix}.k_proj", f"{hf_prefix}.v_proj"),
- config.add_linear_biases,
+ kv_bias,
KeyValueWeightConverter,
config,
drop_on_export=drop_on_export,
@@ -99,7 +139,7 @@ def get_converters(
*get_weight_and_bias_converters(
f"{fast_llm_prefix}.dense",
f"{hf_prefix}.o_proj",
- config.add_linear_biases,
+ o_bias,
drop_on_export=drop_on_export,
),
]
@@ -524,6 +564,12 @@ def import_config(cls, config: dict, block_config: dict) -> dict:
"gated": mlp_config["gated"],
"add_linear_biases": mlp_config["add_linear_biases"],
}
+ # Import per-layer MLP bias settings (layer_1, layer_2)
+ for layer_name in ("layer_1", "layer_2"):
+ if layer_name in mlp_config:
+ layer_cfg = mlp_config[layer_name]
+ if "bias" in layer_cfg:
+ mlp[layer_name] = {"bias": layer_cfg["bias"]}
normalization = block_config["normalization"]
@@ -578,6 +624,11 @@ def export_config(cls, config: DecoderBlockConfig) -> dict:
"gated": config.mlp.gated,
"add_linear_biases": config.mlp.add_linear_biases,
}
+ # Export per-layer MLP bias settings (layer_1, layer_2)
+ if config.mlp.layer_1.bias.enabled is not None:
+ mlp["layer_1"] = {"bias": {"enabled": config.mlp.layer_1.bias.enabled}}
+ if config.mlp.layer_2.bias.enabled is not None:
+ mlp["layer_2"] = {"bias": {"enabled": config.mlp.layer_2.bias.enabled}}
normalization = {"type": norm_type_str, "epsilon": config.normalization.epsilon}
@@ -624,24 +675,56 @@ def get_converters(
)
)
- converters.extend(
- [
- *get_weight_and_bias_converters(
- f"{fast_llm_prefix}.mlp.layer_1",
- (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"),
- config.mlp.add_linear_biases,
- SplitWeightConverter,
- drop_on_export=drop_on_export,
- ),
- *get_weight_and_bias_converters(
- f"{fast_llm_prefix}.mlp.layer_2",
- f"{hf_prefix}.mlp.down_proj",
- config.mlp.add_linear_biases,
- MLPLayer2Converter,
- drop_on_export=drop_on_export,
- ),
- ]
- )
+ # Per-layer MLP bias: use layer-specific setting if set, else default
+ def get_mlp_layer_bias(layer_config, default: bool) -> bool:
+ if layer_config.bias.enabled is not None:
+ return layer_config.bias.enabled
+ return default
+
+ layer_1_bias = get_mlp_layer_bias(config.mlp.layer_1, config.mlp.add_linear_biases)
+ layer_2_bias = get_mlp_layer_bias(config.mlp.layer_2, config.mlp.add_linear_biases)
+
+ if config.mlp.gated:
+ # Gated MLP: gate_proj + up_proj -> layer_1 (split), down_proj -> layer_2
+ converters.extend(
+ [
+ *get_weight_and_bias_converters(
+ f"{fast_llm_prefix}.mlp.layer_1",
+ (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"),
+ layer_1_bias,
+ SplitWeightConverter,
+ drop_on_export=drop_on_export,
+ ),
+ *get_weight_and_bias_converters(
+ f"{fast_llm_prefix}.mlp.layer_2",
+ f"{hf_prefix}.mlp.down_proj",
+ layer_2_bias,
+ MLPLayer2Converter,
+ drop_on_export=drop_on_export,
+ ),
+ ]
+ )
+ else:
+ # Non-gated MLP: up_proj -> layer_1, down_proj -> layer_2
+ # Note: layer_2 still needs MLPLayer2Converter for the transpose
+ converters.extend(
+ [
+ *get_weight_and_bias_converters(
+ f"{fast_llm_prefix}.mlp.layer_1",
+ f"{hf_prefix}.mlp.up_proj",
+ layer_1_bias,
+ WeightConverter,
+ drop_on_export=drop_on_export,
+ ),
+ *get_weight_and_bias_converters(
+ f"{fast_llm_prefix}.mlp.layer_2",
+ f"{hf_prefix}.mlp.down_proj",
+ layer_2_bias,
+ MLPLayer2Converter,
+ drop_on_export=drop_on_export,
+ ),
+ ]
+ )
converters.extend(
[
diff --git a/fast_llm/models/gpt/conversion/qwen2.py b/fast_llm/models/gpt/conversion/qwen2.py
index 57c9614bd..4ebf18c3a 100644
--- a/fast_llm/models/gpt/conversion/qwen2.py
+++ b/fast_llm/models/gpt/conversion/qwen2.py
@@ -1,10 +1,12 @@
import typing
from fast_llm.engine.checkpoint.config import CheckpointFormat
+from fast_llm.engine.checkpoint.external import WeightConverter
from fast_llm.layers.attention.config import AttentionConfig
from fast_llm.layers.decoder.mlp.config import MLPConfig
from fast_llm.models.gpt.conversion.config import Qwen2CheckpointFormat
from fast_llm.models.gpt.conversion.llama import (
+ KeyValueWeightConverter,
LlamaAttentionConverter,
LlamaBaseModelConverter,
LlamaBlockConverter,
@@ -12,6 +14,8 @@
LlamaHeadConverter,
LlamaHuggingfaceCheckpointHandler,
LlamaMLPConverter,
+ QueryWeightConverter,
+ get_weight_and_bias_converters,
)
from fast_llm.utils import Assert
@@ -50,6 +54,39 @@ def _check_config(cls, config: AttentionConfig) -> None:
Assert.is_(config.value_layer.bias.enabled, True)
Assert.incl(config.dense_layer.bias.enabled, (None, False))
+ @classmethod
+ def get_converters(
+ cls,
+ config: AttentionConfig,
+ fast_llm_prefix: str,
+ hf_prefix: str,
+ drop_on_export: bool = False,
+ ) -> list[WeightConverter]:
+ return [
+ *get_weight_and_bias_converters(
+ f"{fast_llm_prefix}.query",
+ f"{hf_prefix}.q_proj",
+ True,
+ QueryWeightConverter,
+ config,
+ drop_on_export=drop_on_export,
+ ),
+ *get_weight_and_bias_converters(
+ f"{fast_llm_prefix}.key_value",
+ (f"{hf_prefix}.k_proj", f"{hf_prefix}.v_proj"),
+ True,
+ KeyValueWeightConverter,
+ config,
+ drop_on_export=drop_on_export,
+ ),
+ *get_weight_and_bias_converters(
+ f"{fast_llm_prefix}.dense",
+ f"{hf_prefix}.o_proj",
+ False,
+ drop_on_export=drop_on_export,
+ ),
+ ]
+
class Qwen2MLPConverter(LlamaMLPConverter):
@classmethod
diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py
index 64e7f1cbd..2f43d1e41 100644
--- a/fast_llm/models/gpt/model.py
+++ b/fast_llm/models/gpt/model.py
@@ -217,12 +217,37 @@ def preprocess_batch(
pasts = presents
presents = None if i == len(preprocessed_meta) - 1 else []
+ # Create activation mask for activation distillation
+ # This mask should:
+ # - Be 0 on padding tokens (added at the end when documents aren't truncated)
+ # - Be 1 on image placeholder tokens (token value -100 but not padding)
+ # - Be 1 on all other valid tokens (ignores loss-masking-spans)
+ #
+ # Note: Padding is added as a separate document with all tokens = -100
+ # We detect padding by checking if all tokens in a document segment are -100
+ activation_mask = torch.ones_like(cropped_tokens.tokens, dtype=torch.bool)
+
+ for sample_index, sample_lengths in enumerate(cropped_tokens.lengths):
+ # Iterate through documents in this sample
+ pos = 0
+ for doc_length in sample_lengths:
+ # Check if this document is padding (all tokens are -100)
+ doc_tokens = cropped_tokens.tokens[sample_index, pos : pos + doc_length]
+ is_padding_doc = torch.all(doc_tokens == -100).item()
+
+ if is_padding_doc:
+ # This is a padding document, mask it out
+ activation_mask[sample_index, pos : pos + doc_length] = False
+
+ pos += doc_length
+
kwargs: dict[str, typing.Any] = {
**kwargs_meta,
AttentionKwargs.past_key_values: pasts,
AttentionKwargs.presents: presents,
BlockKwargs.iteration: iteration,
AttentionKwargs.sequence_lengths: cropped_tokens.lengths,
+ BlockKwargs.activation_mask: activation_mask,
AttentionKwargs.device: self._distributed.device,
BlockKwargs.hidden_states: {},
**reference_logits[i],
diff --git a/fast_llm/models/multimodal/conversion/apriel2.py b/fast_llm/models/multimodal/conversion/apriel2.py
index b4147a8bf..307a67c63 100644
--- a/fast_llm/models/multimodal/conversion/apriel2.py
+++ b/fast_llm/models/multimodal/conversion/apriel2.py
@@ -326,9 +326,7 @@ class Apriel2MultimodalBaseModelConverter:
@classmethod
def import_config(cls, config: dict) -> dict:
text_config = Apriel2BaseModelConverter.import_config(config)
- vision_config = (
- cls.vision_model_converter_class.import_config(config) if "vision_encoder" in config else None
- )
+ vision_config = cls.vision_model_converter_class.import_config(config) if "vision_encoder" in config else None
result = safe_merge_dicts(
text_config,
@@ -388,10 +386,7 @@ def get_transformers_configuration_class(cls):
@classmethod
def get_model_files(cls) -> tuple[str, str, str | None]:
- from fast_llm_external_models.apriel2 import (
- configuration_apriel2,
- modeling_apriel2,
- )
+ from fast_llm_external_models.apriel2 import configuration_apriel2, modeling_apriel2
return configuration_apriel2.__file__, modeling_apriel2.__file__, None
diff --git a/fast_llm/redis/config.py b/fast_llm/redis/config.py
index 5b6bfbddd..e69de29bb 100644
--- a/fast_llm/redis/config.py
+++ b/fast_llm/redis/config.py
@@ -1,28 +0,0 @@
-from fast_llm.config import Config, Field, FieldHint, config_class
-
-
-@config_class()
-class RedisConfig(Config):
- host: str = Field(
- default="localhost",
- desc="Hostname or IP address of the Redis server.",
- hint=FieldHint.core,
- )
-
- port: int = Field(
- default=6379,
- desc="Port number on which the Redis server is running.",
- hint=FieldHint.core,
- )
-
- stream_key: str = Field(
- default=None,
- desc="Name of the Redis stream to read data from.",
- hint=FieldHint.core,
- )
-
- payload_key: str = Field(
- default=None,
- desc="Key under which the message data is stored inside the Redis payload dict.",
- hint=FieldHint.core,
- )
diff --git a/fast_llm/utils.py b/fast_llm/utils.py
index 2ca61aa0e..fa4ea4c2f 100644
--- a/fast_llm/utils.py
+++ b/fast_llm/utils.py
@@ -167,7 +167,7 @@ def rms_close_relative(x, y, threshold, min_threshold=0, *, msg=None):
)
@staticmethod
- def all_equal(x, *args):
+ def all_equal(x, *args, msg=None):
import torch
# Make it work for lists and numpy arrays.
@@ -181,7 +181,9 @@ def all_equal(x, *args):
index = None if x.numel() == 1 else torch.where(neq) # noqa
raise AssertionError(
f"Tensors have {index[0].numel()} different entries out of "
- f"{x.numel()}: {x[index]} != {arg[index]} at index {torch.stack(index, -1)}"
+ f"{x.numel()}: {x[index]} != {arg[index]} at index {torch.stack(index, -1)}" + ""
+ if msg is None
+ else f"| {msg}"
)
@staticmethod
diff --git a/fast_llm_external_models/apriel2/cache.py b/fast_llm_external_models/apriel2/cache.py
index 86c67a085..f83ae87d6 100644
--- a/fast_llm_external_models/apriel2/cache.py
+++ b/fast_llm_external_models/apriel2/cache.py
@@ -1,17 +1,22 @@
from __future__ import annotations
+
import torch
from transformers.cache_utils import Cache
class _AttentionCache:
- __slots__ = ["key", "value", "window"]
+ __slots__ = ["key", "value", "window", "cumulative_length"]
def __init__(self, window=None):
self.key = None
self.value = None
self.window = window
+ self.cumulative_length = 0
def update(self, key, value):
+ new_tokens = key.shape[-2]
+ self.cumulative_length += new_tokens
+
if self.key is None:
if self.window and key.shape[-2] > self.window:
self.key = key[..., -self.window :, :].contiguous()
@@ -35,6 +40,40 @@ def _window(self, cache, new):
return cache
return torch.cat([cache, new], -2)[..., -self.window :, :].contiguous()
+ def reset(self):
+ self.key = None
+ self.value = None
+ self.cumulative_length = 0
+
+ def reorder(self, beam_idx):
+ if self.key is not None:
+ self.key = self.key.index_select(0, beam_idx.to(self.key.device))
+ self.value = self.value.index_select(0, beam_idx.to(self.value.device))
+
+ def crop(self, max_length):
+ if self.key is not None:
+ self.key = self.key[..., :max_length, :]
+ self.value = self.value[..., :max_length, :]
+ self.cumulative_length = self.key.shape[-2]
+
+ def batch_repeat(self, repeats):
+ if self.key is not None:
+ self.key = self.key.repeat_interleave(repeats, dim=0)
+ self.value = self.value.repeat_interleave(repeats, dim=0)
+
+ def batch_select(self, indices):
+ if self.key is not None:
+ self.key = self.key.index_select(0, indices.to(self.key.device))
+ self.value = self.value.index_select(0, indices.to(self.value.device))
+
+ @property
+ def is_initialized(self):
+ return self.key is not None
+
+ @property
+ def batch_size(self):
+ return self.key.shape[0] if self.key is not None else None
+
class _SSMCache:
__slots__ = ["conv", "recurrent"]
@@ -43,6 +82,52 @@ def __init__(self):
self.conv = None
self.recurrent = None
+ def reset(self):
+ self.conv = None
+ self.recurrent = None
+
+ def reorder(self, beam_idx):
+ if self.conv is not None:
+ if isinstance(self.conv, tuple):
+ self.conv = tuple(c.index_select(0, beam_idx.to(c.device)) for c in self.conv)
+ else:
+ self.conv = self.conv.index_select(0, beam_idx.to(self.conv.device))
+ if self.recurrent is not None:
+ self.recurrent = self.recurrent.index_select(0, beam_idx.to(self.recurrent.device))
+
+ def crop(self, max_length):
+ pass # SSM caches don't have sequence dimension to crop
+
+ def batch_repeat(self, repeats):
+ if self.conv is not None:
+ if isinstance(self.conv, tuple):
+ self.conv = tuple(c.repeat_interleave(repeats, dim=0) for c in self.conv)
+ else:
+ self.conv = self.conv.repeat_interleave(repeats, dim=0)
+ if self.recurrent is not None:
+ self.recurrent = self.recurrent.repeat_interleave(repeats, dim=0)
+
+ def batch_select(self, indices):
+ if self.conv is not None:
+ if isinstance(self.conv, tuple):
+ self.conv = tuple(c.index_select(0, indices.to(c.device)) for c in self.conv)
+ else:
+ self.conv = self.conv.index_select(0, indices.to(self.conv.device))
+ if self.recurrent is not None:
+ self.recurrent = self.recurrent.index_select(0, indices.to(self.recurrent.device))
+
+ @property
+ def is_initialized(self):
+ return self.conv is not None
+
+ @property
+ def batch_size(self):
+ if self.conv is None:
+ return None
+ if isinstance(self.conv, tuple):
+ return self.conv[0].shape[0]
+ return self.conv.shape[0]
+
class _DummyCacheLayer:
pass
@@ -93,14 +178,19 @@ def set_active_mixer(self, layer_idx, mixer_name):
self.active_mixers[layer_idx] = mixer_name
def get_seq_length(self, layer_idx=0):
+ """Returns the cumulative sequence length of tokens seen by the cache.
+
+ For sliding window caches, this returns the total tokens seen (not just cached).
+ This matches HuggingFace's DynamicSlidingWindowLayer behavior.
+ """
layer = self.layers[layer_idx]
if isinstance(layer, dict):
mixer = self.active_mixers[layer_idx]
if mixer and isinstance(layer[mixer], _AttentionCache):
- return layer[mixer].key.shape[-2] if layer[mixer].key is not None else 0
+ return layer[mixer].cumulative_length
return 0
if isinstance(layer, _AttentionCache):
- return layer.key.shape[-2] if layer.key is not None else 0
+ return layer.cumulative_length
return 0
def get_max_cache_shape(self, layer_idx=0):
@@ -114,22 +204,61 @@ def get_max_cache_shape(self, layer_idx=0):
return None
def get_mask_sizes(self, cache_position, layer_idx):
+ """Return the length and offset of the cache, used to generate the attention mask.
+
+ For standard (non-sliding) attention:
+ kv_offset = 0 (KV[0] corresponds to sequence position 0)
+ kv_length = cumulative_length + query_length
+
+ For sliding window attention:
+ kv_offset = max(cumulative_length - window + 1, 0)
+ kv_length = min(cumulative_length, window - 1) + query_length
+
+ For SSM/linear layers:
+ kv_offset = 0, kv_length = query_length (no KV cache to attend to)
+ """
query_length = cache_position.shape[0]
- past_seen_tokens = self.get_seq_length(layer_idx)
- kv_length = query_length + past_seen_tokens
- kv_offset = past_seen_tokens
- return kv_length, kv_offset
+ layer = self.layers[layer_idx]
+
+ # Handle stochastic layers by getting the active mixer's cache
+ if isinstance(layer, dict):
+ mixer = self.active_mixers[layer_idx]
+ if mixer is None:
+ # No active mixer set, return defaults
+ return query_length, 0
+ cache = layer[mixer]
+ else:
+ cache = layer
+
+ # SSM layers don't have KV cache for attention mask purposes
+ if isinstance(cache, _SSMCache):
+ return query_length, 0
+
+ # Attention cache - check if sliding window
+ if isinstance(cache, _AttentionCache):
+ cumulative = cache.cumulative_length
+ window = cache.window
+
+ if window is not None:
+ # Sliding window attention
+ kv_offset = max(cumulative - window + 1, 0)
+ if cumulative >= window:
+ kv_length = window - 1 + query_length
+ else:
+ kv_length = cumulative + query_length
+ else:
+ # Full attention
+ kv_offset = 0
+ kv_length = cumulative + query_length
+
+ return kv_length, kv_offset
+
+ # Fallback
+ return query_length, 0
@property
def has_previous_state(self):
- for layer in self.layers:
- if isinstance(layer, dict):
- for cache in layer.values():
- if isinstance(cache, _SSMCache) and cache.conv is not None:
- return True
- elif isinstance(layer, _SSMCache) and layer.conv is not None:
- return True
- return False
+ return any(isinstance(cache, _SSMCache) and cache.conv is not None for cache in self._iter_caches())
@property
def key_cache(self):
@@ -147,101 +276,33 @@ def conv_states(self):
def recurrent_states(self):
return _LayerListAccessor(self, "recurrent")
- def reorder_cache(self, beam_idx):
- for i, layer in enumerate(self.layers):
+ def _iter_caches(self):
+ """Iterate over all leaf cache objects (flattening stochastic layer dicts)."""
+ for layer in self.layers:
if isinstance(layer, dict):
- for cache in layer.values():
- self._reorder_cache_obj(cache, beam_idx)
+ yield from layer.values()
else:
- self._reorder_cache_obj(layer, beam_idx)
+ yield layer
- def _reorder_cache_obj(self, cache, beam_idx):
- if isinstance(cache, _AttentionCache):
- if cache.key is not None:
- cache.key = cache.key.index_select(0, beam_idx.to(cache.key.device))
- cache.value = cache.value.index_select(0, beam_idx.to(cache.value.device))
- elif isinstance(cache, _SSMCache):
- if cache.conv is not None:
- # Handle both single tensor (GDN/Mamba) and tuple (KDA) conv states
- if isinstance(cache.conv, tuple):
- cache.conv = tuple(c.index_select(0, beam_idx.to(c.device)) for c in cache.conv)
- else:
- cache.conv = cache.conv.index_select(0, beam_idx.to(cache.conv.device))
- if cache.recurrent is not None:
- cache.recurrent = cache.recurrent.index_select(0, beam_idx.to(cache.recurrent.device))
+ def reorder_cache(self, beam_idx):
+ for cache in self._iter_caches():
+ cache.reorder(beam_idx)
def reset(self):
- for layer in self.layers:
- if isinstance(layer, dict):
- for cache in layer.values():
- self._reset_cache_obj(cache)
- else:
- self._reset_cache_obj(layer)
-
- def _reset_cache_obj(self, cache):
- if isinstance(cache, _AttentionCache):
- cache.key = None
- cache.value = None
- elif isinstance(cache, _SSMCache):
- cache.conv = None
- cache.recurrent = None
+ for cache in self._iter_caches():
+ cache.reset()
def crop(self, max_length):
- for layer in self.layers:
- if isinstance(layer, dict):
- for cache in layer.values():
- if isinstance(cache, _AttentionCache) and cache.key is not None:
- cache.key = cache.key[..., :max_length, :]
- cache.value = cache.value[..., :max_length, :]
- elif isinstance(layer, _AttentionCache) and layer.key is not None:
- layer.key = layer.key[..., :max_length, :]
- layer.value = layer.value[..., :max_length, :]
+ for cache in self._iter_caches():
+ cache.crop(max_length)
def batch_repeat_interleave(self, repeats):
- for layer in self.layers:
- if isinstance(layer, dict):
- for cache in layer.values():
- self._batch_repeat_cache_obj(cache, repeats)
- else:
- self._batch_repeat_cache_obj(layer, repeats)
-
- def _batch_repeat_cache_obj(self, cache, repeats):
- if isinstance(cache, _AttentionCache):
- if cache.key is not None:
- cache.key = cache.key.repeat_interleave(repeats, dim=0)
- cache.value = cache.value.repeat_interleave(repeats, dim=0)
- elif isinstance(cache, _SSMCache):
- if cache.conv is not None:
- # Handle both single tensor (GDN/Mamba) and tuple (KDA) conv states
- if isinstance(cache.conv, tuple):
- cache.conv = tuple(c.repeat_interleave(repeats, dim=0) for c in cache.conv)
- else:
- cache.conv = cache.conv.repeat_interleave(repeats, dim=0)
- if cache.recurrent is not None:
- cache.recurrent = cache.recurrent.repeat_interleave(repeats, dim=0)
+ for cache in self._iter_caches():
+ cache.batch_repeat(repeats)
def batch_select_indices(self, indices):
- for layer in self.layers:
- if isinstance(layer, dict):
- for cache in layer.values():
- self._batch_select_cache_obj(cache, indices)
- else:
- self._batch_select_cache_obj(layer, indices)
-
- def _batch_select_cache_obj(self, cache, indices):
- if isinstance(cache, _AttentionCache):
- if cache.key is not None:
- cache.key = cache.key.index_select(0, indices.to(cache.key.device))
- cache.value = cache.value.index_select(0, indices.to(cache.value.device))
- elif isinstance(cache, _SSMCache):
- if cache.conv is not None:
- # Handle both single tensor (GDN/Mamba) and tuple (KDA) conv states
- if isinstance(cache.conv, tuple):
- cache.conv = tuple(c.index_select(0, indices.to(c.device)) for c in cache.conv)
- else:
- cache.conv = cache.conv.index_select(0, indices.to(cache.conv.device))
- if cache.recurrent is not None:
- cache.recurrent = cache.recurrent.index_select(0, indices.to(cache.recurrent.device))
+ for cache in self._iter_caches():
+ cache.batch_select(indices)
@property
def is_compileable(self):
@@ -249,19 +310,7 @@ def is_compileable(self):
@property
def is_initialized(self):
- for layer in self.layers:
- if isinstance(layer, dict):
- for cache in layer.values():
- if isinstance(cache, _AttentionCache) and cache.key is not None:
- return True
- if isinstance(cache, _SSMCache) and cache.conv is not None:
- return True
- else:
- if isinstance(layer, _AttentionCache) and layer.key is not None:
- return True
- if isinstance(layer, _SSMCache) and layer.conv is not None:
- return True
- return False
+ return any(cache.is_initialized for cache in self._iter_caches())
@property
def is_sliding(self):
@@ -280,39 +329,20 @@ def is_sliding(self):
@property
def max_batch_size(self):
- for layer in self.layers:
- if isinstance(layer, dict):
- for cache in layer.values():
- if isinstance(cache, _AttentionCache) and cache.key is not None:
- return cache.key.shape[0]
- if isinstance(cache, _SSMCache) and cache.conv is not None:
- # Handle both single tensor and tuple conv states
- if isinstance(cache.conv, tuple):
- return cache.conv[0].shape[0]
- return cache.conv.shape[0]
- else:
- if isinstance(layer, _AttentionCache) and layer.key is not None:
- return layer.key.shape[0]
- if isinstance(layer, _SSMCache) and layer.conv is not None:
- # Handle both single tensor and tuple conv states
- if isinstance(layer.conv, tuple):
- return layer.conv[0].shape[0]
- return layer.conv.shape[0]
+ for cache in self._iter_caches():
+ bs = cache.batch_size
+ if bs is not None:
+ return bs
return None
@property
def max_cache_len(self):
- max_len = None
- for layer in self.layers:
- if isinstance(layer, dict):
- for cache in layer.values():
- if isinstance(cache, _AttentionCache):
- if cache.window is not None:
- max_len = cache.window if max_len is None else min(max_len, cache.window)
- elif isinstance(layer, _AttentionCache):
- if layer.window is not None:
- max_len = layer.window if max_len is None else min(max_len, layer.window)
- return max_len
+ windows = [
+ cache.window
+ for cache in self._iter_caches()
+ if isinstance(cache, _AttentionCache) and cache.window is not None
+ ]
+ return min(windows) if windows else None
def __len__(self):
return len(self.layers)
diff --git a/fast_llm_external_models/apriel2/conversion/__init__.py b/fast_llm_external_models/apriel2/conversion/__init__.py
index 983a632e0..2c28d1e87 100644
--- a/fast_llm_external_models/apriel2/conversion/__init__.py
+++ b/fast_llm_external_models/apriel2/conversion/__init__.py
@@ -1,88 +1,138 @@
"""Weight conversion system for Apriel2 models.
-Architecture Overview
-=====================
+Overview
+========
-This package implements a declarative weight transformation system with two
-orthogonal concerns:
+This package implements a declarative weight transformation system. The core
+abstraction separates config composition (structural) from plan execution (weights).
-1. **Config Composition** - Structural transformations of model configs
-2. **Plan Building & Execution** - Weight transformations between configs
+Conceptual Types
+================
-These concerns are intentionally separated:
-- Config composition determines WHAT the target architecture looks like
-- Plan building determines HOW weights are transformed to match
-- The `init` field bridges them: it's config metadata consumed by the plan builder
+All configs are ``dict``, but we distinguish three conceptual types:
-Key Design Decisions
-====================
+**State (S)** - A complete model config without ``init`` fields.
+ What you load from disk or save after conversion.
-**Declarative Plans**
- Plans are DATA (JSON-serializable expressions), not functions. This enables:
- - Inspection and debugging of transformations
- - Serialization for distributed execution
- - Composition via substitution rather than function composition
-
-**Separation of Config and Weights**
- The `init` field in surgery specs controls weight handling (transfer vs random)
- but does NOT affect config composition. Config composition is purely structural.
- After composition, `init` fields are stripped from complete configs.
-
-**Composition Semantics**
- Surgery specs use declarative (merge) composition, not operational (function)
- composition. For "additive" surgeries (modifying existing structure), the
- monoid action law holds. For "replacement" surgeries (defining complete new
- structure), sequential application differs from composed application by design.
-
-**Cross-Type Derivation**
- When converting between mixer types (e.g., attention → mamba), geometric
- parameters are derived where possible:
- - attention.heads → mamba dimensions (MIL conversion)
- - attention.heads → gdn heads (DIL conversion)
+**Partial Surgery (P)** - An incomplete config specifying changes.
+ May contain ``init`` fields (``transfer`` or ``random``).
-Module Structure
-================
+**Transition Spec (T)** - A complete config WITH ``init`` fields.
+ The result of applying surgery to a state. Describes both target
+ structure and weight initialization mode.
+
+Algebraic Structure
+===================
+
+**Monoid**: Partial surgeries compose via deep merge::
+
+ compose_configs : P × P → P
+
+**Action**: Surgeries act on states to produce transition specs::
+
+ compose_configs : S × P → T
+ compose_configs : T × P → T
+
+**Extraction**: Strip init to get a state::
+
+ strip_init_fields : T → S
+
+**Planning**: Build weight transformation from source state + transition spec::
+
+ plan_surgery : S × T → Plan
-- `config.py` - Config composition (compose_configs, apply_surgery)
-- `converters.py` - Plan builders (plan_surgery, plan_mil_attention_to_mamba, etc.)
-- `expr.py` - Expression types and plan class (Ref, Slice, Concat, Init, ExprPlan)
-- `executor.py` - Plan execution (StreamingExecutor, execute)
-- `io.py` - Streaming I/O (SafetensorLoader, ShardedSafetensorWriter)
-- `llava/` - Source-specific converter for Llava → Apriel2
+The ``init`` Field
+==================
-Example Usage
+The ``init`` field in surgeries specifies weight initialization:
+
+- ``init: transfer`` → transfer/convert weights from source
+- ``init: random`` → randomly initialize weights
+
+This field is preserved through ``compose_configs`` so ``plan_surgery`` can read it.
+Use ``strip_init_fields`` before saving configs to disk.
+
+Typical Usage
=============
+::
+
from fast_llm_external_models.apriel2.conversion import (
compose_configs,
plan_surgery,
+ strip_init_fields,
execute,
)
- # 1. Compose configs to get target architecture
- target_config = compose_configs(source_config, surgery_spec)
+ # Load source state
+ source_state = load_config(...) # S
- # 2. Build plan for weight transformation
- plan = plan_surgery(source_config, surgery_spec)
+ # Apply surgery
+ surgery = {"decoder": {"block": {"mixer": {"type": "gdn", "init": "random"}}}} # P
+ transition = compose_configs(source_state, surgery) # T
- # 3. Execute plan to transform weights
- target_weights = execute(plan, source_weights, seed=42)
+ # Build and execute plan
+ plan = plan_surgery(source_state, transition)
+ weights = execute(plan, source_weights, seed=42)
-For streaming I/O with large models:
+ # Save (strip init first)
+ target_state = strip_init_fields(transition) # S
+ save_config(target_state)
- from fast_llm_external_models.apriel2.conversion import (
- StreamingExecutor,
- SafetensorLoader,
- ShardedSafetensorWriter,
- )
+For chained surgeries::
+
+ current_state = source_state # S
+ current_plan = identity_plan
+
+ for surgery in surgery_chain: # each P
+ transition = compose_configs(current_state, surgery) # T
+ plan = plan_surgery(current_state, transition)
+ current_plan = compose(current_plan, plan)
+ current_state = strip_init_fields(transition) # S <- IMPORTANT!
+
+**Note**: The ``strip_init_fields`` call is critical. It ensures that ``init: random``
+applies only to the surgery that introduces a component. Without stripping, subsequent
+surgeries would re-randomize existing components. See ``config.py`` docstring for details.
+
+Key Design Decisions
+====================
+
+**Declarative Plans**
+ Plans are data (expressions), not functions. Enables inspection,
+ serialization, and composition via substitution.
+
+**Inheritance Semantics**
+ When S × P → T, unspecified fields inherit from source.
+ Cross-type derivation maps geometry (attention.heads → gdn.value_heads).
- with SafetensorLoader(source_files) as loader:
- executor = StreamingExecutor(plan, loader)
- with ShardedSafetensorWriter(output_dir) as writer:
- for key, tensor in executor.execute(seed=42):
- writer.add(key, tensor)
+**Additive vs Replacement Surgeries**
+ Additive surgeries (no ``type:`` declaration) satisfy the action law.
+ Replacement surgeries (explicit ``type:``) use last-write-wins.
+
+Module Structure
+================
+
+- ``config.py`` - Config composition (compose_configs, strip_init_fields)
+- ``converters.py`` - Plan builders (plan_surgery, plan_mil_attention_to_mamba)
+- ``expr.py`` - Expression types (Ref, Slice, Concat, Init, ExprPlan)
+- ``executor.py`` - Plan execution (StreamingExecutor, execute)
+- ``io.py`` - Streaming I/O (SafetensorLoader, ShardedSafetensorWriter)
"""
+# Config composition
+from fast_llm_external_models.apriel2.conversion.config import compose_configs, strip_init_fields
+
+# Plan builders (generic)
+from fast_llm_external_models.apriel2.conversion.converters import (
+ plan_dil_attention_to_gdn,
+ plan_kil_attention_to_kda,
+ plan_mil_attention_to_mamba,
+ plan_surgery,
+)
+
+# Execution
+from fast_llm_external_models.apriel2.conversion.executor import MAX_SEED, StreamingExecutor, execute
+
# Core types and plan operations
from fast_llm_external_models.apriel2.conversion.expr import (
Concat,
@@ -104,13 +154,6 @@
substitute,
)
-# Execution
-from fast_llm_external_models.apriel2.conversion.executor import (
- MAX_SEED,
- StreamingExecutor,
- execute,
-)
-
# I/O utilities
from fast_llm_external_models.apriel2.conversion.io import (
DEFAULT_MAX_SHARD_SIZE,
@@ -118,22 +161,9 @@
ShardedSafetensorWriter,
)
-# Plan builders (generic)
-from fast_llm_external_models.apriel2.conversion.converters import (
- plan_mil_attention_to_mamba,
- plan_dil_attention_to_gdn,
- plan_kil_attention_to_kda,
- plan_surgery,
-)
-
-# Config composition
-from fast_llm_external_models.apriel2.conversion.config import compose_configs
-
# Source-specific converters
-from fast_llm_external_models.apriel2.conversion.llava import (
- convert_config as convert_llava_config,
- plan_llava_to_apriel2,
-)
+from fast_llm_external_models.apriel2.conversion.llava import convert_config as convert_llava_config
+from fast_llm_external_models.apriel2.conversion.llava import plan_llava_to_apriel2
# Rendering (optional, imported lazily by ExprPlan.render_tree)
# from fast_llm_external_models.apriel2.conversion.render import render_tree
@@ -175,6 +205,7 @@
"plan_kil_attention_to_kda",
# Config composition
"compose_configs",
+ "strip_init_fields",
# Source-specific converters
"convert_llava_config",
"plan_llava_to_apriel2",
diff --git a/fast_llm_external_models/apriel2/conversion/config.py b/fast_llm_external_models/apriel2/conversion/config.py
index 48f8ff44b..3752688c1 100644
--- a/fast_llm_external_models/apriel2/conversion/config.py
+++ b/fast_llm_external_models/apriel2/conversion/config.py
@@ -1,56 +1,136 @@
"""Config composition for Apriel2 architecture transformations.
-This module handles STRUCTURAL composition of configs, independent of weight handling.
-The `init` field in surgery specs is preserved as metadata for the plan builder but
-does not affect how configs are composed.
+Conceptual Types
+================
-Composition Cases
-=================
+The system operates on three conceptual types, all represented as ``dict``:
-compose_configs(base, overlay) handles four cases based on completeness:
+**State (S)**
+ A complete structural description of a model. Has ``hidden_size`` and ``decoder``.
+ Does NOT contain ``init`` fields. Represents WHAT a model looks like.
-1. **Complete + Partial** → Apply surgery semantics (inheritance, cross-type derivation)
-2. **Partial + Partial** → Deep merge (monoid operation on surgery specs)
-3. **Partial + Complete** → Overlay wins (complete config replaces partial)
-4. **Complete + Complete** → Deep merge, then strip `init` fields
+ Example: A saved config.json, or a model you're about to transform.
-A config is "complete" if it has `hidden_size` and `decoder` (i.e., it's a full model
-config, not a surgery spec).
+**Partial Surgery (P)**
+ An incomplete config specifying fields to change. Missing ``hidden_size`` or
+ ``decoder``. May contain ``init`` fields specifying weight initialization mode.
-Surgery Semantics
-=================
+ Example: ``{"decoder": {"block": {"mixer": {"type": "gdn", "init": "random"}}}}``
-When applying a surgery spec to a complete config:
+**Transition Spec (T)**
+ A complete config WITH ``init`` fields. Describes both the target structure
+ AND how to initialize weights. This is the output of applying a surgery to
+ a state - it's a complete specification of the transformation.
-**Inheritance**
- Unspecified parameters inherit from the source config. New blocks inherit
- from the "default" block (first block in pattern, or the single fixed block).
+ Example: The result of ``compose_configs(state, surgery)`` before stripping.
-**Cross-Type Derivation**
- When changing mixer types, geometric parameters are derived where possible:
- - attention → sliding_window: preserve heads, head_groups, head_size
- - attention → gdn: heads → value_heads, head_groups → key_heads
- - attention → mamba: derive d_inner, d_xb, dt_rank from hidden_size
- - attention → kda: preserve heads, head_size → head_dim
+The distinction between S and T is semantic (presence of ``init``), not structural.
+Both are "complete" in the sense of having ``hidden_size`` and ``decoder``.
-**Stochastic Mixer Composition**
- Two semantics based on whether surgery declares `type: stochastic`:
- - Replacement: surgery declares type → only surgery's sub-mixers included
- - Additive: surgery omits type → source sub-mixers preserved, surgery adds/modifies
+Algebraic Structure
+===================
- This distinction means the monoid action law holds for additive surgeries but
- intentionally fails for replacement surgeries (they have "last-write-wins" semantics).
+**Partial Surgeries form a Monoid (P, ∘, {})**::
-The `init` Field
-================
+ compose_configs : P × P → P (deep merge, overlay wins)
+
+ Identity: compose_configs(p, {}) = compose_configs({}, p) = p
+ Associativity: compose_configs(compose_configs(a, b), c)
+ = compose_configs(a, compose_configs(b, c))
+
+**Surgeries act on States to produce Transition Specs**::
+
+ compose_configs : S × P → T (apply surgery with inheritance)
+ compose_configs : T × P → T (extend transition with more surgery)
+
+**Action Law (for additive surgeries)**::
+
+ compose_configs(compose_configs(s, p₁), p₂) = compose_configs(s, compose_configs(p₁, p₂))
+
+This law holds when surgeries are "additive" (modifying existing structure without
+declaring new types). For "replacement" surgeries (explicitly declaring ``type:``),
+the action law intentionally fails - this is last-write-wins semantics.
+
+**State Extraction**::
+
+ strip_init_fields : T → S (remove init metadata for saving)
+
+Operations Summary
+==================
+
+``compose_configs(base, overlay)`` dispatches based on completeness:
+
+1. **S × P → T** : Apply surgery to state (inheritance, cross-type derivation)
+2. **T × P → T** : Extend transition spec with more surgery
+3. **P × P → P** : Merge partial surgeries (monoid operation)
+4. **S × S → S** : Merge states (deep merge, rare)
+5. **P × S → S** : Overlay wins (complete replaces partial)
+
+``strip_init_fields(config)`` removes all ``init`` fields, converting T → S.
+
+Inheritance Semantics
+=====================
+
+When applying a surgery (S × P → T):
+
+- Unspecified fields inherit from source state
+- New decoder blocks inherit from the "default" block
+- Cross-type derivation maps geometry (attention.heads → gdn.value_heads, etc.)
+- Stochastic mixers: additive surgery preserves source mixers, replacement replaces
+
+The ``init`` Field
+==================
+
+The ``init`` field specifies weight initialization mode for ``plan_surgery()``:
+
+- ``init: transfer`` → transfer weights from source (possibly with conversion)
+- ``init: random`` → randomly initialize weights
+
+**Key invariant**: ``init`` is preserved through composition so ``plan_surgery()``
+can read it. Use ``strip_init_fields()`` to obtain a pure state for:
+
+- Saving to disk (config.json should not contain ``init``)
+- Starting the next surgery iteration (current_state should be S, not T)
+
+Typical Usage Pattern
+=====================
+
+::
+
+ current_state: S = load_config(...)
+
+ for surgery: P in surgery_chain:
+ transition: T = compose_configs(current_state, surgery) # S × P → T
+ plan = plan_surgery(current_state, transition) # plan reads init from T
+ current_state: S = strip_init_fields(transition) # T → S for next iteration
+
+ save_config(current_state) # S has no init fields
+
+Sequential vs Merged Surgery Application
+========================================
+
+**IMPORTANT**: Applying surgeries sequentially (with stripping) differs from merging
+surgeries first then applying once. This affects ``init`` semantics:
+
+**Sequential** (recommended)::
-The `init` field is metadata for the plan builder, NOT for config composition:
-- `init: transfer` → plan builder creates weight transfer mappings
-- `init: random` → plan builder creates random initialization
+ t1 = compose_configs(s, p1) # GDN gets init: random
+ s1 = strip_init_fields(t1) # GDN loses init
+ t2 = compose_configs(s1, p2) # GDN has init: None → transfer mode
-After surgery is applied to produce a complete config, ALL `init` fields are stripped.
-This ensures configs are purely structural and plan creation is Markovian (depends only
-on current config + surgery, not on history).
+**Merged**::
+
+ merged = compose_configs(p1, p2) # GDN keeps init: random from p1
+ t = compose_configs(s, merged) # GDN has init: random → random mode
+
+The sequential approach means ``init: random`` applies **only to the surgery that
+introduces a component**. Subsequent surgeries transfer existing weights by default.
+
+This is the intended behavior: if surgery 1 adds GDN with random init, and surgery 2
+adds sliding window (not mentioning GDN), GDN keeps its weights from surgery 1.
+
+The merged approach would re-randomize GDN in every execution, which is rarely desired.
+Always use the sequential pattern shown in "Typical Usage Pattern" above.
"""
from __future__ import annotations
@@ -65,14 +145,42 @@ def is_complete(config: dict) -> bool:
def compose_configs(base: dict, overlay: dict | None) -> dict:
- """Compose two configs.
+ """Compose configs. Dispatches based on completeness of arguments.
+
+ Type Signatures (see module docstring for S, P, T definitions)::
+
+ S × P → T Apply surgery to state, get transition spec
+ T × P → T Extend transition spec with more surgery
+ P × P → P Merge partial surgeries (monoid operation)
+ S × S → S Merge states (deep merge)
+ P × S → S Overlay wins
+
+ The ``init`` field is preserved in all cases. Use ``strip_init_fields()``
+ to convert T → S for saving or iteration.
Args:
- base: Base config (complete or partial surgery spec).
- overlay: Overlay config (complete or partial surgery spec).
+ base: State (S), transition spec (T), or partial surgery (P).
+ overlay: Partial surgery (P) or state (S).
Returns:
- Composed config.
+ Composed config. Type depends on inputs (see signatures above).
+
+ Algebraic Properties:
+ Monoid: ``compose(compose(p1, p2), p3) == compose(p1, compose(p2, p3))``
+
+ Action law (additive surgeries):
+ ``compose(compose(s, p1), p2) == compose(s, compose(p1, p2))``
+
+ Example::
+
+ # S × P → T (apply surgery to state)
+ state = {"hidden_size": 256, "decoder": {...}}
+ surgery = {"decoder": {"block": {"mixer": {"init": "random"}}}}
+ transition = compose_configs(state, surgery) # T, has init
+
+ # Build plan, then extract state
+ plan = plan_surgery(state, transition)
+ new_state = strip_init_fields(transition) # S, no init
"""
if not overlay:
return copy.deepcopy(base)
@@ -94,9 +202,8 @@ def compose_configs(base: dict, overlay: dict | None) -> dict:
if not base_complete and overlay_complete:
return copy.deepcopy(overlay)
- # Case 4: Both complete -> deep merge
+ # Case 4: Both complete -> deep merge (init preserved for plan_surgery)
result = _deep_merge(base, overlay)
- _strip_keys(result, {"init"})
return result
@@ -128,26 +235,53 @@ def _strip_keys(config: Any, keys_to_strip: set[str]) -> None:
_strip_keys(item, keys_to_strip)
+def strip_init_fields(config: dict) -> dict:
+ """Return a copy of config with all ``init`` fields stripped (T → S).
+
+ Converts a transition spec (T) to a state (S) by removing ``init`` metadata.
+ Use this:
+
+ 1. Before saving configs to disk (config.json should be purely structural)
+ 2. Between surgery iterations (so subsequent surgeries don't re-randomize)
+
+ See module docstring section "Sequential vs Merged Surgery Application" for
+ why stripping between iterations is critical.
+
+ Args:
+ config: Config dict (not modified). Typically a transition spec (T).
+
+ Returns:
+ A deep copy with all ``init`` fields recursively removed (a state S).
+ """
+ result = copy.deepcopy(config)
+ _strip_keys(result, {"init"})
+ return result
+
+
# =============================================================================
# Surgery application with full semantics
# =============================================================================
def apply_surgery(source_config: dict, surgery_config: dict | None) -> dict:
- """Apply surgery specification to a complete source config.
+ """Apply surgery spec to complete config (the monoid action).
- This handles:
- - Top-level scalar overrides
- - Decoder composition (fixed vs pattern)
- - Stochastic mixer sub-mixer inheritance
- - Cross-type derivation (attention → gdn, attention → mamba)
+ This is the internal implementation of the monoid action: surgery specs
+ acting on complete configs. Called by compose_configs when base is complete
+ and overlay is partial.
+
+ Implements inheritance semantics:
+ - Unspecified fields inherit from source
+ - Cross-type derivation maps geometry (attention → gdn, etc.)
+ - Stochastic sub-mixers inherit from source's main mixer
+ - `init` fields are PRESERVED for plan_surgery() to see
Args:
- source_config: Complete Apriel2 config.
- surgery_config: Partial surgery specification.
+ source_config: Complete Apriel2 config (the state being acted on).
+ surgery_config: Partial surgery spec (the monoid element acting).
Returns:
- Complete Apriel2 config with surgery applied.
+ Complete config with surgery applied. `init` fields preserved.
"""
if not surgery_config:
return copy.deepcopy(source_config)
@@ -189,8 +323,9 @@ def apply_surgery(source_config: dict, surgery_config: dict | None) -> dict:
surgery_config["vision_encoder"],
)
- # Strip init keys from final result
- _strip_keys(result, {"init"})
+ # NOTE: We do NOT strip init keys here. The `init` field is preserved through
+ # composition so that plan_surgery() can see it and decide between transfer
+ # vs random initialization. The caller (convert.py) strips init before saving.
return result
@@ -392,6 +527,12 @@ def _compose_single_mixer(source: dict, surgery: dict, hidden_size: int) -> dict
result[key] = surgery[key]
elif key in source:
result[key] = source[key]
+ # Copy per-layer bias settings (query_layer, key_layer, value_layer, dense_layer)
+ for key in ["query_layer", "key_layer", "value_layer", "dense_layer", "add_linear_biases"]:
+ if key in surgery:
+ result[key] = surgery[key]
+ elif key in source:
+ result[key] = copy.deepcopy(source[key])
# Preserve init
if "init" in surgery:
result["init"] = surgery["init"]
diff --git a/fast_llm_external_models/apriel2/conversion/converters.py b/fast_llm_external_models/apriel2/conversion/converters.py
index 6d1350c54..9c9238bb0 100644
--- a/fast_llm_external_models/apriel2/conversion/converters.py
+++ b/fast_llm_external_models/apriel2/conversion/converters.py
@@ -61,16 +61,7 @@
from __future__ import annotations
-from fast_llm_external_models.apriel2.conversion.expr import (
- Concat,
- Expr,
- ExprPlan,
- Init,
- Ref,
- Slice,
- W,
-)
-
+from fast_llm_external_models.apriel2.conversion.expr import Concat, Expr, ExprPlan, Init, Ref, Slice, W
# =============================================================================
# SECTION 1: Per-Mixer Plan Functions
@@ -79,6 +70,21 @@
# This is the single source of truth for each mixer's weight schema.
+def _get_attention_bias_enabled(config: dict, layer_name: str) -> bool:
+ """Get whether bias is enabled for an attention layer.
+
+ Checks per-layer bias config (e.g., query_layer.bias.enabled).
+ Falls back to add_linear_biases if not set.
+ """
+ layer_cfg = config.get(layer_name, {})
+ bias_cfg = layer_cfg.get("bias", {})
+ enabled = bias_cfg.get("enabled")
+ if enabled is not None:
+ return enabled
+ # Fall back to add_linear_biases
+ return config.get("add_linear_biases", False)
+
+
def _plan_attention_mixer(
*,
prefix: W,
@@ -90,9 +96,13 @@ def _plan_attention_mixer(
Weight schema:
- q_proj.weight: (q_size, hidden_size)
+ - q_proj.bias: (q_size,) [if query_layer.bias.enabled]
- k_proj.weight: (kv_size, hidden_size)
+ - k_proj.bias: (kv_size,) [if key_layer.bias.enabled]
- v_proj.weight: (kv_size, hidden_size)
+ - v_proj.bias: (kv_size,) [if value_layer.bias.enabled]
- o_proj.weight: (hidden_size, q_size)
+ - o_proj.bias: (hidden_size,) [if dense_layer.bias.enabled]
Args:
prefix: Target weight path prefix.
@@ -100,12 +110,28 @@ def _plan_attention_mixer(
hidden_size: Model hidden size.
source_prefix: If provided, passthrough from source. If None, random init.
"""
+ # Check per-layer bias configuration
+ q_bias = _get_attention_bias_enabled(config, "query_layer")
+ k_bias = _get_attention_bias_enabled(config, "key_layer")
+ v_bias = _get_attention_bias_enabled(config, "value_layer")
+ o_bias = _get_attention_bias_enabled(config, "dense_layer")
+
if source_prefix is not None:
- # Passthrough
- return ExprPlan(mappings={
+ # Passthrough weights
+ mappings: dict[W, Expr] = {
prefix / proj / "weight": Ref(key=source_prefix / proj / "weight")
for proj in ["q_proj", "k_proj", "v_proj", "o_proj"]
- })
+ }
+ # Passthrough biases if enabled
+ if q_bias:
+ mappings[prefix / "q_proj" / "bias"] = Ref(key=source_prefix / "q_proj" / "bias")
+ if k_bias:
+ mappings[prefix / "k_proj" / "bias"] = Ref(key=source_prefix / "k_proj" / "bias")
+ if v_bias:
+ mappings[prefix / "v_proj" / "bias"] = Ref(key=source_prefix / "v_proj" / "bias")
+ if o_bias:
+ mappings[prefix / "o_proj" / "bias"] = Ref(key=source_prefix / "o_proj" / "bias")
+ return ExprPlan(mappings=mappings)
# Random init
heads = config["heads"]
@@ -114,12 +140,22 @@ def _plan_attention_mixer(
q_size = heads * head_size
kv_size = head_groups * head_size
- return ExprPlan(mappings={
+ mappings = {
prefix / "q_proj" / "weight": Init(shape=(q_size, hidden_size), init_type="kaiming"),
prefix / "k_proj" / "weight": Init(shape=(kv_size, hidden_size), init_type="kaiming"),
prefix / "v_proj" / "weight": Init(shape=(kv_size, hidden_size), init_type="kaiming"),
prefix / "o_proj" / "weight": Init(shape=(hidden_size, q_size), init_type="kaiming"),
- })
+ }
+ # Random init biases if enabled
+ if q_bias:
+ mappings[prefix / "q_proj" / "bias"] = Init(shape=(q_size,), init_type="zeros")
+ if k_bias:
+ mappings[prefix / "k_proj" / "bias"] = Init(shape=(kv_size,), init_type="zeros")
+ if v_bias:
+ mappings[prefix / "v_proj" / "bias"] = Init(shape=(kv_size,), init_type="zeros")
+ if o_bias:
+ mappings[prefix / "o_proj" / "bias"] = Init(shape=(hidden_size,), init_type="zeros")
+ return ExprPlan(mappings=mappings)
def _plan_mamba_mixer(
@@ -150,20 +186,22 @@ def _plan_mamba_mixer(
"""
if source_prefix is not None:
# Passthrough - include all possible weights
- return ExprPlan(mappings={
- prefix / name: Ref(key=source_prefix / name)
- for name in [
- "in_proj.weight",
- "out_proj.weight",
- "dt_in_proj.weight",
- "dt_proj.weight",
- "dt_proj.bias",
- "conv1d.weight",
- "conv1d.bias",
- "A_log",
- "D",
- ]
- })
+ return ExprPlan(
+ mappings={
+ prefix / name: Ref(key=source_prefix / name)
+ for name in [
+ "in_proj.weight",
+ "out_proj.weight",
+ "dt_in_proj.weight",
+ "dt_proj.weight",
+ "dt_proj.bias",
+ "conv1d.weight",
+ "conv1d.bias",
+ "A_log",
+ "D",
+ ]
+ }
+ )
# Random init
d_inner = config["d_inner"]
@@ -181,9 +219,7 @@ def _plan_mamba_mixer(
conv_channels = d_inner if repeat_kv_before_conv else d_xb
mappings: dict[W, Expr] = {
- prefix / "in_proj" / "weight": Init(
- shape=(2 * d_inner + 2 * d_xb, hidden_size), init_type="kaiming"
- ),
+ prefix / "in_proj" / "weight": Init(shape=(2 * d_inner + 2 * d_xb, hidden_size), init_type="kaiming"),
prefix / "out_proj" / "weight": Init(shape=(hidden_size, d_inner), init_type="kaiming"),
prefix / "dt_in_proj" / "weight": Init(shape=(dt_rank, hidden_size), init_type="kaiming"),
prefix / "dt_proj" / "weight": Init(shape=(d_inner, dt_rank), init_type="kaiming"),
@@ -230,18 +266,20 @@ def _plan_gdn_mixer(
"""
if source_prefix is not None:
# Passthrough
- return ExprPlan(mappings={
- prefix / name: Ref(key=source_prefix / name)
- for name in [
- "in_proj_qkvz.weight",
- "in_proj_ba.weight",
- "out_proj.weight",
- "convolution.weight",
- "A_log",
- "dt_bias",
- "norm.weight",
- ]
- })
+ return ExprPlan(
+ mappings={
+ prefix / name: Ref(key=source_prefix / name)
+ for name in [
+ "in_proj_qkvz.weight",
+ "in_proj_ba.weight",
+ "out_proj.weight",
+ "convolution.weight",
+ "A_log",
+ "dt_bias",
+ "norm.weight",
+ ]
+ }
+ )
# Random init
num_v_heads = config["value_heads"]
@@ -255,17 +293,19 @@ def _plan_gdn_mixer(
conv_dim = key_dim * 2 + value_dim
qkvz_size = key_dim * 2 + value_dim * 2 # Q, K both key_dim; V, Z both value_dim
- return ExprPlan(mappings={
- prefix / "in_proj_qkvz" / "weight": Init(shape=(qkvz_size, hidden_size), init_type="kaiming"),
- prefix / "in_proj_ba" / "weight": Init(shape=(num_v_heads * 2, hidden_size), init_type="zeros"),
- prefix / "out_proj" / "weight": Init(shape=(hidden_size, value_dim), init_type="kaiming"),
- prefix / "convolution" / "weight": Init(
- shape=(conv_dim, 1, conv_kernel_size), init_type="scaled_identity_conv"
- ),
- prefix / "A_log": Init(shape=(num_v_heads,), init_type="slow_decay"),
- prefix / "dt_bias": Init(shape=(num_v_heads,), init_type="zeros"),
- prefix / "norm" / "weight": Init(shape=(head_v_dim,), init_type="ones"),
- })
+ return ExprPlan(
+ mappings={
+ prefix / "in_proj_qkvz" / "weight": Init(shape=(qkvz_size, hidden_size), init_type="kaiming"),
+ prefix / "in_proj_ba" / "weight": Init(shape=(num_v_heads * 2, hidden_size), init_type="zeros"),
+ prefix / "out_proj" / "weight": Init(shape=(hidden_size, value_dim), init_type="kaiming"),
+ prefix
+ / "convolution"
+ / "weight": Init(shape=(conv_dim, 1, conv_kernel_size), init_type="scaled_identity_conv"),
+ prefix / "A_log": Init(shape=(num_v_heads,), init_type="slow_decay"),
+ prefix / "dt_bias": Init(shape=(num_v_heads,), init_type="zeros"),
+ prefix / "norm" / "weight": Init(shape=(head_v_dim,), init_type="ones"),
+ }
+ )
def _plan_kda_mixer(
@@ -298,26 +338,28 @@ def _plan_kda_mixer(
"""
if source_prefix is not None:
# Passthrough
- return ExprPlan(mappings={
- prefix / name: Ref(key=source_prefix / name)
- for name in [
- "q_proj.weight",
- "k_proj.weight",
- "v_proj.weight",
- "o_proj.weight",
- "q_conv.weight",
- "k_conv.weight",
- "v_conv.weight",
- "f_a_proj.weight",
- "f_b_proj.weight",
- "g_a_proj.weight",
- "g_b_proj.weight",
- "beta_proj.weight",
- "A_log",
- "dt_bias",
- "norm.weight",
- ]
- })
+ return ExprPlan(
+ mappings={
+ prefix / name: Ref(key=source_prefix / name)
+ for name in [
+ "q_proj.weight",
+ "k_proj.weight",
+ "v_proj.weight",
+ "o_proj.weight",
+ "q_conv.weight",
+ "k_conv.weight",
+ "v_conv.weight",
+ "f_a_proj.weight",
+ "f_b_proj.weight",
+ "g_a_proj.weight",
+ "g_b_proj.weight",
+ "beta_proj.weight",
+ "A_log",
+ "dt_bias",
+ "norm.weight",
+ ]
+ }
+ )
# Random init
num_heads = config["heads"]
@@ -325,36 +367,38 @@ def _plan_kda_mixer(
projection_size = num_heads * head_dim
conv_kernel_size = config.get("convolution_layer", {}).get("kernel_size", 4)
- return ExprPlan(mappings={
- # Main projections
- prefix / "q_proj" / "weight": Init(shape=(projection_size, hidden_size), init_type="kaiming"),
- prefix / "k_proj" / "weight": Init(shape=(projection_size, hidden_size), init_type="kaiming"),
- prefix / "v_proj" / "weight": Init(shape=(projection_size, hidden_size), init_type="kaiming"),
- prefix / "o_proj" / "weight": Init(shape=(hidden_size, projection_size), init_type="kaiming"),
- # Convolutions
- prefix / "q_conv" / "weight": Init(
- shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv"
- ),
- prefix / "k_conv" / "weight": Init(
- shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv"
- ),
- prefix / "v_conv" / "weight": Init(
- shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv"
- ),
- # Gate kernels (low-rank factorization)
- prefix / "f_a_proj" / "weight": Init(shape=(head_dim, hidden_size), init_type="kaiming"),
- prefix / "f_b_proj" / "weight": Init(shape=(projection_size, head_dim), init_type="kaiming"),
- # Output gate (low-rank factorization)
- prefix / "g_a_proj" / "weight": Init(shape=(head_dim, hidden_size), init_type="kaiming"),
- prefix / "g_b_proj" / "weight": Init(shape=(projection_size, head_dim), init_type="kaiming"),
- # Beta projection
- prefix / "beta_proj" / "weight": Init(shape=(num_heads, hidden_size), init_type="kaiming"),
- # Learnable parameters
- prefix / "A_log": Init(shape=(num_heads,), init_type="slow_decay"),
- prefix / "dt_bias": Init(shape=(projection_size,), init_type="zeros"),
- # Normalization
- prefix / "norm" / "weight": Init(shape=(head_dim,), init_type="ones"),
- })
+ return ExprPlan(
+ mappings={
+ # Main projections
+ prefix / "q_proj" / "weight": Init(shape=(projection_size, hidden_size), init_type="kaiming"),
+ prefix / "k_proj" / "weight": Init(shape=(projection_size, hidden_size), init_type="kaiming"),
+ prefix / "v_proj" / "weight": Init(shape=(projection_size, hidden_size), init_type="kaiming"),
+ prefix / "o_proj" / "weight": Init(shape=(hidden_size, projection_size), init_type="kaiming"),
+ # Convolutions
+ prefix
+ / "q_conv"
+ / "weight": Init(shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv"),
+ prefix
+ / "k_conv"
+ / "weight": Init(shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv"),
+ prefix
+ / "v_conv"
+ / "weight": Init(shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv"),
+ # Gate kernels (low-rank factorization)
+ prefix / "f_a_proj" / "weight": Init(shape=(head_dim, hidden_size), init_type="kaiming"),
+ prefix / "f_b_proj" / "weight": Init(shape=(projection_size, head_dim), init_type="kaiming"),
+ # Output gate (low-rank factorization)
+ prefix / "g_a_proj" / "weight": Init(shape=(head_dim, hidden_size), init_type="kaiming"),
+ prefix / "g_b_proj" / "weight": Init(shape=(projection_size, head_dim), init_type="kaiming"),
+ # Beta projection
+ prefix / "beta_proj" / "weight": Init(shape=(num_heads, hidden_size), init_type="kaiming"),
+ # Learnable parameters
+ prefix / "A_log": Init(shape=(num_heads,), init_type="slow_decay"),
+ prefix / "dt_bias": Init(shape=(projection_size,), init_type="zeros"),
+ # Normalization
+ prefix / "norm" / "weight": Init(shape=(head_dim,), init_type="ones"),
+ }
+ )
# Dispatcher for per-mixer plan functions
@@ -409,16 +453,13 @@ def plan_mil_attention_to_mamba(
exprs=(
Init(shape=(d_inner, hidden_size), init_type="kaiming"), # z: random
Slice(
- expr=Ref(key=source_prefix / "v_proj" / "weight"),
- slices=((0, d_xb, None), (None, None, None))
+ expr=Ref(key=source_prefix / "v_proj" / "weight"), slices=((0, d_xb, None), (None, None, None))
), # x <- V
Slice(
- expr=Ref(key=source_prefix / "k_proj" / "weight"),
- slices=((0, d_xb, None), (None, None, None))
+ expr=Ref(key=source_prefix / "k_proj" / "weight"), slices=((0, d_xb, None), (None, None, None))
), # B <- K
Slice(
- expr=Ref(key=source_prefix / "q_proj" / "weight"),
- slices=((0, d_inner, None), (None, None, None))
+ expr=Ref(key=source_prefix / "q_proj" / "weight"), slices=((0, d_inner, None), (None, None, None))
), # C <- Q
),
dim=0,
@@ -532,19 +573,21 @@ def plan_dil_attention_to_gdn(
dim=0,
)
- return ExprPlan(mappings={
- target_prefix / "in_proj_qkvz" / "weight": in_proj_qkvz_expr,
- target_prefix / "in_proj_ba" / "weight": Init(
- shape=(2 * num_v_heads, hidden_size), init_type="zeros"
- ), # b=a=0 → β=0.5
- target_prefix / "out_proj" / "weight": Ref(key=source_prefix / "o_proj" / "weight"),
- target_prefix / "convolution" / "weight": Init(
- shape=(conv_dim, 1, conv_kernel_size), init_type="scaled_identity_conv"
- ),
- target_prefix / "A_log": Init(shape=(num_v_heads,), init_type="slow_decay"),
- target_prefix / "dt_bias": Init(shape=(num_v_heads,), init_type="zeros"),
- target_prefix / "norm" / "weight": Init(shape=(head_v_dim,), init_type="ones"),
- })
+ return ExprPlan(
+ mappings={
+ target_prefix / "in_proj_qkvz" / "weight": in_proj_qkvz_expr,
+ target_prefix
+ / "in_proj_ba"
+ / "weight": Init(shape=(2 * num_v_heads, hidden_size), init_type="zeros"), # b=a=0 → β=0.5
+ target_prefix / "out_proj" / "weight": Ref(key=source_prefix / "o_proj" / "weight"),
+ target_prefix
+ / "convolution"
+ / "weight": Init(shape=(conv_dim, 1, conv_kernel_size), init_type="scaled_identity_conv"),
+ target_prefix / "A_log": Init(shape=(num_v_heads,), init_type="slow_decay"),
+ target_prefix / "dt_bias": Init(shape=(num_v_heads,), init_type="zeros"),
+ target_prefix / "norm" / "weight": Init(shape=(head_v_dim,), init_type="ones"),
+ }
+ )
def plan_kil_attention_to_kda(
@@ -595,9 +638,7 @@ def plan_kil_attention_to_kda(
for h in range(num_heads):
src_h = h % source_num_q_heads
row_start = src_h * source_head_dim
- q_slices.append(
- Slice(expr=q_ref, slices=((row_start, row_start + head_dim, None), (None, None, None)))
- )
+ q_slices.append(Slice(expr=q_ref, slices=((row_start, row_start + head_dim, None), (None, None, None))))
q_expr = Concat(exprs=tuple(q_slices), dim=0)
# K: tile source KV heads to fill target projection_size
@@ -608,9 +649,7 @@ def plan_kil_attention_to_kda(
for h in range(num_heads):
src_h = h % source_num_kv_heads
row_start = src_h * source_head_dim
- k_slices.append(
- Slice(expr=k_ref, slices=((row_start, row_start + head_dim, None), (None, None, None)))
- )
+ k_slices.append(Slice(expr=k_ref, slices=((row_start, row_start + head_dim, None), (None, None, None))))
k_expr = Concat(exprs=tuple(k_slices), dim=0)
# V: tile source KV heads to fill target projection_size
@@ -621,41 +660,41 @@ def plan_kil_attention_to_kda(
for h in range(num_heads):
src_h = h % source_num_kv_heads
row_start = src_h * source_head_dim
- v_slices.append(
- Slice(expr=v_ref, slices=((row_start, row_start + head_dim, None), (None, None, None)))
- )
+ v_slices.append(Slice(expr=v_ref, slices=((row_start, row_start + head_dim, None), (None, None, None))))
v_expr = Concat(exprs=tuple(v_slices), dim=0)
- return ExprPlan(mappings={
- # Transfer main projections
- target_prefix / "q_proj" / "weight": q_expr,
- target_prefix / "k_proj" / "weight": k_expr,
- target_prefix / "v_proj" / "weight": v_expr,
- target_prefix / "o_proj" / "weight": Ref(key=source_prefix / "o_proj" / "weight"),
- # Random init: convolutions (scaled identity for near-passthrough initially)
- target_prefix / "q_conv" / "weight": Init(
- shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv"
- ),
- target_prefix / "k_conv" / "weight": Init(
- shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv"
- ),
- target_prefix / "v_conv" / "weight": Init(
- shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv"
- ),
- # Random init: gate kernels (low-rank factorization)
- target_prefix / "f_a_proj" / "weight": Init(shape=(head_dim, hidden_size), init_type="kaiming"),
- target_prefix / "f_b_proj" / "weight": Init(shape=(projection_size, head_dim), init_type="kaiming"),
- # Random init: output gate (low-rank factorization)
- target_prefix / "g_a_proj" / "weight": Init(shape=(head_dim, hidden_size), init_type="kaiming"),
- target_prefix / "g_b_proj" / "weight": Init(shape=(projection_size, head_dim), init_type="kaiming"),
- # Random init: beta projection
- target_prefix / "beta_proj" / "weight": Init(shape=(num_heads, hidden_size), init_type="kaiming"),
- # Random init: learnable parameters
- target_prefix / "A_log": Init(shape=(num_heads,), init_type="slow_decay"),
- target_prefix / "dt_bias": Init(shape=(projection_size,), init_type="zeros"),
- # Random init: normalization
- target_prefix / "norm" / "weight": Init(shape=(head_dim,), init_type="ones"),
- })
+ return ExprPlan(
+ mappings={
+ # Transfer main projections
+ target_prefix / "q_proj" / "weight": q_expr,
+ target_prefix / "k_proj" / "weight": k_expr,
+ target_prefix / "v_proj" / "weight": v_expr,
+ target_prefix / "o_proj" / "weight": Ref(key=source_prefix / "o_proj" / "weight"),
+ # Random init: convolutions (scaled identity for near-passthrough initially)
+ target_prefix
+ / "q_conv"
+ / "weight": Init(shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv"),
+ target_prefix
+ / "k_conv"
+ / "weight": Init(shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv"),
+ target_prefix
+ / "v_conv"
+ / "weight": Init(shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv"),
+ # Random init: gate kernels (low-rank factorization)
+ target_prefix / "f_a_proj" / "weight": Init(shape=(head_dim, hidden_size), init_type="kaiming"),
+ target_prefix / "f_b_proj" / "weight": Init(shape=(projection_size, head_dim), init_type="kaiming"),
+ # Random init: output gate (low-rank factorization)
+ target_prefix / "g_a_proj" / "weight": Init(shape=(head_dim, hidden_size), init_type="kaiming"),
+ target_prefix / "g_b_proj" / "weight": Init(shape=(projection_size, head_dim), init_type="kaiming"),
+ # Random init: beta projection
+ target_prefix / "beta_proj" / "weight": Init(shape=(num_heads, hidden_size), init_type="kaiming"),
+ # Random init: learnable parameters
+ target_prefix / "A_log": Init(shape=(num_heads,), init_type="slow_decay"),
+ target_prefix / "dt_bias": Init(shape=(projection_size,), init_type="zeros"),
+ # Random init: normalization
+ target_prefix / "norm" / "weight": Init(shape=(head_dim,), init_type="ones"),
+ }
+ )
# =============================================================================
@@ -786,7 +825,70 @@ def plan_surgery(
source_config: dict,
target_config: dict,
) -> ExprPlan:
- """Build plan for Apriel2→Apriel2 surgery (MIL, DIL, KIL, stochastic mixers, etc.)."""
+ """Build a weight conversion plan: S × T → Plan.
+
+ Creates an ExprPlan mapping target weight keys to expressions over source weights.
+ Handles same-type passthrough, cross-type conversions (MIL, DIL, KIL), and
+ stochastic mixer routing.
+
+ Type Signature::
+
+ plan_surgery : S × T → Plan
+
+ Where S is a state (source) and T is a transition spec (target with ``init`` fields).
+
+ The ``init`` Field
+ ------------------
+
+ The ``init`` field in ``target_config`` controls weight initialization:
+
+ - ``init: transfer`` (or absent) → create Ref expressions (transfer from source)
+ - ``init: random`` → create Init expressions (random initialization)
+
+ This is why ``target_config`` should be a transition spec (T) from ``compose_configs``,
+ not a stripped state (S). If ``init`` fields are missing, all components default to
+ transfer mode.
+
+ Args:
+ source_config: State (S) - complete config describing source architecture.
+ Must have hidden_size, decoder, etc. No ``init`` fields expected.
+ target_config: Transition spec (T) - complete config with ``init`` fields.
+ Use ``compose_configs(source, surgery)`` to produce this.
+
+ Returns:
+ ExprPlan mapping target weight keys to expressions over source weights.
+
+ Example::
+
+ # Apply a surgery that wraps attention in a stochastic mixer
+ surgery_spec = {
+ "decoder": {"block": {"mixer": {
+ "type": "stochastic",
+ "mixers": {
+ "attention": {"init": "transfer"},
+ "gdn": {"type": "gdn", "init": "random"},
+ }
+ }}}
+ }
+
+ # S × P → T
+ transition = compose_configs(source_config, surgery_spec)
+
+ # S × T → Plan
+ plan = plan_surgery(source_config, transition)
+
+ # Execute
+ new_weights = execute(plan, source_weights, seed=42)
+
+ # T → S for saving
+ target_state = strip_init_fields(transition)
+
+ Note:
+ Both arguments must be complete (have hidden_size and decoder).
+ The target_config should retain ``init`` fields from the surgery spec.
+ Passing a stripped state as target will cause all components to use
+ transfer mode, which may not be intended.
+ """
hidden_size = target_config.get("hidden_size", source_config.get("hidden_size"))
assert hidden_size is not None, "hidden_size must be specified in source or target config"
@@ -804,18 +906,24 @@ def plan_surgery(
target_block = _get_block_config(target_decoder, target_layer_idx)
plan += _plan_mixer(
- target_layer_idx, source_layer_idx,
- source_block.get("mixer", {}), target_block.get("mixer", {}),
+ target_layer_idx,
+ source_layer_idx,
+ source_block.get("mixer", {}),
+ target_block.get("mixer", {}),
hidden_size,
)
plan += _plan_mlp(
- target_layer_idx, source_layer_idx,
- source_block.get("mlp", {}), target_block.get("mlp", {}),
+ target_layer_idx,
+ source_layer_idx,
+ source_block.get("mlp", {}),
+ target_block.get("mlp", {}),
hidden_size,
)
plan += _plan_norms(
- target_layer_idx, source_layer_idx,
- source_block, target_block,
+ target_layer_idx,
+ source_layer_idx,
+ source_block,
+ target_block,
hidden_size,
)
@@ -839,14 +947,16 @@ def _plan_non_decoder_weights(config: dict) -> ExprPlan:
embed = W("model", "embed_tokens", "weight")
mappings[embed] = Ref(key=embed)
- head = W("lm_head", "weight")
- mappings[head] = Ref(key=head)
+ # lm_head only if not tied to embeddings
+ if not config.get("tie_word_embeddings", False):
+ head = W("lm_head", "weight")
+ mappings[head] = Ref(key=head)
norm = W("model", "norm", "weight")
mappings[norm] = Ref(key=norm)
- if "vision_encoder" in config:
- vision_config = config["vision_encoder"]
+ vision_config = config.get("vision_encoder")
+ if vision_config:
vision = W("model", "vision_encoder")
patch_emb = vision / "embeddings" / "patch_embeddings" / "weight"
@@ -950,9 +1060,13 @@ def _plan_mixer(
source_prefix = source_mixer_base
plan += _plan_mixer_transfer(
- matched_source_type, sub_type,
- matched_source, sub_config,
- source_prefix, target_prefix, hidden_size,
+ matched_source_type,
+ sub_type,
+ matched_source,
+ sub_config,
+ source_prefix,
+ target_prefix,
+ hidden_size,
)
# Passthrough source sub-mixers not in target spec
@@ -963,8 +1077,13 @@ def _plan_mixer(
source_prefix = source_layer / "mixer" / "mixers" / sub_name
target_prefix = target_layer / "mixer" / "mixers" / sub_name
plan += _plan_mixer_transfer(
- sub_type, sub_type, sub_config, sub_config,
- source_prefix, target_prefix, hidden_size,
+ sub_type,
+ sub_type,
+ sub_config,
+ sub_config,
+ source_prefix,
+ target_prefix,
+ hidden_size,
)
return plan
@@ -980,12 +1099,34 @@ def _plan_mixer(
source_prefix = source_layer / "mixer"
return _plan_mixer_transfer(
- main_source_type, target_type,
- main_source, target_mixer,
- source_prefix, target_prefix, hidden_size,
+ main_source_type,
+ target_type,
+ main_source,
+ target_mixer,
+ source_prefix,
+ target_prefix,
+ hidden_size,
)
+def _get_mlp_bias_enabled(config: dict, layer_name: str) -> bool:
+ """Get whether bias is enabled for an MLP layer.
+
+ Checks per-layer bias config (e.g., layer_1.bias.enabled, layer_2.bias.enabled).
+ Falls back to add_linear_biases if not set.
+
+ Note: layer_1 corresponds to gate_proj and up_proj (gated MLP) or just up_proj (non-gated)
+ layer_2 corresponds to down_proj
+ """
+ layer_cfg = config.get(layer_name, {})
+ bias_cfg = layer_cfg.get("bias", {})
+ enabled = bias_cfg.get("enabled")
+ if enabled is not None:
+ return enabled
+ # Fall back to add_linear_biases
+ return config.get("add_linear_biases", False)
+
+
def _plan_mlp(
target_layer_idx: int,
source_layer_idx: int,
@@ -1006,7 +1147,7 @@ def _plan_mlp_transfer(
target_mlp: dict,
hidden_size: int,
) -> ExprPlan:
- """Passthrough for MLP weights."""
+ """Passthrough for MLP weights and biases."""
source_mlp_path = W("model", "decoder", "blocks", source_layer_idx, "mlp")
target_mlp_path = W("model", "decoder", "blocks", target_layer_idx, "mlp")
@@ -1019,10 +1160,36 @@ def _plan_mlp_transfer(
f"Use 'init: random' to initialize randomly."
)
- return ExprPlan(mappings={
- target_mlp_path / proj / "weight": Ref(key=source_mlp_path / proj / "weight")
- for proj in ["gate_proj", "up_proj", "down_proj"]
- })
+ # Check per-layer bias configuration
+ layer_1_bias = _get_mlp_bias_enabled(target_mlp, "layer_1")
+ layer_2_bias = _get_mlp_bias_enabled(target_mlp, "layer_2")
+
+ # Check if gated MLP (has gate_proj) or non-gated (only up_proj)
+ gated = target_mlp.get("gated", True) # Default to gated for backwards compatibility
+
+ # Passthrough weights
+ # layer_1 = gate_proj + up_proj (gated) or just up_proj (non-gated)
+ # layer_2 = down_proj
+ if gated:
+ weight_projs = ["gate_proj", "up_proj", "down_proj"]
+ else:
+ weight_projs = ["up_proj", "down_proj"]
+
+ mappings: dict[W, Expr] = {
+ target_mlp_path / proj / "weight": Ref(key=source_mlp_path / proj / "weight") for proj in weight_projs
+ }
+
+ # Passthrough biases if enabled
+ if layer_1_bias:
+ if gated:
+ mappings[target_mlp_path / "gate_proj" / "bias"] = Ref(key=source_mlp_path / "gate_proj" / "bias")
+ mappings[target_mlp_path / "up_proj" / "bias"] = Ref(key=source_mlp_path / "up_proj" / "bias")
+
+ # layer_2 = down_proj
+ if layer_2_bias:
+ mappings[target_mlp_path / "down_proj" / "bias"] = Ref(key=source_mlp_path / "down_proj" / "bias")
+
+ return ExprPlan(mappings=mappings)
def _plan_random_mlp(
@@ -1030,20 +1197,41 @@ def _plan_random_mlp(
target_mlp: dict,
hidden_size: int,
) -> ExprPlan:
- """Random initialization for MLP."""
+ """Random initialization for MLP weights and biases."""
target_mlp_path = W("model", "decoder", "blocks", target_layer_idx, "mlp")
intermediate_size = target_mlp["intermediate_size"]
- return ExprPlan(mappings={
- target_mlp_path / "gate_proj" / "weight": Init(
- shape=(intermediate_size, hidden_size), init_type="kaiming"
- ),
- target_mlp_path / "up_proj" / "weight": Init(
+
+ # Check per-layer bias configuration
+ layer_1_bias = _get_mlp_bias_enabled(target_mlp, "layer_1")
+ layer_2_bias = _get_mlp_bias_enabled(target_mlp, "layer_2")
+
+ # Check if gated MLP (has gate_proj) or non-gated (only up_proj)
+ gated = target_mlp.get("gated", True) # Default to gated for backwards compatibility
+
+ # Random init weights
+ mappings: dict[W, Expr] = {}
+ if gated:
+ mappings[target_mlp_path / "gate_proj" / "weight"] = Init(
shape=(intermediate_size, hidden_size), init_type="kaiming"
- ),
- target_mlp_path / "down_proj" / "weight": Init(
- shape=(hidden_size, intermediate_size), init_type="kaiming"
- ),
- })
+ )
+ mappings[target_mlp_path / "up_proj" / "weight"] = Init(
+ shape=(intermediate_size, hidden_size), init_type="kaiming"
+ )
+ mappings[target_mlp_path / "down_proj" / "weight"] = Init(
+ shape=(hidden_size, intermediate_size), init_type="kaiming"
+ )
+
+ # Random init biases if enabled
+ if layer_1_bias:
+ if gated:
+ mappings[target_mlp_path / "gate_proj" / "bias"] = Init(shape=(intermediate_size,), init_type="zeros")
+ mappings[target_mlp_path / "up_proj" / "bias"] = Init(shape=(intermediate_size,), init_type="zeros")
+
+ # layer_2 = down_proj
+ if layer_2_bias:
+ mappings[target_mlp_path / "down_proj" / "bias"] = Init(shape=(hidden_size,), init_type="zeros")
+
+ return ExprPlan(mappings=mappings)
def _plan_norms(
@@ -1083,10 +1271,12 @@ def _plan_norms_transfer(
f"Use 'init: random' to initialize randomly."
)
- return ExprPlan(mappings={
- target_layer / norm_name / "weight": Ref(key=source_layer / norm_name / "weight")
- for norm_name in ["input_layernorm", "post_attention_layernorm"]
- })
+ return ExprPlan(
+ mappings={
+ target_layer / norm_name / "weight": Ref(key=source_layer / norm_name / "weight")
+ for norm_name in ["input_layernorm", "post_attention_layernorm"]
+ }
+ )
def _plan_random_norms(
@@ -1095,7 +1285,9 @@ def _plan_random_norms(
) -> ExprPlan:
"""Random initialization for normalization layers."""
target_layer = W("model", "decoder", "blocks", target_layer_idx)
- return ExprPlan(mappings={
- target_layer / norm_name / "weight": Init(shape=(hidden_size,), init_type="ones")
- for norm_name in ["input_layernorm", "post_attention_layernorm"]
- })
+ return ExprPlan(
+ mappings={
+ target_layer / norm_name / "weight": Init(shape=(hidden_size,), init_type="ones")
+ for norm_name in ["input_layernorm", "post_attention_layernorm"]
+ }
+ )
diff --git a/fast_llm_external_models/apriel2/conversion/executor.py b/fast_llm_external_models/apriel2/conversion/executor.py
index a6c5672f0..b0779c97f 100644
--- a/fast_llm_external_models/apriel2/conversion/executor.py
+++ b/fast_llm_external_models/apriel2/conversion/executor.py
@@ -29,7 +29,8 @@
from __future__ import annotations
import hashlib
-from typing import Callable, Iterator
+from collections.abc import Iterator
+from typing import Callable
import torch
from torch import Tensor
@@ -81,8 +82,7 @@ def execute(
break
else:
raise ValueError(
- "Cannot infer device/dtype: plan has no source references. "
- "Provide device and dtype explicitly."
+ "Cannot infer device/dtype: plan has no source references. " "Provide device and dtype explicitly."
)
generator = torch.Generator(device=device)
@@ -94,10 +94,7 @@ def execute(
# Verify device/dtype consistency
for key, tensor in sources.items():
if tensor.device != device or tensor.dtype != dtype:
- raise ValueError(
- f"Source {key} has {tensor.device}/{tensor.dtype}, "
- f"expected {device}/{dtype}"
- )
+ raise ValueError(f"Source {key} has {tensor.device}/{tensor.dtype}, " f"expected {device}/{dtype}")
# Deterministic per-target seed
key_offset = int(hashlib.md5(str(target_key).encode()).hexdigest()[:8], 16)
diff --git a/fast_llm_external_models/apriel2/conversion/expr.py b/fast_llm_external_models/apriel2/conversion/expr.py
index 4867a27ae..34ea106fc 100644
--- a/fast_llm_external_models/apriel2/conversion/expr.py
+++ b/fast_llm_external_models/apriel2/conversion/expr.py
@@ -52,7 +52,8 @@
import math
from collections import defaultdict
-from typing import Annotated, Any, Callable, Iterator, Literal, TypedDict, Union, Unpack
+from collections.abc import Iterator
+from typing import Annotated, Any, Callable, Literal, TypedDict, Union, Unpack
import torch
from pydantic import BaseModel, ConfigDict, Field, GetCoreSchemaHandler, TypeAdapter
@@ -60,7 +61,6 @@
from pydantic_core import CoreSchema, core_schema
from torch import Tensor
-
# =============================================================================
# Weight Path Builder
# =============================================================================
@@ -78,7 +78,7 @@ class W(str):
mappings[q] = Ref(key=source_q)
"""
- def __new__(cls, *parts) -> "W":
+ def __new__(cls, *parts) -> W:
# Join parts, stripping any leading/trailing dots from each
cleaned = []
for p in parts:
@@ -89,12 +89,12 @@ def __new__(cls, *parts) -> "W":
cleaned.append(s)
return super().__new__(cls, ".".join(cleaned))
- def __truediv__(self, other) -> "W":
+ def __truediv__(self, other) -> W:
if isinstance(other, (list, tuple)):
return W(self, *other)
return W(self, other)
- def __rtruediv__(self, other) -> "W":
+ def __rtruediv__(self, other) -> W:
return W(other, self)
@classmethod
@@ -156,7 +156,7 @@ class Slice(BaseModel):
model_config = ConfigDict(frozen=True)
type: Literal["slice"] = "slice"
- expr: "Expr"
+ expr: Expr
slices: tuple[tuple[int | None, int | None, int | None], ...]
def find_refs(self) -> set[W]:
@@ -184,7 +184,7 @@ class Concat(BaseModel):
model_config = ConfigDict(frozen=True)
type: Literal["concat"] = "concat"
- exprs: tuple["Expr", ...]
+ exprs: tuple[Expr, ...]
dim: int = 0
def find_refs(self) -> set[W]:
@@ -303,7 +303,7 @@ class Reshape(BaseModel):
model_config = ConfigDict(frozen=True)
type: Literal["reshape"] = "reshape"
- expr: "Expr"
+ expr: Expr
shape: tuple[int, ...]
def find_refs(self) -> set[W]:
@@ -442,10 +442,10 @@ def __getitem__(self, key: W) -> Expr:
def __contains__(self, key: W) -> bool:
return key in self.mappings
- def __or__(self, other: "ExprPlan") -> "ExprPlan":
+ def __or__(self, other: ExprPlan) -> ExprPlan:
return compose(self, other)
- def __add__(self, other: "ExprPlan") -> "ExprPlan":
+ def __add__(self, other: ExprPlan) -> ExprPlan:
return merge(self, other)
def source_keys(self) -> set[str]:
@@ -471,7 +471,7 @@ def summary(self) -> dict[str, Any]:
"metadata": self.metadata,
}
- def fuse(self) -> "ExprPlan":
+ def fuse(self) -> ExprPlan:
return ExprPlan(
mappings={k: fuse(v) for k, v in self.mappings.items()},
source_format=self.source_format,
diff --git a/fast_llm_external_models/apriel2/conversion/io.py b/fast_llm_external_models/apriel2/conversion/io.py
index e1a261d7e..1f64df0b9 100644
--- a/fast_llm_external_models/apriel2/conversion/io.py
+++ b/fast_llm_external_models/apriel2/conversion/io.py
@@ -62,7 +62,7 @@ def __init__(self, files: list[Path], device: str = "cpu"):
self._handles: dict[Path, Any] = {}
self._key_index: dict[str, Path] = {}
- def __enter__(self) -> "SafetensorLoader":
+ def __enter__(self) -> SafetensorLoader:
# Pre-build index: key -> file (one-time O(n×m), then O(1) lookups)
for f in self.files:
handle = safe_open(f, framework="pt", device=self.device)
@@ -128,7 +128,7 @@ def __init__(
self._finalized: bool = False
self._result_path: Path | None = None
- def __enter__(self) -> "ShardedSafetensorWriter":
+ def __enter__(self) -> ShardedSafetensorWriter:
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
@@ -180,8 +180,7 @@ def _flush(self) -> None:
shard_file = self.output_dir / f"{self.base_name}-{self._shard_index:05d}.safetensors.tmp"
logger.debug(
- f"Writing shard {self._shard_index}: {len(self._buffer)} tensors, "
- f"{self._buffer_bytes / 1e9:.2f} GB"
+ f"Writing shard {self._shard_index}: {len(self._buffer)} tensors, " f"{self._buffer_bytes / 1e9:.2f} GB"
)
save_file(self._buffer, shard_file)
self._shard_files.append(shard_file)
diff --git a/fast_llm_external_models/apriel2/conversion/llava/plan.py b/fast_llm_external_models/apriel2/conversion/llava/plan.py
index df485efbd..a97e46c1a 100644
--- a/fast_llm_external_models/apriel2/conversion/llava/plan.py
+++ b/fast_llm_external_models/apriel2/conversion/llava/plan.py
@@ -1,11 +1,6 @@
"""Llava to Apriel2 weight conversion plan."""
-from fast_llm_external_models.apriel2.conversion.expr import (
- Expr,
- ExprPlan,
- Ref,
- W,
-)
+from fast_llm_external_models.apriel2.conversion.expr import Expr, ExprPlan, Ref, W
def plan_llava_to_apriel2(llava_config: dict) -> ExprPlan:
diff --git a/fast_llm_external_models/apriel2/conversion/qwen2/__init__.py b/fast_llm_external_models/apriel2/conversion/qwen2/__init__.py
new file mode 100644
index 000000000..d0a0b8e6e
--- /dev/null
+++ b/fast_llm_external_models/apriel2/conversion/qwen2/__init__.py
@@ -0,0 +1,6 @@
+"""Qwen2/Qwen2.5 to Apriel2 conversion module."""
+
+from fast_llm_external_models.apriel2.conversion.qwen2.config import convert_config
+from fast_llm_external_models.apriel2.conversion.qwen2.plan import plan_qwen2_to_apriel2
+
+__all__ = ["convert_config", "plan_qwen2_to_apriel2"]
diff --git a/fast_llm_external_models/apriel2/conversion/qwen2/config.py b/fast_llm_external_models/apriel2/conversion/qwen2/config.py
new file mode 100644
index 000000000..70629fe0e
--- /dev/null
+++ b/fast_llm_external_models/apriel2/conversion/qwen2/config.py
@@ -0,0 +1,79 @@
+"""Qwen2/Qwen2.5 to Apriel2 config conversion."""
+
+
+def convert_config(qwen2_config: dict) -> dict:
+ """Convert Qwen2/Qwen2.5 config to Apriel2TextConfig format.
+
+ Qwen2.5 architecture:
+ - Standard transformer with GQA (grouped query attention)
+ - QKV bias enabled, O bias disabled
+ - MLP bias disabled
+ - Gated SwiGLU MLP
+ - RMSNorm
+ - RoPE embeddings
+
+ Args:
+ qwen2_config: HuggingFace Qwen2Config as dict
+
+ Returns:
+ Apriel2TextConfig-compatible dict
+ """
+ hidden_size = qwen2_config["hidden_size"]
+ num_attention_heads = qwen2_config["num_attention_heads"]
+ num_key_value_heads = qwen2_config.get("num_key_value_heads", num_attention_heads)
+ head_dim = hidden_size // num_attention_heads
+
+ # Qwen2 uses QKV bias but not O bias - mirror Fast-LLM's per-layer config
+ return {
+ "model_type": "apriel2_text",
+ "architectures": ["Apriel2ForCausalLM"],
+ "auto_map": {
+ "AutoConfig": "configuration_apriel2.Apriel2TextConfig",
+ "AutoModel": "modeling_apriel2.Apriel2TextModel",
+ "AutoModelForCausalLM": "modeling_apriel2.Apriel2ForCausalLM",
+ },
+ "hidden_size": hidden_size,
+ "vocab_size": qwen2_config["vocab_size"],
+ "tie_word_embeddings": qwen2_config.get("tie_word_embeddings", False),
+ "decoder": {
+ "type": "fixed",
+ "num_blocks": qwen2_config["num_hidden_layers"],
+ "block": {
+ "mixer": {
+ "type": "attention",
+ "heads": num_attention_heads,
+ "head_groups": num_key_value_heads,
+ "head_size": head_dim,
+ # Per-layer bias config matching Fast-LLM structure
+ "query_layer": {"bias": {"enabled": True}},
+ "key_layer": {"bias": {"enabled": True}},
+ "value_layer": {"bias": {"enabled": True}},
+ "dense_layer": {"bias": {"enabled": False}},
+ "rotary": {
+ "type": "mistral_1d",
+ "theta": qwen2_config.get("rope_theta", 1000000.0),
+ },
+ },
+ "mlp": {
+ "type": "mlp",
+ "intermediate_size": qwen2_config["intermediate_size"],
+ "activation": qwen2_config.get("hidden_act", "silu"),
+ "gated": True,
+ "add_linear_biases": False,
+ },
+ "normalization": {
+ "type": "rms_norm",
+ "epsilon": qwen2_config.get("rms_norm_eps", 1e-6),
+ },
+ },
+ },
+ "head": {
+ "normalization": {
+ "type": "rms_norm",
+ "epsilon": qwen2_config.get("rms_norm_eps", 1e-6),
+ }
+ },
+ "embeddings": {
+ "max_position_embeddings": qwen2_config.get("max_position_embeddings", 32768),
+ },
+ }
diff --git a/fast_llm_external_models/apriel2/conversion/qwen2/plan.py b/fast_llm_external_models/apriel2/conversion/qwen2/plan.py
new file mode 100644
index 000000000..c1ec4af8b
--- /dev/null
+++ b/fast_llm_external_models/apriel2/conversion/qwen2/plan.py
@@ -0,0 +1,100 @@
+"""Qwen2/Qwen2.5 to Apriel2 weight conversion plan."""
+
+from fast_llm_external_models.apriel2.conversion.expr import Expr, ExprPlan, Ref, W
+
+
+def plan_qwen2_to_apriel2(qwen2_config: dict) -> ExprPlan:
+ """Build an expression plan for Qwen2/Qwen2.5 to Apriel2 conversion.
+
+ This is a pure mapping (all Ref expressions) since Qwen2→Apriel2
+ is just renaming keys. The weight tensors are identical.
+
+ Key mapping (source keys have "model." prefix in safetensors):
+ Qwen2 (safetensor key) Apriel2
+ ---------------------- -------
+ model.embed_tokens.weight -> model.embed_tokens.weight
+ model.norm.weight -> model.norm.weight
+ model.layers.{i}.input_layernorm.weight -> model.decoder.blocks.{i}.input_layernorm.weight
+ model.layers.{i}.post_attention_layernorm.weight -> model.decoder.blocks.{i}.post_attention_layernorm.weight
+ model.layers.{i}.self_attn.q_proj.weight -> model.decoder.blocks.{i}.mixer.q_proj.weight
+ model.layers.{i}.self_attn.q_proj.bias -> model.decoder.blocks.{i}.mixer.q_proj.bias
+ model.layers.{i}.self_attn.k_proj.weight -> model.decoder.blocks.{i}.mixer.k_proj.weight
+ model.layers.{i}.self_attn.k_proj.bias -> model.decoder.blocks.{i}.mixer.k_proj.bias
+ model.layers.{i}.self_attn.v_proj.weight -> model.decoder.blocks.{i}.mixer.v_proj.weight
+ model.layers.{i}.self_attn.v_proj.bias -> model.decoder.blocks.{i}.mixer.v_proj.bias
+ model.layers.{i}.self_attn.o_proj.weight -> model.decoder.blocks.{i}.mixer.o_proj.weight
+ model.layers.{i}.mlp.gate_proj.weight -> model.decoder.blocks.{i}.mlp.gate_proj.weight
+ model.layers.{i}.mlp.up_proj.weight -> model.decoder.blocks.{i}.mlp.up_proj.weight
+ model.layers.{i}.mlp.down_proj.weight -> model.decoder.blocks.{i}.mlp.down_proj.weight
+
+ Note: Qwen2 has QKV biases but no O bias. The Apriel2 config uses per-layer
+ bias settings (query_layer.bias.enabled=True, dense_layer.bias.enabled=False)
+ to match this exactly - no workarounds needed.
+
+ Args:
+ qwen2_config: HuggingFace Qwen2Config as dict
+
+ Returns:
+ ExprPlan with Ref mappings
+ """
+ mappings: dict[str, Expr] = {}
+
+ num_layers = qwen2_config["num_hidden_layers"]
+
+ # Static mappings (embeddings and final norm)
+ # Note: Qwen2 safetensor keys have "model." prefix
+ static_mappings = [
+ (W("model", "embed_tokens", "weight"), W("model", "embed_tokens", "weight")),
+ (W("model", "norm", "weight"), W("model", "norm", "weight")),
+ ]
+
+ # lm_head - only if not tied
+ if not qwen2_config.get("tie_word_embeddings", False):
+ static_mappings.append((W("lm_head", "weight"), W("lm_head", "weight")))
+
+ for src, tgt in static_mappings:
+ mappings[tgt] = Ref(key=src)
+
+ # Layer mappings
+ for layer in range(num_layers):
+ # Source has "model.layers.{i}" prefix
+ qwen_layer = W("model", "layers", layer)
+ apriel_layer = W("model", "decoder", "blocks", layer)
+
+ # Attention projection weights
+ for proj in ["q_proj", "k_proj", "v_proj", "o_proj"]:
+ src = qwen_layer / "self_attn" / proj / "weight"
+ tgt = apriel_layer / "mixer" / proj / "weight"
+ mappings[tgt] = Ref(key=src)
+
+ # QKV biases (Qwen2 has these, but not O bias)
+ for proj in ["q_proj", "k_proj", "v_proj"]:
+ src = qwen_layer / "self_attn" / proj / "bias"
+ tgt = apriel_layer / "mixer" / proj / "bias"
+ mappings[tgt] = Ref(key=src)
+
+ # Note: o_proj has no bias in Qwen2, and Apriel2 config has dense_layer.bias.enabled=False
+
+ # MLP projections
+ for proj in ["gate_proj", "up_proj", "down_proj"]:
+ src = qwen_layer / "mlp" / proj / "weight"
+ tgt = apriel_layer / "mlp" / proj / "weight"
+ mappings[tgt] = Ref(key=src)
+
+ # Layer norms
+ mappings[apriel_layer / "input_layernorm" / "weight"] = Ref(key=qwen_layer / "input_layernorm" / "weight")
+ mappings[apriel_layer / "post_attention_layernorm" / "weight"] = Ref(
+ key=qwen_layer / "post_attention_layernorm" / "weight"
+ )
+
+ return ExprPlan(
+ mappings=mappings,
+ source_format="qwen2",
+ target_format="apriel2",
+ metadata={
+ "num_layers": num_layers,
+ "hidden_size": qwen2_config["hidden_size"],
+ "num_attention_heads": qwen2_config["num_attention_heads"],
+ "num_key_value_heads": qwen2_config.get("num_key_value_heads", qwen2_config["num_attention_heads"]),
+ },
+ )
diff --git a/fast_llm_external_models/apriel2/conversion/render.py b/fast_llm_external_models/apriel2/conversion/render.py
index d71fa03e1..f9a0c8ac1 100644
--- a/fast_llm_external_models/apriel2/conversion/render.py
+++ b/fast_llm_external_models/apriel2/conversion/render.py
@@ -8,17 +8,11 @@
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
+from fast_llm_external_models.apriel2.conversion.expr import Concat, Init, Ref, Reshape, Slice
+
if TYPE_CHECKING:
from fast_llm_external_models.apriel2.conversion.expr import Expr, ExprPlan
-from fast_llm_external_models.apriel2.conversion.expr import (
- Concat,
- Init,
- Ref,
- Reshape,
- Slice,
-)
-
@dataclass
class PlanTreeNode:
@@ -28,10 +22,10 @@ class PlanTreeNode:
After merging, leaf nodes contain aggregated values from multiple siblings.
"""
- children: dict[str, "PlanTreeNode"] = field(default_factory=dict)
+ children: dict[str, PlanTreeNode] = field(default_factory=dict)
# For leaf nodes: list of (sibling_key, expr) pairs
# Before merge: single item, after merge: multiple items from merged siblings
- values: list[tuple[str, "Expr"]] = field(default_factory=list)
+ values: list[tuple[str, Expr]] = field(default_factory=list)
def is_leaf(self) -> bool:
return len(self.children) == 0
@@ -61,7 +55,7 @@ def _build_plan_tree(plan: ExprPlan) -> PlanTreeNode:
return root
-def _expr_signature(expr: "Expr") -> tuple:
+def _expr_signature(expr: Expr) -> tuple:
"""Get a signature for an expression that determines merge compatibility.
Expressions with different signatures should not be merged together.
@@ -453,7 +447,7 @@ def _render_plan_tree(
)
-def _format_aggregated_leaf(values: list[tuple[str, "Expr"]]) -> str:
+def _format_aggregated_leaf(values: list[tuple[str, Expr]]) -> str:
"""Format a leaf with aggregated values using pattern discovery.
Args:
@@ -494,7 +488,7 @@ def _format_aggregated_leaf(values: list[tuple[str, "Expr"]]) -> str:
return _format_single_expr(first_expr)
-def _format_single_expr(expr: "Expr") -> str:
+def _format_single_expr(expr: Expr) -> str:
"""Format a single expression using ML notation."""
match expr:
case Ref(key=key):
@@ -531,7 +525,7 @@ def _format_single_expr(expr: "Expr") -> str:
return f"= {type(expr).__name__}"
-def _format_concat_part(expr: "Expr") -> str:
+def _format_concat_part(expr: Expr) -> str:
"""Format a single part of a concat (for short display)."""
match expr:
case Ref(key=key):
@@ -570,7 +564,7 @@ def _format_slice_notation(slices: tuple) -> str:
return f"[{', '.join(slice_strs)}]"
-def _format_aggregated_concat(values: list[tuple[str, "Expr"]]) -> str:
+def _format_aggregated_concat(values: list[tuple[str, Expr]]) -> str:
"""Format aggregated Concat expressions with pattern discovery."""
# Get the first concat to understand structure
first_concat = values[0][1]
@@ -590,7 +584,7 @@ def _format_aggregated_concat(values: list[tuple[str, "Expr"]]) -> str:
return f"= [{sep.join(formatted_parts)}]"
-def _format_aggregated_concat_part(values: list[tuple[str, "Expr"]]) -> str:
+def _format_aggregated_concat_part(values: list[tuple[str, Expr]]) -> str:
"""Format a single part of an aggregated concat."""
if len(values) == 1:
return _format_concat_part(values[0][1])
@@ -619,7 +613,7 @@ def _format_aggregated_concat_part(values: list[tuple[str, "Expr"]]) -> str:
return _format_concat_part(first_expr)
-def _format_aggregated_slice(values: list[tuple[str, "Expr"]]) -> str:
+def _format_aggregated_slice(values: list[tuple[str, Expr]]) -> str:
"""Format aggregated Slice expressions with pattern discovery."""
first_slice = values[0][1]
if not isinstance(first_slice, Slice):
diff --git a/fast_llm_external_models/apriel2/convert.py b/fast_llm_external_models/apriel2/convert.py
index cbf921b31..66c419dfd 100644
--- a/fast_llm_external_models/apriel2/convert.py
+++ b/fast_llm_external_models/apriel2/convert.py
@@ -15,6 +15,7 @@
Supported source formats:
- llava: Llava/Pixtral models
+- qwen2: Qwen2/Qwen2.5 models
- apriel2: Apriel2 models (surgery-only mode - no conversion, just apply surgeries)
"""
@@ -29,10 +30,7 @@
import yaml
from tqdm import tqdm
-# Allow running as script or module
-if __name__ == "__main__":
- sys.path.insert(0, str(Path(__file__).parent.parent.parent))
-
+# Import source-specific converters
from fast_llm_external_models.apriel2.conversion import (
DEFAULT_MAX_SHARD_SIZE,
ExprPlan,
@@ -41,11 +39,16 @@
StreamingExecutor,
compose,
compose_configs,
- plan_surgery,
)
-
-# Import source-specific converters
from fast_llm_external_models.apriel2.conversion import llava as llava_converter
+from fast_llm_external_models.apriel2.conversion import plan_surgery
+from fast_llm_external_models.apriel2.conversion import qwen2 as qwen2_converter
+from fast_llm_external_models.apriel2.conversion import strip_init_fields
+
+# Allow running as script or module
+if __name__ == "__main__":
+ sys.path.insert(0, str(Path(__file__).parent.parent.parent))
+
logger = logging.getLogger(__name__)
@@ -73,6 +76,7 @@ def _identity_plan(config: dict) -> ExprPlan:
# Each entry maps format name to (config_converter, plan_builder)
SOURCE_FORMATS: dict[str, tuple[Callable[[dict], dict], Callable[[dict], ExprPlan]]] = {
"llava": (llava_converter.convert_config, llava_converter.plan_llava_to_apriel2),
+ "qwen2": (qwen2_converter.convert_config, qwen2_converter.plan_qwen2_to_apriel2),
"apriel2": (_identity_config, _identity_plan),
}
@@ -88,8 +92,12 @@ def detect_source_format(config: dict) -> str | None:
if model_type in ("llava", "pixtral") or "text_config" in config:
return "llava"
+ # Qwen2/Qwen2.5 detection
+ if model_type == "qwen2":
+ return "qwen2"
+
# Apriel2 detection - check for Apriel2-specific structure
- if model_type == "apriel2" or "decoder" in config:
+ if model_type in ("apriel2", "apriel2_text") or "decoder" in config:
return "apriel2"
return None
@@ -142,15 +150,21 @@ def build_plan(
# Apply surgery chain if requested
if surgery_configs:
for i, surgery_config in enumerate(surgery_configs, 1):
- surgery_plan = plan_surgery(current_config, surgery_config)
- logger.info(f"Built surgery plan [{i}/{len(surgery_configs)}]: {surgery_plan.summary()['num_targets']} targets")
+ # S × P → T: compose state with surgery to get transition spec
+ target_config = compose_configs(current_config, surgery_config)
+
+ # S × T → Plan: build plan from source state and transition spec
+ surgery_plan = plan_surgery(current_config, target_config)
+ logger.info(
+ f"Built surgery plan [{i}/{len(surgery_configs)}]: {surgery_plan.summary()['num_targets']} targets"
+ )
- # Compose: current -> surgery
+ # Compose plans
current_plan = compose(current_plan, surgery_plan)
logger.info(f"Composed plan [{i}/{len(surgery_configs)}]: {current_plan.summary()['num_targets']} targets")
- # Compose configs: merge surgery spec into current config
- current_config = compose_configs(current_config, surgery_config)
+ # T → S: strip init for next iteration (init is consumed by plan_surgery)
+ current_config = strip_init_fields(target_config)
return current_plan, current_config
@@ -211,9 +225,7 @@ def convert(
executor = StreamingExecutor(full_plan, loader)
with ShardedSafetensorWriter(output_dir, max_shard_size=max_shard_size) as writer:
- for target_key, tensor in tqdm(
- executor.execute(seed), desc="Converting", total=len(full_plan)
- ):
+ for target_key, tensor in tqdm(executor.execute(seed), desc="Converting", total=len(full_plan)):
writer.add(target_key, tensor)
return final_config
@@ -282,9 +294,7 @@ def resolve_input(input_path: str) -> Path:
def main():
- parser = argparse.ArgumentParser(
- description="Convert HuggingFace checkpoint to Apriel2 HF format"
- )
+ parser = argparse.ArgumentParser(description="Convert HuggingFace checkpoint to Apriel2 HF format")
parser.add_argument(
"input",
type=str,
@@ -384,8 +394,7 @@ def main():
safetensor_files = sorted(input_dir.glob("*.safetensors"))
if not safetensor_files:
raise ValueError(
- f"No safetensor files found in {input_dir}. "
- "Plan-based conversion requires safetensor files."
+ f"No safetensor files found in {input_dir}. " "Plan-based conversion requires safetensor files."
)
# Convert using plan-based approach with streaming sharded output
@@ -400,11 +409,11 @@ def main():
show_plan=args.show_plan or args.verbose,
)
- # Save config
+ # Save config (build_plan returns S which has no init, but strip defensively)
output_config_file = args.output_dir / "config.json"
logger.info(f"Saving config to {output_config_file}")
with open(output_config_file, "w") as f:
- json.dump(apriel2_config, f, indent=2)
+ json.dump(strip_init_fields(apriel2_config), f, indent=2)
# Copy tokenizer files
copy_tokenizer_files(input_dir, args.output_dir)
diff --git a/fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml b/fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml
new file mode 100644
index 000000000..34672916c
--- /dev/null
+++ b/fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml
@@ -0,0 +1,103 @@
+# Dataset preparation config for Tulu 3 SFT mixture with Qwen2 tokenizer
+#
+# This config converts the Tulu 3 SFT dataset (conversation format) to
+# Fast-LLM's memmap format, with automatic loss masking span computation
+# to train only on assistant responses.
+#
+# =============================================================================
+# TOKENIZER SETUP (one-time)
+# =============================================================================
+#
+# The tokenizer must have a chat template with {% generation %} markers.
+# Qwen2's default template doesn't have these, so we need to patch it.
+#
+# IMPORTANT: The entire assistant turn (opening tag + content + closing tag)
+# must be inside the {% generation %} block. This ensures the model learns to
+# produce the full assistant response including special tokens like <|im_end|>.
+# Reference: https://huggingface.co/HuggingFaceTB/SmolLM3-3B/blob/main/chat_template.jinja
+#
+# Run this Python script to create a patched tokenizer:
+#
+# from transformers import AutoTokenizer
+#
+# tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
+#
+# # Patch chat template: wrap ENTIRE assistant turn in generation markers
+# tokenizer.chat_template = '''{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system
+# You are a helpful assistant.<|im_end|>
+# ' }}{% endif %}{% if message['role'] == 'assistant' %}{% generation %}{{ '<|im_start|>assistant
+# ' + message['content'] + '<|im_end|>
+# ' }}{% endgeneration %}{% else %}{{ '<|im_start|>' + message['role'] + '
+# ' + message['content'] + '<|im_end|>
+# ' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant
+# ' }}{% endif %}'''
+#
+# tokenizer.save_pretrained("/path/to/qwen2-instruct-with-markers")
+#
+# =============================================================================
+# DATA PREPARATION
+# =============================================================================
+#
+# Small dataset (for testing):
+#
+# fast-llm prepare gpt_memmap \
+# -c fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml \
+# dataset.split=train[:1000] \
+# output_path=/path/to/tulu3-prepared-small
+#
+# Full dataset (~939K samples, ~6 minutes):
+#
+# fast-llm prepare gpt_memmap \
+# -c fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml
+#
+# =============================================================================
+# VERIFICATION
+# =============================================================================
+#
+# To verify the prepared dataset has loss masking spans:
+#
+# import pathlib
+# from fast_llm.data.dataset.memmap import MemmapDataset
+# from fast_llm.data.sample.language_model import LanguageModelSample
+# from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig
+#
+# dataset = MemmapDataset[LanguageModelSample](
+# 'tulu3',
+# pathlib.Path('/path/to/tulu3-prepared/shard_0_0.fast_llm_dataset'),
+# LanguageModelPreprocessingConfig(use_loss_masking_spans=True)
+# )
+#
+# doc = dataset.get_document(0)
+# print(f'Tokens: {len(doc.tokens.tokens)}')
+# print(f'Loss masking spans: {doc.loss_masking_spans.ranges}')
+#
+# =============================================================================
+
+# Dataset configuration
+dataset:
+ # Tulu 3 SFT mixture from AllenAI
+ path: allenai/tulu-3-sft-mixture
+ split: train
+
+ # Source schema for conversation format
+ source_schema:
+ # Use conversation type (vs default "document" type)
+ type: conversation
+
+ # Column containing the messages list
+ messages: messages
+
+# Tokenizer configuration
+# IMPORTANT: Must use a tokenizer with {% generation %} markers in its chat template.
+# See instructions above to create a patched Qwen2 tokenizer.
+tokenizer:
+ path: /path/to/qwen2-instruct-with-markers
+ # Qwen2 doesn't have a BOS token by default, use <|endoftext|> as BOS
+ bos_token: "<|endoftext|>"
+
+# Output configuration
+output_path: /path/to/tulu3-prepared
+
+# Processing configuration
+num_workers: 8
+documents_per_shard: 100000
diff --git a/fast_llm_external_models/apriel2/examples/train_supernet_qwen2.yaml b/fast_llm_external_models/apriel2/examples/train_supernet_qwen2.yaml
new file mode 100644
index 000000000..5b190955f
--- /dev/null
+++ b/fast_llm_external_models/apriel2/examples/train_supernet_qwen2.yaml
@@ -0,0 +1,193 @@
+# Training config for Qwen2-based Apriel2 stochastic supernet on Tulu 3 SFT data
+#
+# This config trains a stochastic supernet where each layer can sample from
+# multiple mixer types (attention, sliding window, gated delta net, KDA).
+# Only the mixer weights are trained; all other weights are frozen.
+# Activation-level distillation from a teacher model guides the training.
+#
+# =============================================================================
+# PREREQUISITES
+# =============================================================================
+#
+# 1. TOKENIZER SETUP
+#
+# Qwen2's default chat template doesn't have generation markers needed for
+# loss masking. Create a patched tokenizer following the SmolLM3 pattern:
+# https://huggingface.co/HuggingFaceTB/SmolLM3-3B/blob/main/chat_template.jinja
+#
+# IMPORTANT: The ENTIRE assistant turn (opening tag + content + closing tag)
+# must be inside {% generation %}...{% endgeneration %} markers.
+#
+# from transformers import AutoTokenizer
+# tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
+# # Wrap entire assistant turn in generation markers (NOT just content!)
+# tokenizer.chat_template = '''{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system
+# You are a helpful assistant.<|im_end|>
+# ' }}{% endif %}{% if message['role'] == 'assistant' %}{% generation %}{{ '<|im_start|>assistant
+# ' + message['content'] + '<|im_end|>
+# ' }}{% endgeneration %}{% else %}{{ '<|im_start|>' + message['role'] + '
+# ' + message['content'] + '<|im_end|>
+# ' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant
+# ' }}{% endif %}'''
+# tokenizer.save_pretrained("/path/to/qwen2-instruct-with-markers")
+#
+# 2. PREPARE TULU 3 DATASET
+#
+# Small dataset (for testing):
+#
+# fast-llm prepare gpt_memmap \
+# -c fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml \
+# tokenizer.path=/path/to/qwen2-instruct-with-markers \
+# dataset.split=train[:1000] \
+# output_path=/path/to/tulu3-prepared-small
+#
+# Full dataset (~939K samples, ~6 minutes):
+#
+# fast-llm prepare gpt_memmap \
+# -c fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml \
+# tokenizer.path=/path/to/qwen2-instruct-with-markers \
+# output_path=/path/to/tulu3-prepared
+#
+# 3. CONVERT QWEN2 TO APRIEL2 SUPERNET (student model)
+#
+# This creates a stochastic supernet with multiple mixer types per layer:
+#
+# python fast_llm_external_models/apriel2/convert.py \
+# Qwen/Qwen2.5-0.5B-Instruct \
+# /path/to/qwen2-supernet \
+# --surgery fast_llm_external_models/apriel2/examples/stochastic_supernet.yaml
+#
+# 4. CONVERT QWEN2 TO APRIEL2 (teacher model)
+#
+# The teacher is the original model without surgery, used for distillation:
+#
+# python fast_llm_external_models/apriel2/convert.py \
+# Qwen/Qwen2.5-0.5B-Instruct \
+# /path/to/qwen2-teacher
+#
+# 5. RUN TRAINING
+#
+# Update paths below and run:
+#
+# fast-llm train gpt \
+# -c fast_llm_external_models/apriel2/examples/train_supernet_qwen2.yaml
+#
+# For long runs, use nohup:
+#
+# nohup fast-llm train gpt \
+# -c fast_llm_external_models/apriel2/examples/train_supernet_qwen2.yaml \
+# > training.log 2>&1 &
+# tail -f training.log
+#
+# =============================================================================
+# PERFORMANCE TUNING
+# =============================================================================
+#
+# Default config uses seq=4096, micro_batch=2, batch=16 which gives:
+# - ~8k tokens/s/gpu throughput
+# - ~61GB GPU memory usage
+# - ~25 hours for 1B tokens on single GPU
+#
+# Adjust batch settings based on your GPU memory:
+# - Reduce micro_batch_size if OOM
+# - Increase micro_batch_size/batch_size if memory available
+#
+# =============================================================================
+# OUTPUT
+# =============================================================================
+#
+# Checkpoints: /path/to/qwen2-supernet-trained/checkpoints/{iteration}/
+# Exports: /path/to/qwen2-supernet-trained/export/apriel2_text/{iteration}/
+#
+# =============================================================================
+
+# Load pretrained model (Qwen2 converted to Apriel2 supernet)
+pretrained:
+ path: /path/to/qwen2-supernet
+ format: apriel2_text
+ model_weights: true
+ load_config: model
+
+# Model config
+model:
+ base_model:
+ # Freeze all components except the mixer
+ decoder:
+ block:
+ mlp:
+ lr_scale: 0.0 # Freeze MLP
+ normalization:
+ lr_scale: 0.0 # Freeze layer norms
+ # Activation-level distillation from teacher
+ distillation_model: teacher
+ activation_distillation_factor: 0.8
+ embeddings:
+ lr_scale: 0.0 # Freeze word embeddings
+ head:
+ lr_scale: 0.0 # Freeze output head
+ cross_entropy_implementation: torch
+ multi_stage:
+ zero_stage: 2
+ distributed:
+ compute_dtype: bf16
+ seed: 42
+
+# Teacher model for activation-level distillation
+reference_models:
+ teacher:
+ model:
+ type: gpt
+ pretrained:
+ path: /path/to/qwen2-teacher
+ format: apriel2_text
+ model_weights: true
+ load_config: model
+
+# Batch configuration (tuned for ~61GB GPU memory, ~8k tokens/s)
+batch:
+ sequence_length: 4096
+ micro_batch_size: 2
+ batch_size: 16
+
+# Data configuration (prepared Tulu 3 dataset)
+data:
+ datasets:
+ training:
+ type: file
+ path: /path/to/tulu3-prepared/fast_llm_config.yaml
+
+# Optimizer configuration
+optimizer:
+ learning_rate:
+ base: 1.0e-05
+ decay_style: cosine
+ warmup_iterations: 100
+ decay_iterations: 10000
+ minimum: 1.0e-06
+ weight_decay: 0.1
+ beta_1: 0.9
+ beta_2: 0.95
+
+# Training configuration
+# At seq=4096, batch=16: ~65k tokens/iter, ~280 iters/hour
+# 10000 iters ≈ 650M tokens ≈ 35 hours
+training:
+ train_iters: 10000
+ num_workers: 4
+ logs:
+ interval: 10
+ checkpoint:
+ interval: 280 # ~hourly
+ export:
+ interval: 280 # ~hourly (useful for development/testing during training)
+ format: apriel2_text
+ test_iters: 0
+ evaluators: {}
+ # Weights & Biases configuration (optional, uncomment to enable)
+ # wandb:
+ # entity_name: your-entity
+ # project_name: your-project
+
+# Experiment directory
+run:
+ experiment_dir: /path/to/qwen2-supernet-trained
diff --git a/fast_llm_external_models/apriel2/examples/train_supernet_small.yaml b/fast_llm_external_models/apriel2/examples/train_supernet_small.yaml
index 78c22e57f..be4d06e0a 100644
--- a/fast_llm_external_models/apriel2/examples/train_supernet_small.yaml
+++ b/fast_llm_external_models/apriel2/examples/train_supernet_small.yaml
@@ -107,7 +107,7 @@ model:
lr_scale: 0.0 # Freeze layer norms (norm_1 and norm_2 in each block)
# Activation-level distillation: teach mixers to mimic teacher's attention outputs
distillation_model: teacher
- activation_distillation_factor: 0.1
+ activation_distillation_factor: 0.8
embeddings:
lr_scale: 0.0 # Freeze word embeddings
head:
diff --git a/fast_llm_external_models/apriel2/modeling_apriel2.py b/fast_llm_external_models/apriel2/modeling_apriel2.py
index 4c263b4e2..240240cd6 100644
--- a/fast_llm_external_models/apriel2/modeling_apriel2.py
+++ b/fast_llm_external_models/apriel2/modeling_apriel2.py
@@ -24,8 +24,8 @@
is_torch_flex_attn_available,
)
-from fast_llm_external_models.apriel2.cache import Apriel2Cache
-from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config, Apriel2TextConfig
+from .cache import Apriel2Cache
+from .configuration_apriel2 import Apriel2Config, Apriel2TextConfig
# GDN implementation - matches Fast-LLM's gdn.py exactly
try:
@@ -395,14 +395,30 @@ def __init__(self, d_model: int, mixer_config: dict, layer_idx: int, config):
# cross_document_attention: if False, use cu_seqlens to isolate sequences (e.g., images)
self.cross_document_attention = mixer_config.get("cross_document_attention", True)
- # Whether to add biases to linear projections
- add_bias = mixer_config.get("add_linear_biases", False)
-
- # Projections (Fast-LLM weight names: q_proj, k_proj, v_proj, o_proj)
- self.q_proj = nn.Linear(d_model, self.num_heads * self.head_dim, bias=add_bias)
- self.k_proj = nn.Linear(d_model, self.num_key_value_heads * self.head_dim, bias=add_bias)
- self.v_proj = nn.Linear(d_model, self.num_key_value_heads * self.head_dim, bias=add_bias)
- self.o_proj = nn.Linear(self.num_heads * self.head_dim, d_model, bias=add_bias)
+ # Bias configuration mirroring Fast-LLM's structure:
+ # - add_linear_biases: bool (default for all projections)
+ # - query_layer: {"bias": {"enabled": bool}} (per-layer override)
+ # - key_layer: {"bias": {"enabled": bool}}
+ # - value_layer: {"bias": {"enabled": bool}}
+ # - dense_layer: {"bias": {"enabled": bool}}
+ default_bias = mixer_config.get("add_linear_biases", False)
+
+ def get_layer_bias(layer_name: str) -> bool:
+ layer_cfg = mixer_config.get(layer_name, {})
+ bias_cfg = layer_cfg.get("bias", {})
+ enabled = bias_cfg.get("enabled")
+ return default_bias if enabled is None else enabled
+
+ q_bias = get_layer_bias("query_layer")
+ k_bias = get_layer_bias("key_layer")
+ v_bias = get_layer_bias("value_layer")
+ o_bias = get_layer_bias("dense_layer")
+
+ # Projections
+ self.q_proj = nn.Linear(d_model, self.num_heads * self.head_dim, bias=q_bias)
+ self.k_proj = nn.Linear(d_model, self.num_key_value_heads * self.head_dim, bias=k_bias)
+ self.v_proj = nn.Linear(d_model, self.num_key_value_heads * self.head_dim, bias=v_bias)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, d_model, bias=o_bias)
@classmethod
def setup(
@@ -1017,6 +1033,8 @@ def torch_chunk_gated_delta_rule(
if not output_final_state:
last_recurrent_state = None
+ elif last_recurrent_state is not None:
+ last_recurrent_state = last_recurrent_state.to(initial_dtype)
core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1])
core_attn_out = core_attn_out[:, :, :sequence_length]
core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
@@ -1225,7 +1243,9 @@ def forward(self, hidden_states: torch.Tensor, past_key_values=None, attention_m
mixed_qkv = self.convolution.update(
mixed_qkv.squeeze(2), # [batch, conv_dim, 1] -> [batch, conv_dim]
conv_state,
- ).unsqueeze(2) # [batch, conv_dim] -> [batch, conv_dim, 1]
+ ).unsqueeze(
+ 2
+ ) # [batch, conv_dim] -> [batch, conv_dim, 1]
else:
# Prefill mode
use_cache = past_key_values is not None
@@ -1270,8 +1290,14 @@ def forward(self, hidden_states: torch.Tensor, past_key_values=None, attention_m
output_final_state=past_key_values is not None,
use_qk_l2norm_in_kernel=True,
)
+ # Ensure state is in same dtype as hidden_states (fla kernel may return float32)
+ if last_recurrent_state is not None:
+ last_recurrent_state = last_recurrent_state.to(hidden_states.dtype)
else:
# Recurrent mode for single token decode
+ # Convert recurrent_state to match hidden_states dtype if needed
+ if recurrent_state is not None and recurrent_state.dtype != hidden_states.dtype:
+ recurrent_state = recurrent_state.to(hidden_states.dtype)
output, last_recurrent_state = self._recurrent_gated_delta_rule(
query, key, value, g, beta_gate, recurrent_state
)
@@ -1294,7 +1320,16 @@ def forward(self, hidden_states: torch.Tensor, past_key_values=None, attention_m
return (output,)
def _recurrent_gated_delta_rule(self, query, key, value, g, beta, state):
- """Single-step recurrent update for cached inference."""
+ """Single-step recurrent update for cached inference.
+
+ Input shapes: [batch, seq=1, heads, dim]
+ Need shapes: [batch, heads, dim] for einsum operations
+ """
+ # Transpose from [batch, seq, heads, dim] to [batch, heads, seq, dim]
+ query = query.transpose(1, 2)
+ key = key.transpose(1, 2)
+ value = value.transpose(1, 2)
+
# L2 normalize query and key
query = _l2norm(query, dim=-1, eps=1e-6)
key = _l2norm(key, dim=-1, eps=1e-6)
@@ -1307,7 +1342,9 @@ def _recurrent_gated_delta_rule(self, query, key, value, g, beta, state):
beta = beta.squeeze(1)
# Update state: S = exp(g) * S + beta * k^T @ v
- decay = g.exp().unsqueeze(-1).unsqueeze(-1) # [batch, heads, 1, 1]
+ # Keep everything in the same dtype as input (exp() returns float32, need to convert back)
+ input_dtype = query.dtype
+ decay = g.exp().to(input_dtype).unsqueeze(-1).unsqueeze(-1) # [batch, heads, 1, 1]
k_outer_v = torch.einsum("bhk,bhv->bhkv", key * beta.unsqueeze(-1), value)
state = decay * state + k_outer_v
@@ -1315,6 +1352,12 @@ def _recurrent_gated_delta_rule(self, query, key, value, g, beta, state):
output = torch.einsum("bhk,bhkv->bhv", query, state)
output = output.unsqueeze(2) # [batch, heads, 1, v_dim]
+ # Transpose back to [batch, seq=1, heads, v_dim]
+ output = output.transpose(1, 2)
+
+ # Ensure state matches output dtype
+ state = state.to(output.dtype)
+
return output, state
@classmethod
@@ -1447,9 +1490,7 @@ def __init__(
# Normalization - use GatedRMSNormalization (same wrapper as GDN, with sigmoid activation)
self.norm = GatedRMSNormalization(self.head_dim, eps=self.norm_eps, activation=self.norm_activation)
- def _apply_conv(
- self, x: torch.Tensor, conv: CausalConv1d, conv_state: torch.Tensor | None, use_cache: bool
- ):
+ def _apply_conv(self, x: torch.Tensor, conv: CausalConv1d, conv_state: torch.Tensor | None, use_cache: bool):
"""
Apply causal convolution with cache support.
@@ -1828,16 +1869,36 @@ def __init__(
self.post_attention_layernorm = self._create_norm(norm_config, hidden_size, rms_norm_eps)
def _create_mlp(self, mlp_config: dict, hidden_size: int):
- """Create MLP based on config."""
+ """Create MLP based on config.
+
+ Supports per-layer bias configuration mirroring Fast-LLM:
+ - add_linear_biases: default bias setting for all layers
+ - layer_1.bias.enabled: override for up_proj/gate_proj
+ - layer_2.bias.enabled: override for down_proj
+ """
mlp_type = mlp_config.get("type", "mlp")
if mlp_type == "mlp":
intermediate_size = mlp_config["intermediate_size"]
activation = mlp_config.get("activation", "silu")
- gated = mlp_config["gated"]
- bias = mlp_config.get("add_linear_biases", False)
+ gated = mlp_config.get("gated", False)
+
+ # Per-layer bias configuration (mirrors Fast-LLM structure)
+ default_bias = mlp_config.get("add_linear_biases", False)
+
+ def get_layer_bias(layer_name: str) -> bool:
+ layer_cfg = mlp_config.get(layer_name, {})
+ bias_cfg = layer_cfg.get("bias", {})
+ enabled = bias_cfg.get("enabled")
+ return default_bias if enabled is None else enabled
+
+ layer_1_bias = get_layer_bias("layer_1")
+ layer_2_bias = get_layer_bias("layer_2")
if gated:
+ # MistralMLP uses gate_proj, up_proj, down_proj (all bias controlled together)
+ # For now, we use the default bias setting for gated MLPs
+ # TODO: Add per-layer bias support to gated MLP
mlp_cfg = SimpleNamespace(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
@@ -1845,7 +1906,13 @@ def _create_mlp(self, mlp_config: dict, hidden_size: int):
)
return MistralMLP(mlp_cfg)
else:
- return SimpleMLP(hidden_size, intermediate_size, activation, bias)
+ return SimpleMLP(
+ hidden_size,
+ intermediate_size,
+ activation,
+ layer_1_bias=layer_1_bias,
+ layer_2_bias=layer_2_bias,
+ )
else:
raise ValueError(f"Unknown MLP type: {mlp_type}")
@@ -2179,6 +2246,8 @@ def forward(
class Apriel2ForCausalLM(Apriel2PreTrainedModel, GenerationMixin):
"""Apriel2 model with a language modeling head (text-only)."""
+ _tied_weights_keys = ["lm_head.weight"]
+
def __init__(self, config: Apriel2TextConfig):
super().__init__(config)
self.model = Apriel2TextModel(config)
@@ -2186,6 +2255,7 @@ def __init__(self, config: Apriel2TextConfig):
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
+ # post_init() calls init_weights() which calls tie_weights() if config.tie_word_embeddings
self.post_init()
def get_input_embeddings(self):
@@ -2583,14 +2653,26 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
class SimpleMLP(nn.Module):
- """Non-gated MLP: up_proj -> activation -> down_proj."""
+ """Non-gated MLP: up_proj -> activation -> down_proj.
+
+ Supports per-layer bias configuration mirroring Fast-LLM:
+ - layer_1_bias: bias for up_proj (layer_1 in Fast-LLM naming)
+ - layer_2_bias: bias for down_proj (layer_2 in Fast-LLM naming)
+ """
- def __init__(self, hidden_size: int, intermediate_size: int, activation: str = "silu", bias: bool = False):
+ def __init__(
+ self,
+ hidden_size: int,
+ intermediate_size: int,
+ activation: str = "silu",
+ layer_1_bias: bool = False,
+ layer_2_bias: bool = False,
+ ):
super().__init__()
from transformers.activations import ACT2FN
- self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=bias)
- self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=bias)
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=layer_1_bias)
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=layer_2_bias)
self.act_fn = ACT2FN[activation]
def forward(self, x):
diff --git a/fast_llm_external_models/tests/test_apriel2/conftest.py b/fast_llm_external_models/tests/test_apriel2/conftest.py
index 8585aec65..21b90b097 100644
--- a/fast_llm_external_models/tests/test_apriel2/conftest.py
+++ b/fast_llm_external_models/tests/test_apriel2/conftest.py
@@ -1,23 +1,44 @@
"""Test fixtures for Apriel2 model tests."""
+from collections.abc import Generator
from pathlib import Path
-from typing import Generator
import pytest
import torch
from transformers import LlavaConfig, LlavaForConditionalGeneration, MistralConfig
+from fast_llm_external_models.apriel2.cache import _AttentionCache, _SSMCache
+
+
+# Register custom marks
+def pytest_configure(config):
+ config.addinivalue_line("markers", "slow: marks tests as slow (deselect with '-m \"not slow\"')")
+
+
+def _can_import_fast_llm():
+ """Check if Fast-LLM is available."""
+ try:
+ return True
+ except ImportError:
+ return False
+
# Skip marker for tests that require CUDA for Mamba forward pass
requires_cuda = pytest.mark.skipif(
- not torch.cuda.is_available(),
- reason="SSM mixers (Mamba) require CUDA for forward pass"
+ not torch.cuda.is_available(), reason="SSM mixers (Mamba) require CUDA for forward pass"
)
+# Skip marker for tests that require Fast-LLM
+requires_fastllm = pytest.mark.skipif(not _can_import_fast_llm(), reason="Fast-LLM not available")
-@pytest.fixture(autouse=True)
+
+@pytest.fixture(scope="module", autouse=True)
def set_default_device():
- """Set default device to CUDA for all tests (Mamba requires CUDA)."""
+ """Set default device to CUDA for all tests (Mamba requires CUDA).
+
+ Module-scoped to ensure it runs before any module-scoped fixtures
+ that load models (e.g., qwen2_model_and_tokenizer).
+ """
if torch.cuda.is_available():
old_device = torch.get_default_device()
torch.set_default_device("cuda")
@@ -27,9 +48,12 @@ def set_default_device():
yield
-@pytest.fixture(autouse=True)
+@pytest.fixture(scope="module", autouse=True)
def set_default_dtype():
- """Set default dtype to float32 for numerical comparison tests."""
+ """Set default dtype to float32 for numerical comparison tests.
+
+ Module-scoped to ensure it runs before any module-scoped fixtures.
+ """
old_dtype = torch.get_default_dtype()
torch.set_default_dtype(torch.float32)
yield
@@ -135,14 +159,11 @@ def model_pair(request, small_pixtral_model, tmp_path):
tuple: (source_model, target_model, expected_atol, variant_name)
"""
import json
+
from safetensors import safe_open
from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config
- from fast_llm_external_models.apriel2.conversion import (
- convert_llava_config,
- execute,
- plan_llava_to_apriel2,
- )
+ from fast_llm_external_models.apriel2.conversion import convert_llava_config, execute, plan_llava_to_apriel2
from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForConditionalGeneration
source = small_pixtral_model
@@ -638,12 +659,12 @@ def apriel2_config_comprehensive():
"type": "pattern",
"num_blocks": 6,
"pattern": [
- "attn", # 0: pure full attention
- "swa", # 1: pure sliding window attention
- "mamba", # 2: pure mamba
- "gdn", # 3: pure gated delta net
- "stoch_attn_mamba", # 4: stochastic attention + mamba
- "stoch_swa_gdn", # 5: stochastic swa + gated delta net
+ "attn", # 0: pure full attention
+ "swa", # 1: pure sliding window attention
+ "mamba", # 2: pure mamba
+ "gdn", # 3: pure gated delta net
+ "stoch_attn_mamba", # 4: stochastic attention + mamba
+ "stoch_swa_gdn", # 5: stochastic swa + gated delta net
],
"blocks": {
"attn": {
@@ -761,6 +782,52 @@ def apriel2_config_comprehensive():
)
+@pytest.fixture
+def apriel2_config_with_bias():
+ """Apriel2 config with Qwen-style per-layer bias and non-gated MLP.
+
+ This config exercises:
+ - Per-layer attention bias (QKV bias enabled, O bias disabled)
+ - Non-gated MLP with per-layer bias (layer_1 enabled, layer_2 disabled)
+ - Config structure parity with Fast-LLM's AffineLinearConfig
+
+ Critical for testing bias inheritance through surgery operations.
+ """
+ from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config
+
+ return Apriel2Config(
+ vocab_size=100,
+ hidden_size=64,
+ decoder={
+ "type": "fixed",
+ "num_blocks": 2,
+ "block": {
+ "mixer": {
+ "type": "attention",
+ "heads": 4,
+ "head_groups": 2,
+ "head_size": 16,
+ "rotary": {"type": "mistral_1d", "theta": 10000.0},
+ # Qwen-style: QKV bias enabled, O bias disabled
+ "query_layer": {"bias": {"enabled": True}},
+ "key_layer": {"bias": {"enabled": True}},
+ "value_layer": {"bias": {"enabled": True}},
+ "dense_layer": {"bias": {"enabled": False}},
+ },
+ "mlp": {
+ "type": "mlp",
+ "intermediate_size": 256,
+ "gated": False, # Non-gated MLP (SimpleMLP)
+ # Per-layer MLP bias
+ "layer_1": {"bias": {"enabled": True}},
+ "layer_2": {"bias": {"enabled": False}},
+ },
+ "normalization": {"type": "rms_norm", "epsilon": 1e-5},
+ },
+ },
+ )
+
+
@pytest.fixture
def apriel2_cache(apriel2_config_tiny):
"""Create empty Apriel2Cache from tiny config."""
@@ -863,6 +930,77 @@ def additive_surgery_chain():
]
+@pytest.fixture
+def bias_surgery_chain():
+ """Surgery chain that exercises bias inheritance through surgery operations.
+
+ Designed to be used with apriel2_config_with_bias as the source config.
+ Tests that per-layer bias settings (Qwen-style QKV bias, non-gated MLP bias)
+ are correctly inherited through:
+ - Stochastic wrapper creation
+ - Adding new sub-mixers that inherit from source
+ - Cross-type derivation (attention → sliding_window)
+
+ Source config has:
+ - Attention: query/key/value bias enabled, dense bias disabled
+ - MLP: layer_1 bias enabled, layer_2 bias disabled (non-gated)
+ """
+ return [
+ # S1: Wrap in stochastic - bias should transfer to attention sub-mixer
+ {
+ "decoder": {
+ "block": {
+ "mixer": {
+ "type": "stochastic",
+ "main_mixer_name": "attention",
+ "mixers": {
+ "attention": {"init": "transfer"},
+ },
+ },
+ "mlp": {"init": "transfer"},
+ "normalization": {"init": "transfer"},
+ },
+ },
+ },
+ # S2: Add sliding_window - should inherit bias from source attention
+ {
+ "decoder": {
+ "block": {
+ "mixer": {
+ "mixers": {
+ "sliding_window": {
+ "type": "attention",
+ "init": "transfer",
+ "window_size": 512,
+ },
+ },
+ },
+ },
+ },
+ },
+ # S3: Add new attention with DIFFERENT bias config (random init)
+ {
+ "decoder": {
+ "block": {
+ "mixer": {
+ "mixers": {
+ "full_bias_attn": {
+ "type": "attention",
+ "init": "random",
+ "heads": 4,
+ "head_groups": 2,
+ "head_size": 16,
+ "rotary": {"type": "mistral_1d", "theta": 10000.0},
+ "add_linear_biases": True, # All biases enabled
+ },
+ },
+ },
+ },
+ },
+ },
+ ]
+
+
@pytest.fixture
def comprehensive_torture_chain():
"""Comprehensive torture chain exercising ALL conversion paths.
@@ -885,7 +1023,7 @@ def comprehensive_torture_chain():
# MIL requires: d_inner <= Q rows (256), d_xb <= K/V rows (128)
mamba_params = {
"d_inner": 256, # Must be <= heads*head_size = 256
- "d_xb": 64, # Must be <= head_groups*head_size = 128
+ "d_xb": 64, # Must be <= head_groups*head_size = 128
"dt_rank": 16,
"d_state": 16,
"d_conv": 4,
@@ -1532,3 +1670,330 @@ def torture_surgery_chain():
},
},
]
+
+
+# =============================================================================
+# Shared Config Dict Fixtures (for compose_configs / plan_surgery tests)
+# =============================================================================
+
+
+@pytest.fixture
+def base_config_dict():
+ """Complete Apriel2 config dict without biases (Llama-style).
+
+ Use this as the base config for testing compose_configs and plan_surgery.
+ Returns a dict (not Apriel2Config) for direct use with compose_configs.
+ """
+ return {
+ "model_type": "apriel2",
+ "hidden_size": 256,
+ "vocab_size": 1000,
+ "bos_token_id": 1,
+ "eos_token_id": 2,
+ "tie_word_embeddings": False,
+ "decoder": {
+ "type": "fixed",
+ "num_blocks": 2,
+ "block": {
+ "mixer": {
+ "type": "attention",
+ "heads": 8,
+ "head_groups": 4,
+ "head_size": 32,
+ "rotary": {"type": "mistral_1d", "theta": 10000.0},
+ },
+ "mlp": {"type": "mlp", "intermediate_size": 512, "gated": True},
+ "normalization": {"type": "rms_norm", "epsilon": 1e-5},
+ },
+ },
+ }
+
+
+@pytest.fixture
+def base_config_with_bias_dict():
+ """Complete Apriel2 config dict with Qwen-style biases.
+
+ - QKV bias enabled, O bias disabled
+ - Gated MLP (no per-layer bias control in this style)
+
+ Use this for testing bias inheritance through surgery operations.
+ Returns a dict (not Apriel2Config) for direct use with compose_configs.
+ """
+ return {
+ "model_type": "apriel2",
+ "hidden_size": 256,
+ "vocab_size": 1000,
+ "bos_token_id": 1,
+ "eos_token_id": 2,
+ "tie_word_embeddings": False,
+ "decoder": {
+ "type": "fixed",
+ "num_blocks": 2,
+ "block": {
+ "mixer": {
+ "type": "attention",
+ "heads": 8,
+ "head_groups": 4,
+ "head_size": 32,
+ "rotary": {"type": "mistral_1d", "theta": 10000.0},
+ "query_layer": {"bias": {"enabled": True}},
+ "key_layer": {"bias": {"enabled": True}},
+ "value_layer": {"bias": {"enabled": True}},
+ "dense_layer": {"bias": {"enabled": False}},
+ },
+ "mlp": {"type": "mlp", "intermediate_size": 512, "gated": True},
+ "normalization": {"type": "rms_norm", "epsilon": 1e-5},
+ },
+ },
+ }
+
+
+def make_weights_for_config(config: dict) -> dict:
+ """Create random weights matching a config's expected schema.
+
+ This is a helper function (not a fixture) for creating test weights.
+ Use it in tests that need weights for plan execution.
+
+ Args:
+ config: Complete Apriel2 config dict
+
+ Returns:
+ Dict mapping weight key strings to torch tensors
+ """
+ from fast_llm_external_models.apriel2.conversion import W
+
+ hidden = config["hidden_size"]
+ vocab = config["vocab_size"]
+ decoder = config["decoder"]
+ num_blocks = decoder["num_blocks"]
+ block = decoder["block"]
+ mixer = block["mixer"]
+ mlp = block["mlp"]
+
+ heads = mixer["heads"]
+ head_groups = mixer["head_groups"]
+ head_size = mixer["head_size"]
+ inter = mlp["intermediate_size"]
+
+ # Check bias settings
+ has_q_bias = mixer.get("query_layer", {}).get("bias", {}).get("enabled", False)
+ has_k_bias = mixer.get("key_layer", {}).get("bias", {}).get("enabled", False)
+ has_v_bias = mixer.get("value_layer", {}).get("bias", {}).get("enabled", False)
+
+ weights = {}
+ weights["model.embed_tokens.weight"] = torch.randn(vocab, hidden)
+
+ for i in range(num_blocks):
+ p = f"model.decoder.blocks.{i}"
+
+ # Attention
+ weights[f"{p}.mixer.q_proj.weight"] = torch.randn(heads * head_size, hidden)
+ weights[f"{p}.mixer.k_proj.weight"] = torch.randn(head_groups * head_size, hidden)
+ weights[f"{p}.mixer.v_proj.weight"] = torch.randn(head_groups * head_size, hidden)
+ weights[f"{p}.mixer.o_proj.weight"] = torch.randn(hidden, heads * head_size)
+
+ if has_q_bias:
+ weights[f"{p}.mixer.q_proj.bias"] = torch.randn(heads * head_size)
+ if has_k_bias:
+ weights[f"{p}.mixer.k_proj.bias"] = torch.randn(head_groups * head_size)
+ if has_v_bias:
+ weights[f"{p}.mixer.v_proj.bias"] = torch.randn(head_groups * head_size)
+
+ # MLP
+ weights[f"{p}.mlp.up_proj.weight"] = torch.randn(inter, hidden)
+ weights[f"{p}.mlp.gate_proj.weight"] = torch.randn(inter, hidden)
+ weights[f"{p}.mlp.down_proj.weight"] = torch.randn(hidden, inter)
+
+ # Norms
+ weights[f"{p}.input_layernorm.weight"] = torch.randn(hidden)
+ weights[f"{p}.post_attention_layernorm.weight"] = torch.randn(hidden)
+
+ weights["model.norm.weight"] = torch.randn(hidden)
+ weights["lm_head.weight"] = torch.randn(vocab, hidden)
+
+ return {W(k): v for k, v in weights.items()}
+
+
+# =============================================================================
+# Cache Test Fixtures - Tensor Dimensions
+# =============================================================================
+
+
+@pytest.fixture
+def batch_size():
+ """Default batch size for cache tests."""
+ return 2
+
+
+@pytest.fixture
+def num_heads():
+ """Default number of attention heads for cache tests."""
+ return 4
+
+
+@pytest.fixture
+def head_dim():
+ """Default head dimension for cache tests."""
+ return 16
+
+
+@pytest.fixture
+def make_kv(batch_size, num_heads, head_dim):
+ """Factory fixture for creating KV tensors."""
+
+ def _make_kv(seq_len):
+ return (
+ torch.randn(batch_size, num_heads, seq_len, head_dim),
+ torch.randn(batch_size, num_heads, seq_len, head_dim),
+ )
+
+ return _make_kv
+
+
+# =============================================================================
+# Cache Test Fixtures - HuggingFace Cache Layers
+# =============================================================================
+
+
+@pytest.fixture
+def hf_dynamic_layer():
+ """HuggingFace DynamicLayer for full attention contract testing."""
+ from transformers.cache_utils import DynamicLayer
+
+ return DynamicLayer()
+
+
+@pytest.fixture
+def hf_sliding_layer(window_size):
+ """HuggingFace DynamicSlidingWindowLayer for sliding window contract testing."""
+ from transformers.cache_utils import DynamicSlidingWindowLayer
+
+ return DynamicSlidingWindowLayer(sliding_window=window_size)
+
+
+# =============================================================================
+# Cache Test Fixtures - Apriel2 Low-level Caches
+# =============================================================================
+
+
+@pytest.fixture
+def apriel_attention_cache():
+ """Apriel2 attention cache without window (full attention)."""
+ return _AttentionCache(window=None)
+
+
+@pytest.fixture
+def apriel_sliding_cache(window_size):
+ """Apriel2 attention cache with sliding window."""
+ return _AttentionCache(window=window_size)
+
+
+@pytest.fixture
+def ssm_cache():
+ """Apriel2 SSM cache for Mamba/GDN/KDA layers."""
+ return _SSMCache()
+
+
+# =============================================================================
+# Cache Test Fixtures - Apriel2 Configs (Simple Versions)
+# =============================================================================
+
+
+@pytest.fixture
+def attention_config():
+ """Pure attention config (2 layers, no sliding window)."""
+ from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config
+
+ return Apriel2Config(
+ vocab_size=100,
+ hidden_size=64,
+ decoder={
+ "type": "fixed",
+ "num_blocks": 2,
+ "block": {
+ "mixer": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16},
+ "mlp": {"type": "mlp", "intermediate_size": 256},
+ "normalization": {"type": "rms_norm", "epsilon": 1e-5},
+ },
+ },
+ )
+
+
+@pytest.fixture
+def swa_config():
+ """Sliding window attention config (2 layers, window=8)."""
+ from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config
+
+ return Apriel2Config(
+ vocab_size=100,
+ hidden_size=64,
+ decoder={
+ "type": "fixed",
+ "num_blocks": 2,
+ "block": {
+ "mixer": {
+ "type": "attention",
+ "heads": 4,
+ "head_groups": 2,
+ "head_size": 16,
+ "window_size": 8,
+ },
+ "mlp": {"type": "mlp", "intermediate_size": 256},
+ "normalization": {"type": "rms_norm", "epsilon": 1e-5},
+ },
+ },
+ )
+
+
+@pytest.fixture
+def ssm_config():
+ """Pure SSM config (2 layers)."""
+ from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config
+
+ return Apriel2Config(
+ vocab_size=100,
+ hidden_size=64,
+ decoder={
+ "type": "fixed",
+ "num_blocks": 2,
+ "block": {
+ "mixer": {"type": "mamba", "state_size": 16},
+ "mlp": {"type": "mlp", "intermediate_size": 256},
+ "normalization": {"type": "rms_norm", "epsilon": 1e-5},
+ },
+ },
+ )
+
+
+@pytest.fixture
+def stochastic_config():
+ """Stochastic mixer config with attention and mamba (2 layers)."""
+ from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config
+
+ return Apriel2Config(
+ vocab_size=100,
+ hidden_size=64,
+ decoder={
+ "type": "fixed",
+ "num_blocks": 2,
+ "block": {
+ "mixer": {
+ "type": "stochastic",
+ "main_mixer_name": "attention",
+ "mixers": {
+ "attention": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16},
+ "mamba": {"type": "mamba", "state_size": 16},
+ },
+ },
+ "mlp": {"type": "mlp", "intermediate_size": 256},
+ "normalization": {"type": "rms_norm", "epsilon": 1e-5},
+ },
+ },
+ )
+
+
+# Parameterized window size fixture (used by hf_sliding_layer and apriel_sliding_cache)
+@pytest.fixture(params=[4, 8, 16, 32])
+def window_size(request):
+ """Parameterized window sizes for sliding window tests."""
+ return request.param
diff --git a/fast_llm_external_models/tests/test_apriel2/test_cache.py b/fast_llm_external_models/tests/test_apriel2/test_cache.py
deleted file mode 100644
index ca8158b4f..000000000
--- a/fast_llm_external_models/tests/test_apriel2/test_cache.py
+++ /dev/null
@@ -1,1258 +0,0 @@
-"""Comprehensive tests for Apriel2Cache.
-
-Architecture Overview
-=====================
-Apriel2Cache manages state for autoregressive generation across different mixer types:
-
-1. **Attention Cache** (_AttentionCache): Stores key/value states
- - Supports sliding window (window_size) for SWA
- - Efficient roll optimization for single-token decode
-
-2. **SSM Cache** (_SSMCache): Stores conv and recurrent states
- - Used by Mamba, GDN, KDA
- - KDA uses tuple conv states (q, k, v), others use single tensor
-
-3. **Stochastic Mixer Routing**: For layers with multiple mixer options
- - Each mixer has independent cache (no sharing)
- - active_mixer pointer routes operations to correct sub-cache
- - Switching mixers preserves each mixer's independent state
-
-Cache Invalidation Semantics
-============================
-When switching between mixers in a stochastic layer:
-- Each mixer maintains its OWN independent history
-- Switching does NOT invalidate the previous mixer's cache
-- Switching does NOT copy state between mixers
-- To invalidate: call reset() explicitly
-
-This is intentional for training with stochastic sampling where each mixer
-should learn from its own history. For inference, main_mixer_name is fixed.
-
-Test Organization
-=================
-1. CREATION & PROPERTIES - Cache initialization, config parsing
-2. ATTENTION CACHE - Updates, sliding window, concatenation
-3. SSM CACHE - Conv states, recurrent states, KDA tuples
-4. STOCHASTIC ROUTING - Active mixer, isolation, switching
-5. CACHE INVALIDATION - Reset, per-mixer reset, coherence
-6. BEAM SEARCH - batch_repeat, reorder, select
-7. HF INTEGRATION - get_mask_sizes, indexing, properties
-8. GENERATION PATTERNS - Prefill→decode, crop→continue
-9. ERROR HANDLING - Guards, bounds, invalid operations
-"""
-
-import pytest
-import torch
-
-from fast_llm_external_models.apriel2.cache import (
- Apriel2Cache,
- _AttentionCache,
- _SSMCache,
-)
-
-
-# =============================================================================
-# FIXTURES - Configs and Sample Data
-# =============================================================================
-
-
-@pytest.fixture
-def tiny_attention_config():
- """Minimal config with pure attention layers."""
- from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config
-
- return Apriel2Config(
- vocab_size=100,
- hidden_size=64,
- decoder={
- "type": "fixed",
- "num_blocks": 2,
- "block": {
- "mixer": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16},
- "mlp": {"type": "mlp", "intermediate_size": 256},
- "normalization": {"type": "rms_norm", "epsilon": 1e-5},
- },
- },
- )
-
-
-@pytest.fixture
-def swa_config():
- """Config with sliding window attention."""
- from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config
-
- return Apriel2Config(
- vocab_size=100,
- hidden_size=64,
- decoder={
- "type": "fixed",
- "num_blocks": 2,
- "block": {
- "mixer": {
- "type": "attention",
- "heads": 4,
- "head_groups": 2,
- "head_size": 16,
- "window_size": 8, # Small for testing
- },
- "mlp": {"type": "mlp", "intermediate_size": 256},
- "normalization": {"type": "rms_norm", "epsilon": 1e-5},
- },
- },
- )
-
-
-@pytest.fixture
-def ssm_config():
- """Config with pure SSM layers (mamba)."""
- from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config
-
- return Apriel2Config(
- vocab_size=100,
- hidden_size=64,
- decoder={
- "type": "fixed",
- "num_blocks": 2,
- "block": {
- "mixer": {
- "type": "mamba",
- "d_inner": 128,
- "d_state": 16,
- "dt_rank": 4,
- "d_conv": 4,
- },
- "mlp": {"type": "mlp", "intermediate_size": 256},
- "normalization": {"type": "rms_norm", "epsilon": 1e-5},
- },
- },
- )
-
-
-@pytest.fixture
-def kda_config():
- """Config with pure KDA layers."""
- from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config
-
- return Apriel2Config(
- vocab_size=100,
- hidden_size=64,
- decoder={
- "type": "fixed",
- "num_blocks": 2,
- "block": {
- "mixer": {
- "type": "kda",
- "heads": 4,
- "head_dim": 16,
- "convolution_layer": {"kernel_size": 4},
- "normalization": {"epsilon": 1e-5},
- },
- "mlp": {"type": "mlp", "intermediate_size": 256},
- "normalization": {"type": "rms_norm", "epsilon": 1e-5},
- },
- },
- )
-
-
-@pytest.fixture
-def stochastic_config():
- """Config with stochastic mixer (attention + mamba)."""
- from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config
-
- return Apriel2Config(
- vocab_size=100,
- hidden_size=64,
- decoder={
- "type": "pattern",
- "num_blocks": 2,
- "pattern": ["attn", "stochastic"],
- "blocks": {
- "attn": {
- "mixer": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16},
- "mlp": {"type": "mlp", "intermediate_size": 256},
- "normalization": {"type": "rms_norm", "epsilon": 1e-5},
- },
- "stochastic": {
- "mixer": {
- "type": "stochastic",
- "main_mixer_name": "attention",
- "mixers": {
- "attention": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16},
- "mamba": {"type": "mamba", "d_inner": 128, "d_state": 16, "dt_rank": 4, "d_conv": 4},
- },
- },
- "mlp": {"type": "mlp", "intermediate_size": 256},
- "normalization": {"type": "rms_norm", "epsilon": 1e-5},
- },
- },
- },
- )
-
-
-@pytest.fixture
-def all_mixers_config():
- """Config with stochastic mixer containing all 5 mixer types."""
- from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config
-
- return Apriel2Config(
- vocab_size=100,
- hidden_size=64,
- decoder={
- "type": "pattern",
- "num_blocks": 2,
- "pattern": ["attn", "all_mixers"],
- "blocks": {
- "attn": {
- "mixer": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16},
- "mlp": {"type": "mlp", "intermediate_size": 256},
- "normalization": {"type": "rms_norm", "epsilon": 1e-5},
- },
- "all_mixers": {
- "mixer": {
- "type": "stochastic",
- "main_mixer_name": "attention",
- "mixers": {
- "attention": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16},
- "swa": {
- "type": "attention",
- "heads": 4,
- "head_groups": 2,
- "head_size": 16,
- "window_size": 1024,
- },
- "mamba": {"type": "mamba", "d_inner": 128, "d_state": 16, "dt_rank": 4, "d_conv": 4},
- "gdn": {
- "type": "gdn",
- "value_heads": 4,
- "key_heads": 2,
- "key_head_dim": 16,
- "value_head_dim": 16,
- "convolution_layer": {"kernel_size": 4},
- },
- "kda": {
- "type": "kda",
- "heads": 4,
- "head_dim": 16,
- "convolution_layer": {"kernel_size": 4},
- "normalization": {"epsilon": 1e-5},
- },
- },
- },
- "mlp": {"type": "mlp", "intermediate_size": 256},
- "normalization": {"type": "rms_norm", "epsilon": 1e-5},
- },
- },
- },
- )
-
-
-@pytest.fixture
-def multi_window_config():
- """Config with multiple different window sizes."""
- from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config
-
- return Apriel2Config(
- vocab_size=100,
- hidden_size=64,
- decoder={
- "type": "pattern",
- "num_blocks": 3,
- "pattern": ["full", "small_window", "large_window"],
- "blocks": {
- "full": {
- "mixer": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16},
- "mlp": {"type": "mlp", "intermediate_size": 256},
- "normalization": {"type": "rms_norm", "epsilon": 1e-5},
- },
- "small_window": {
- "mixer": {
- "type": "attention",
- "heads": 4,
- "head_groups": 2,
- "head_size": 16,
- "window_size": 512,
- },
- "mlp": {"type": "mlp", "intermediate_size": 256},
- "normalization": {"type": "rms_norm", "epsilon": 1e-5},
- },
- "large_window": {
- "mixer": {
- "type": "attention",
- "heads": 4,
- "head_groups": 2,
- "head_size": 16,
- "window_size": 2048,
- },
- "mlp": {"type": "mlp", "intermediate_size": 256},
- "normalization": {"type": "rms_norm", "epsilon": 1e-5},
- },
- },
- },
- )
-
-
-@pytest.fixture
-def sample_kv():
- """Sample key/value tensors: [batch=2, heads=4, seq=10, head_dim=16]."""
- return torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16)
-
-
-@pytest.fixture
-def sample_conv_single():
- """Sample single-tensor conv state: [batch=2, d_inner=128, kernel=4]."""
- return torch.randn(2, 128, 4)
-
-
-@pytest.fixture
-def sample_conv_tuple():
- """Sample tuple conv state for KDA: (q, k, v) each [batch=2, d=64, kernel=3]."""
- return (torch.randn(2, 64, 3), torch.randn(2, 64, 3), torch.randn(2, 64, 3))
-
-
-@pytest.fixture
-def sample_recurrent():
- """Sample recurrent state: [batch=2, heads=4, head_dim=16, d_state=16]."""
- return torch.randn(2, 4, 16, 16)
-
-
-# =============================================================================
-# SECTION 1: CACHE CREATION & PROPERTIES
-# =============================================================================
-
-
-class TestCacheCreation:
- """Test cache initialization from config."""
-
- def test_attention_cache_creation(self, tiny_attention_config):
- """Create cache for pure attention config."""
- cache = Apriel2Cache(tiny_attention_config)
-
- assert len(cache) == 2
- assert cache.mixer_types == ["attention", "attention"]
- assert all(isinstance(l, _AttentionCache) for l in cache.layers)
-
- def test_ssm_cache_creation(self, ssm_config):
- """Create cache for pure SSM config."""
- cache = Apriel2Cache(ssm_config)
-
- assert len(cache) == 2
- assert cache.mixer_types == ["mamba", "mamba"]
- assert all(isinstance(l, _SSMCache) for l in cache.layers)
-
- def test_kda_cache_creation(self, kda_config):
- """Create cache for pure KDA config."""
- cache = Apriel2Cache(kda_config)
-
- assert len(cache) == 2
- assert cache.mixer_types == ["kda", "kda"]
- assert all(isinstance(l, _SSMCache) for l in cache.layers)
-
- def test_stochastic_cache_creation(self, stochastic_config):
- """Create cache for stochastic mixer config."""
- cache = Apriel2Cache(stochastic_config)
-
- assert len(cache) == 2
- # Layer 0: pure attention, Layer 1: stochastic (dict)
- assert isinstance(cache.layers[0], _AttentionCache)
- assert isinstance(cache.layers[1], dict)
- assert set(cache.layers[1].keys()) == {"attention", "mamba"}
-
- def test_swa_window_captured(self, swa_config):
- """Verify sliding window size is captured."""
- cache = Apriel2Cache(swa_config)
-
- assert cache.layers[0].window == 8
- assert cache.is_sliding == [True, True]
-
- def test_active_mixers_initialized_none(self, stochastic_config):
- """Verify active_mixers starts as None for all layers."""
- cache = Apriel2Cache(stochastic_config)
-
- assert cache.active_mixers == [None, None]
-
-
-class TestCacheProperties:
- """Test cache property accessors."""
-
- def test_empty_cache_properties(self, tiny_attention_config):
- """Test properties of uninitialized cache."""
- cache = Apriel2Cache(tiny_attention_config)
-
- assert cache.is_initialized == False
- assert cache.has_previous_state == False
- assert cache.max_batch_size is None
- assert cache.max_cache_len is None
- assert cache.is_compileable == False
-
- def test_is_initialized_attention(self, tiny_attention_config, sample_kv):
- """is_initialized detects attention cache."""
- cache = Apriel2Cache(tiny_attention_config)
- cache.update(*sample_kv, layer_idx=0)
-
- assert cache.is_initialized == True
-
- def test_is_initialized_ssm(self, ssm_config, sample_conv_single):
- """is_initialized detects SSM cache."""
- cache = Apriel2Cache(ssm_config)
- cache.conv_states[0] = sample_conv_single
-
- assert cache.is_initialized == True
-
- def test_has_previous_state_ssm_only(self, ssm_config, sample_conv_single):
- """has_previous_state only looks at SSM conv states."""
- cache = Apriel2Cache(ssm_config)
-
- assert cache.has_previous_state == False
- cache.conv_states[0] = sample_conv_single
- assert cache.has_previous_state == True
-
- def test_has_previous_state_ignores_attention(self, tiny_attention_config, sample_kv):
- """has_previous_state ignores attention cache."""
- cache = Apriel2Cache(tiny_attention_config)
- cache.update(*sample_kv, layer_idx=0)
-
- # Attention cache is set, but has_previous_state only checks SSM
- assert cache.has_previous_state == False
-
- def test_max_batch_size_from_attention(self, tiny_attention_config, sample_kv):
- """max_batch_size from attention cache."""
- cache = Apriel2Cache(tiny_attention_config)
- cache.update(*sample_kv, layer_idx=0)
-
- assert cache.max_batch_size == 2
-
- def test_max_batch_size_from_ssm(self, ssm_config, sample_conv_single):
- """max_batch_size from SSM cache."""
- cache = Apriel2Cache(ssm_config)
- cache.conv_states[0] = sample_conv_single
-
- assert cache.max_batch_size == 2
-
- def test_max_batch_size_from_kda_tuple(self, kda_config, sample_conv_tuple):
- """max_batch_size from KDA tuple conv state."""
- cache = Apriel2Cache(kda_config)
- cache.conv_states[0] = sample_conv_tuple
-
- assert cache.max_batch_size == 2
-
- def test_max_cache_len_single_window(self, swa_config):
- """max_cache_len with single window size."""
- cache = Apriel2Cache(swa_config)
- assert cache.max_cache_len == 8
-
- def test_max_cache_len_multiple_windows(self, multi_window_config):
- """max_cache_len returns minimum window."""
- cache = Apriel2Cache(multi_window_config)
- assert cache.max_cache_len == 512 # min(512, 2048)
-
- def test_max_cache_len_no_windows(self, tiny_attention_config):
- """max_cache_len is None when no windows."""
- cache = Apriel2Cache(tiny_attention_config)
- assert cache.max_cache_len is None
-
- def test_is_sliding_mixed(self, multi_window_config):
- """is_sliding reflects per-layer window presence."""
- cache = Apriel2Cache(multi_window_config)
- assert cache.is_sliding == [False, True, True]
-
-
-# =============================================================================
-# SECTION 2: ATTENTION CACHE OPERATIONS
-# =============================================================================
-
-
-class TestAttentionCacheBasics:
- """Test basic attention cache operations."""
-
- def test_update_stores_kv(self, tiny_attention_config, sample_kv):
- """update() stores key/value states."""
- cache = Apriel2Cache(tiny_attention_config)
- key, value = sample_kv
-
- k_out, v_out = cache.update(key, value, layer_idx=0)
-
- torch.testing.assert_close(k_out, key)
- torch.testing.assert_close(v_out, value)
- assert cache.get_seq_length(0) == 10
-
- def test_update_concatenates(self, tiny_attention_config, sample_kv):
- """Subsequent updates concatenate."""
- cache = Apriel2Cache(tiny_attention_config)
- key, value = sample_kv
-
- cache.update(key, value, layer_idx=0)
- k_out, v_out = cache.update(key, value, layer_idx=0)
-
- assert k_out.shape[-2] == 20
- assert cache.get_seq_length(0) == 20
-
- def test_key_value_cache_accessors(self, tiny_attention_config, sample_kv):
- """Test key_cache and value_cache accessors."""
- cache = Apriel2Cache(tiny_attention_config)
- cache.update(*sample_kv, layer_idx=0)
-
- assert cache.key_cache[0] is not None
- assert cache.value_cache[0] is not None
- torch.testing.assert_close(cache.key_cache[0], sample_kv[0])
-
-
-class TestSlidingWindowAttention:
- """Test sliding window attention behavior."""
-
- def test_initial_within_window(self, swa_config):
- """Initial sequence within window is kept."""
- cache = Apriel2Cache(swa_config)
- key = torch.randn(2, 4, 5, 16) # seq=5 < window=8
- value = torch.randn(2, 4, 5, 16)
-
- cache.update(key, value, layer_idx=0)
-
- assert cache.get_seq_length(0) == 5
-
- def test_initial_exceeds_window(self, swa_config):
- """Initial sequence > window is truncated to last window tokens."""
- cache = Apriel2Cache(swa_config)
- key = torch.arange(12).float().view(1, 1, 12, 1).expand(2, 4, 12, 16)
- value = key.clone()
-
- k_out, v_out = cache.update(key, value, layer_idx=0)
-
- assert cache.get_seq_length(0) == 8
- # Should keep tokens 4-11 (last 8)
- assert k_out[0, 0, 0, 0].item() == 4.0
-
- def test_single_token_roll_path(self, swa_config):
- """Single token decode with full window uses efficient roll."""
- cache = Apriel2Cache(swa_config)
-
- # Fill window exactly
- key1 = torch.arange(8).float().view(1, 1, 8, 1).expand(2, 4, 8, 16)
- cache.update(key1, key1.clone(), layer_idx=0)
-
- # Decode single token
- key2 = torch.full((2, 4, 1, 16), 8.0)
- k_out, _ = cache.update(key2, key2.clone(), layer_idx=0)
-
- assert cache.get_seq_length(0) == 8
- assert k_out[0, 0, 0, 0].item() == 1.0 # Token 0 rolled out
- assert k_out[0, 0, 7, 0].item() == 8.0 # New token at end
-
- def test_multi_token_cat_slice_path(self, swa_config):
- """Multiple tokens use cat+slice path."""
- cache = Apriel2Cache(swa_config)
-
- # Fill window
- key1 = torch.randn(2, 4, 8, 16)
- cache.update(key1, key1.clone(), layer_idx=0)
-
- # Add 3 tokens
- key2 = torch.randn(2, 4, 3, 16)
- k_out, _ = cache.update(key2, key2.clone(), layer_idx=0)
-
- assert cache.get_seq_length(0) == 8
- torch.testing.assert_close(k_out[..., -3:, :], key2)
-
- def test_partial_then_fill_then_overflow(self, swa_config):
- """Progressive filling: partial → full → overflow."""
- cache = Apriel2Cache(swa_config)
-
- cache.update(torch.randn(2, 4, 5, 16), torch.randn(2, 4, 5, 16), layer_idx=0)
- assert cache.get_seq_length(0) == 5
-
- cache.update(torch.randn(2, 4, 3, 16), torch.randn(2, 4, 3, 16), layer_idx=0)
- assert cache.get_seq_length(0) == 8
-
- cache.update(torch.randn(2, 4, 2, 16), torch.randn(2, 4, 2, 16), layer_idx=0)
- assert cache.get_seq_length(0) == 8
-
- def test_contiguous_output(self, swa_config):
- """Outputs are contiguous after windowing."""
- cache = Apriel2Cache(swa_config)
-
- cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0)
- cache.update(torch.randn(2, 4, 5, 16), torch.randn(2, 4, 5, 16), layer_idx=0)
-
- assert cache.layers[0].key.is_contiguous()
- assert cache.layers[0].value.is_contiguous()
-
-
-# =============================================================================
-# SECTION 3: SSM CACHE OPERATIONS
-# =============================================================================
-
-
-class TestSSMCacheBasics:
- """Test basic SSM cache operations."""
-
- def test_conv_states_accessor(self, ssm_config, sample_conv_single):
- """Test conv_states accessor."""
- cache = Apriel2Cache(ssm_config)
-
- cache.conv_states[0] = sample_conv_single
- torch.testing.assert_close(cache.conv_states[0], sample_conv_single)
-
- def test_recurrent_states_accessor(self, ssm_config, sample_recurrent):
- """Test recurrent_states accessor."""
- cache = Apriel2Cache(ssm_config)
-
- cache.recurrent_states[0] = sample_recurrent
- torch.testing.assert_close(cache.recurrent_states[0], sample_recurrent)
-
- def test_ssm_seq_length_always_zero(self, ssm_config, sample_conv_single):
- """get_seq_length returns 0 for SSM (no KV cache)."""
- cache = Apriel2Cache(ssm_config)
- cache.conv_states[0] = sample_conv_single
-
- assert cache.get_seq_length(0) == 0
-
-
-class TestKDACache:
- """Test KDA-specific cache operations with tuple conv states."""
-
- def test_tuple_conv_storage(self, kda_config, sample_conv_tuple):
- """KDA stores tuple conv states."""
- cache = Apriel2Cache(kda_config)
-
- cache.conv_states[0] = sample_conv_tuple
-
- assert isinstance(cache.conv_states[0], tuple)
- assert len(cache.conv_states[0]) == 3
- for i in range(3):
- torch.testing.assert_close(cache.conv_states[0][i], sample_conv_tuple[i])
-
- def test_tuple_with_recurrent(self, kda_config, sample_conv_tuple, sample_recurrent):
- """KDA can have both tuple conv and recurrent states."""
- cache = Apriel2Cache(kda_config)
-
- cache.conv_states[0] = sample_conv_tuple
- cache.recurrent_states[0] = sample_recurrent
-
- assert isinstance(cache.conv_states[0], tuple)
- assert cache.recurrent_states[0] is not None
-
- def test_has_previous_state_detects_tuple(self, kda_config, sample_conv_tuple):
- """has_previous_state works with tuple conv states."""
- cache = Apriel2Cache(kda_config)
-
- assert cache.has_previous_state == False
- cache.conv_states[0] = sample_conv_tuple
- assert cache.has_previous_state == True
-
-
-# =============================================================================
-# SECTION 4: STOCHASTIC ROUTING
-# =============================================================================
-
-
-class TestStochasticRouting:
- """Test stochastic mixer cache routing."""
-
- def test_set_active_mixer(self, stochastic_config):
- """set_active_mixer sets the pointer."""
- cache = Apriel2Cache(stochastic_config)
-
- cache.set_active_mixer(1, "attention")
- assert cache.active_mixers[1] == "attention"
-
- cache.set_active_mixer(1, "mamba")
- assert cache.active_mixers[1] == "mamba"
-
- def test_operations_route_to_active(self, stochastic_config, sample_kv):
- """Operations route to currently active mixer."""
- cache = Apriel2Cache(stochastic_config)
-
- cache.set_active_mixer(1, "attention")
- cache.update(*sample_kv, layer_idx=1)
- attn_len = cache.get_seq_length(1)
-
- cache.set_active_mixer(1, "mamba")
- mamba_len = cache.get_seq_length(1)
-
- assert attn_len == 10
- assert mamba_len == 0 # Mamba cache is separate and empty
-
- def test_each_mixer_independent_cache(self, stochastic_config, sample_kv, sample_conv_single):
- """Each mixer maintains independent cache."""
- cache = Apriel2Cache(stochastic_config)
-
- # Fill attention cache
- cache.set_active_mixer(1, "attention")
- cache.update(*sample_kv, layer_idx=1)
-
- # Fill mamba cache
- cache.set_active_mixer(1, "mamba")
- cache.conv_states[1] = sample_conv_single
-
- # Both preserved
- cache.set_active_mixer(1, "attention")
- assert cache.get_seq_length(1) == 10
-
- cache.set_active_mixer(1, "mamba")
- torch.testing.assert_close(cache.conv_states[1], sample_conv_single)
-
-
-class TestMixerSwitching:
- """Test behavior when switching between mixers mid-generation."""
-
- def test_switch_preserves_previous_state(self, stochastic_config, sample_kv):
- """Switching mixers preserves previous mixer's state."""
- cache = Apriel2Cache(stochastic_config)
-
- cache.set_active_mixer(1, "attention")
- cache.update(*sample_kv, layer_idx=1)
- original_key = cache.layers[1]["attention"].key.clone()
-
- # Switch to mamba, do something
- cache.set_active_mixer(1, "mamba")
- cache.conv_states[1] = torch.randn(2, 128, 4)
-
- # Switch back - attention unchanged
- cache.set_active_mixer(1, "attention")
- torch.testing.assert_close(cache.layers[1]["attention"].key, original_key)
-
- def test_switch_does_not_copy_state(self, stochastic_config, sample_kv):
- """Switching does NOT copy state between mixers."""
- cache = Apriel2Cache(stochastic_config)
-
- # Fill attention with 10 tokens
- cache.set_active_mixer(1, "attention")
- cache.update(*sample_kv, layer_idx=1)
-
- # Switch to mamba - it has NO history from attention
- cache.set_active_mixer(1, "mamba")
- assert cache.conv_states[1] is None
- assert cache.recurrent_states[1] is None
-
- def test_has_previous_state_checks_all_sub_caches(self, stochastic_config):
- """has_previous_state checks ALL sub-caches, not just active."""
- cache = Apriel2Cache(stochastic_config)
-
- cache.set_active_mixer(1, "mamba")
- cache.conv_states[1] = torch.randn(2, 128, 4)
-
- # Even if we switch away, has_previous_state still detects it
- cache.set_active_mixer(1, "attention")
- assert cache.has_previous_state == True
-
-
-class TestAllMixerTypes:
- """Test cache isolation across all 5 mixer types."""
-
- def test_all_five_mixer_types_isolated(self, all_mixers_config):
- """All 5 mixer types maintain isolated caches."""
- cache = Apriel2Cache(all_mixers_config)
- layer_idx = 1 # Stochastic layer
-
- # Fill each mixer's cache
- cache.set_active_mixer(layer_idx, "attention")
- attn_kv = (torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16))
- cache.update(*attn_kv, layer_idx=layer_idx)
-
- cache.set_active_mixer(layer_idx, "swa")
- swa_kv = (torch.randn(2, 4, 5, 16), torch.randn(2, 4, 5, 16))
- cache.update(*swa_kv, layer_idx=layer_idx)
-
- cache.set_active_mixer(layer_idx, "mamba")
- mamba_conv = torch.randn(2, 128, 4)
- cache.conv_states[layer_idx] = mamba_conv
-
- cache.set_active_mixer(layer_idx, "gdn")
- gdn_conv = torch.randn(2, 64, 3)
- cache.conv_states[layer_idx] = gdn_conv
-
- cache.set_active_mixer(layer_idx, "kda")
- kda_conv = (torch.randn(2, 64, 3), torch.randn(2, 64, 3), torch.randn(2, 64, 3))
- cache.conv_states[layer_idx] = kda_conv
-
- # Verify all preserved
- cache.set_active_mixer(layer_idx, "attention")
- assert cache.get_seq_length(layer_idx) == 10
-
- cache.set_active_mixer(layer_idx, "swa")
- assert cache.get_seq_length(layer_idx) == 5
-
- cache.set_active_mixer(layer_idx, "mamba")
- torch.testing.assert_close(cache.conv_states[layer_idx], mamba_conv)
-
- cache.set_active_mixer(layer_idx, "gdn")
- torch.testing.assert_close(cache.conv_states[layer_idx], gdn_conv)
-
- cache.set_active_mixer(layer_idx, "kda")
- assert isinstance(cache.conv_states[layer_idx], tuple)
-
-
-# =============================================================================
-# SECTION 5: CACHE INVALIDATION
-# =============================================================================
-
-
-class TestCacheInvalidation:
- """Test cache invalidation and reset semantics.
-
- Key principle: Each mixer maintains independent state. To invalidate:
- - reset() clears ALL caches across ALL layers and mixers
- - There is no per-mixer reset (by design - each mixer is independent)
- """
-
- def test_reset_clears_attention(self, tiny_attention_config, sample_kv):
- """reset() clears attention cache."""
- cache = Apriel2Cache(tiny_attention_config)
- cache.update(*sample_kv, layer_idx=0)
-
- cache.reset()
-
- assert cache.is_initialized == False
- assert cache.get_seq_length(0) == 0
-
- def test_reset_clears_ssm(self, ssm_config, sample_conv_single, sample_recurrent):
- """reset() clears SSM cache."""
- cache = Apriel2Cache(ssm_config)
- cache.conv_states[0] = sample_conv_single
- cache.recurrent_states[0] = sample_recurrent
-
- cache.reset()
-
- assert cache.has_previous_state == False
- assert cache.conv_states[0] is None
- assert cache.recurrent_states[0] is None
-
- def test_reset_clears_kda_tuple(self, kda_config, sample_conv_tuple):
- """reset() clears KDA tuple conv states."""
- cache = Apriel2Cache(kda_config)
- cache.conv_states[0] = sample_conv_tuple
-
- cache.reset()
-
- assert cache.conv_states[0] is None
-
- def test_reset_clears_all_stochastic_mixers(self, all_mixers_config):
- """reset() clears ALL mixer caches in stochastic layer."""
- cache = Apriel2Cache(all_mixers_config)
- layer_idx = 1
-
- # Fill all mixers
- cache.set_active_mixer(layer_idx, "attention")
- cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=layer_idx)
-
- cache.set_active_mixer(layer_idx, "mamba")
- cache.conv_states[layer_idx] = torch.randn(2, 128, 4)
-
- cache.set_active_mixer(layer_idx, "kda")
- cache.conv_states[layer_idx] = (torch.randn(2, 64, 3),) * 3
-
- cache.reset()
-
- # All cleared
- assert cache.layers[layer_idx]["attention"].key is None
- assert cache.layers[layer_idx]["mamba"].conv is None
- assert cache.layers[layer_idx]["kda"].conv is None
-
- def test_crop_truncates_attention(self, tiny_attention_config, sample_kv):
- """crop() truncates attention cache to max_length."""
- cache = Apriel2Cache(tiny_attention_config)
- cache.update(*sample_kv, layer_idx=0)
-
- cache.crop(5)
-
- assert cache.get_seq_length(0) == 5
-
- def test_crop_affects_all_layers(self, tiny_attention_config, sample_kv):
- """crop() affects all layers."""
- cache = Apriel2Cache(tiny_attention_config)
- cache.update(*sample_kv, layer_idx=0)
- cache.update(*sample_kv, layer_idx=1)
-
- cache.crop(3)
-
- assert cache.get_seq_length(0) == 3
- assert cache.get_seq_length(1) == 3
-
- def test_crop_ignores_ssm(self, ssm_config, sample_conv_single):
- """crop() only affects attention, not SSM."""
- cache = Apriel2Cache(ssm_config)
- cache.conv_states[0] = sample_conv_single
-
- cache.crop(5) # Should not crash
-
- # Conv state unchanged
- torch.testing.assert_close(cache.conv_states[0], sample_conv_single)
-
-
-# =============================================================================
-# SECTION 6: BEAM SEARCH OPERATIONS
-# =============================================================================
-
-
-class TestBatchRepeatInterleave:
- """Test batch_repeat_interleave for beam search expansion."""
-
- def test_repeat_attention(self, tiny_attention_config, sample_kv):
- """Repeat attention cache for beam search."""
- cache = Apriel2Cache(tiny_attention_config)
- cache.update(*sample_kv, layer_idx=0)
-
- cache.batch_repeat_interleave(3)
-
- assert cache.max_batch_size == 6 # 2 * 3
-
- def test_repeat_ssm(self, ssm_config, sample_conv_single, sample_recurrent):
- """Repeat SSM cache for beam search."""
- cache = Apriel2Cache(ssm_config)
- cache.conv_states[0] = sample_conv_single
- cache.recurrent_states[0] = sample_recurrent
-
- cache.batch_repeat_interleave(4)
-
- assert cache.conv_states[0].shape[0] == 8 # 2 * 4
- assert cache.recurrent_states[0].shape[0] == 8
-
- def test_repeat_kda_tuple(self, kda_config, sample_conv_tuple):
- """Repeat KDA tuple conv states."""
- cache = Apriel2Cache(kda_config)
- cache.conv_states[0] = sample_conv_tuple
-
- cache.batch_repeat_interleave(3)
-
- for c in cache.conv_states[0]:
- assert c.shape[0] == 6
-
- def test_repeat_stochastic_all_mixers(self, all_mixers_config):
- """Repeat all mixer caches in stochastic layer."""
- cache = Apriel2Cache(all_mixers_config)
- layer_idx = 1
-
- cache.set_active_mixer(layer_idx, "attention")
- cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=layer_idx)
-
- cache.set_active_mixer(layer_idx, "mamba")
- cache.conv_states[layer_idx] = torch.randn(2, 128, 4)
-
- cache.batch_repeat_interleave(2)
-
- cache.set_active_mixer(layer_idx, "attention")
- assert cache.layers[layer_idx]["attention"].key.shape[0] == 4
-
- cache.set_active_mixer(layer_idx, "mamba")
- assert cache.conv_states[layer_idx].shape[0] == 4
-
- def test_repeat_skips_none(self, tiny_attention_config):
- """Repeat gracefully skips None caches."""
- cache = Apriel2Cache(tiny_attention_config)
- # Don't fill anything
-
- cache.batch_repeat_interleave(3) # Should not crash
-
- assert cache.max_batch_size is None
-
-
-class TestReorderCache:
- """Test reorder_cache for beam search hypothesis selection."""
-
- def test_reorder_attention(self, tiny_attention_config, sample_kv):
- """Reorder attention cache."""
- cache = Apriel2Cache(tiny_attention_config)
- key, value = sample_kv
- # Make batches distinguishable
- key = torch.arange(2).float().view(2, 1, 1, 1).expand(2, 4, 10, 16)
- cache.update(key, key.clone(), layer_idx=0)
-
- beam_idx = torch.tensor([1, 0])
- cache.reorder_cache(beam_idx)
-
- assert cache.layers[0].key[0, 0, 0, 0].item() == 1.0
- assert cache.layers[0].key[1, 0, 0, 0].item() == 0.0
-
- def test_reorder_ssm(self, ssm_config):
- """Reorder SSM cache."""
- cache = Apriel2Cache(ssm_config)
- conv = torch.arange(2).float().view(2, 1, 1).expand(2, 128, 4)
- cache.conv_states[0] = conv.clone()
-
- beam_idx = torch.tensor([1, 0])
- cache.reorder_cache(beam_idx)
-
- assert cache.conv_states[0][0, 0, 0].item() == 1.0
-
- def test_reorder_kda_tuple(self, kda_config):
- """Reorder KDA tuple conv states."""
- cache = Apriel2Cache(kda_config)
- conv_q = torch.arange(2).float().view(2, 1, 1).expand(2, 64, 3)
- cache.conv_states[0] = (conv_q.clone(), conv_q.clone(), conv_q.clone())
-
- beam_idx = torch.tensor([1, 0])
- cache.reorder_cache(beam_idx)
-
- for c in cache.conv_states[0]:
- assert c[0, 0, 0].item() == 1.0
-
-
-class TestBatchSelectIndices:
- """Test batch_select_indices for beam selection."""
-
- def test_select_attention(self, tiny_attention_config, sample_kv):
- """Select subset of attention cache."""
- cache = Apriel2Cache(tiny_attention_config)
- key = torch.arange(4).float().view(4, 1, 1, 1).expand(4, 4, 10, 16)
- cache.update(key, key.clone(), layer_idx=0)
-
- indices = torch.tensor([0, 3])
- cache.batch_select_indices(indices)
-
- assert cache.max_batch_size == 2
- assert cache.layers[0].key[0, 0, 0, 0].item() == 0.0
- assert cache.layers[0].key[1, 0, 0, 0].item() == 3.0
-
- def test_select_kda_tuple(self, kda_config):
- """Select subset of KDA tuple conv states."""
- cache = Apriel2Cache(kda_config)
- conv = tuple(torch.arange(4).float().view(4, 1, 1).expand(4, 64, 3).clone() for _ in range(3))
- cache.conv_states[0] = conv
-
- indices = torch.tensor([1, 2])
- cache.batch_select_indices(indices)
-
- for c in cache.conv_states[0]:
- assert c.shape[0] == 2
- assert c[0, 0, 0].item() == 1.0
-
-
-# =============================================================================
-# SECTION 7: HUGGINGFACE INTEGRATION
-# =============================================================================
-
-
-class TestGetMaskSizes:
- """Test get_mask_sizes() for attention mask computation."""
-
- def test_empty_cache(self, tiny_attention_config):
- """Mask sizes with empty cache."""
- cache = Apriel2Cache(tiny_attention_config)
- cache_position = torch.arange(10)
-
- kv_length, kv_offset = cache.get_mask_sizes(cache_position, layer_idx=0)
-
- assert kv_length == 10
- assert kv_offset == 0
-
- def test_with_cached_tokens(self, tiny_attention_config, sample_kv):
- """Mask sizes with cached tokens."""
- cache = Apriel2Cache(tiny_attention_config)
- cache.update(*sample_kv, layer_idx=0) # 10 tokens
-
- cache_position = torch.arange(5)
- kv_length, kv_offset = cache.get_mask_sizes(cache_position, layer_idx=0)
-
- assert kv_length == 15 # 10 + 5
- assert kv_offset == 10
-
- def test_single_token_decode(self, tiny_attention_config, sample_kv):
- """Mask sizes for single token decode."""
- cache = Apriel2Cache(tiny_attention_config)
- cache.update(*sample_kv, layer_idx=0)
-
- cache_position = torch.arange(1)
- kv_length, kv_offset = cache.get_mask_sizes(cache_position, layer_idx=0)
-
- assert kv_length == 11
- assert kv_offset == 10
-
- def test_ssm_returns_query_only(self, ssm_config, sample_conv_single):
- """SSM layers return query_length (no KV cache)."""
- cache = Apriel2Cache(ssm_config)
- cache.conv_states[0] = sample_conv_single
-
- cache_position = torch.arange(5)
- kv_length, kv_offset = cache.get_mask_sizes(cache_position, layer_idx=0)
-
- assert kv_length == 5
- assert kv_offset == 0
-
-
-class TestCacheIndexing:
- """Test cache[idx] indexing."""
-
- def test_attention_returns_kv(self, tiny_attention_config, sample_kv):
- """Indexing attention layer returns (key, value)."""
- cache = Apriel2Cache(tiny_attention_config)
- cache.update(*sample_kv, layer_idx=0)
-
- result = cache[0]
-
- assert isinstance(result, tuple)
- torch.testing.assert_close(result[0], sample_kv[0])
-
- def test_empty_returns_empty_tensors(self, tiny_attention_config):
- """Indexing empty layer returns empty tensors."""
- cache = Apriel2Cache(tiny_attention_config)
-
- result = cache[0]
-
- assert result[0].numel() == 0
- assert result[1].numel() == 0
-
- def test_ssm_returns_empty(self, ssm_config, sample_conv_single):
- """Indexing SSM layer returns empty (no KV)."""
- cache = Apriel2Cache(ssm_config)
- cache.conv_states[0] = sample_conv_single
-
- result = cache[0]
-
- assert result[0].numel() == 0
-
- def test_stochastic_attention_returns_kv(self, stochastic_config, sample_kv):
- """Indexing stochastic with attention active returns KV."""
- cache = Apriel2Cache(stochastic_config)
- cache.set_active_mixer(1, "attention")
- cache.update(*sample_kv, layer_idx=1)
-
- result = cache[1]
-
- torch.testing.assert_close(result[0], sample_kv[0])
-
-
-# =============================================================================
-# SECTION 8: GENERATION PATTERNS
-# =============================================================================
-
-
-class TestGenerationPatterns:
- """Test real-world generation patterns."""
-
- def test_prefill_then_decode(self, tiny_attention_config, sample_kv):
- """Prefill with long prompt, then decode token-by-token."""
- cache = Apriel2Cache(tiny_attention_config)
- cache.update(*sample_kv, layer_idx=0) # Prefill 10 tokens
-
- for _ in range(5):
- new_kv = (torch.randn(2, 4, 1, 16), torch.randn(2, 4, 1, 16))
- cache.update(*new_kv, layer_idx=0)
-
- assert cache.get_seq_length(0) == 15
-
- def test_crop_then_continue(self, tiny_attention_config, sample_kv):
- """Crop old context, continue generation."""
- cache = Apriel2Cache(tiny_attention_config)
- cache.update(*sample_kv, layer_idx=0)
- cache.update(*sample_kv, layer_idx=0) # 20 tokens
-
- cache.crop(5) # Keep last 5
- cache.update(torch.randn(2, 4, 3, 16), torch.randn(2, 4, 3, 16), layer_idx=0)
-
- assert cache.get_seq_length(0) == 8
-
- def test_reset_between_generations(self, tiny_attention_config, sample_kv):
- """Reset between independent generations."""
- cache = Apriel2Cache(tiny_attention_config)
-
- # First generation
- cache.update(*sample_kv, layer_idx=0)
- assert cache.is_initialized == True
-
- # Reset
- cache.reset()
- assert cache.is_initialized == False
-
- # Second generation
- cache.update(*sample_kv, layer_idx=0)
- assert cache.get_seq_length(0) == 10
-
- def test_multi_layer_consistency(self, tiny_attention_config, sample_kv):
- """All layers updated consistently."""
- cache = Apriel2Cache(tiny_attention_config)
-
- for layer_idx in range(2):
- cache.update(*sample_kv, layer_idx=layer_idx)
- cache.update(torch.randn(2, 4, 1, 16), torch.randn(2, 4, 1, 16), layer_idx=layer_idx)
-
- for layer_idx in range(2):
- assert cache.get_seq_length(layer_idx) == 11
-
-
-# =============================================================================
-# SECTION 9: ERROR HANDLING
-# =============================================================================
-
-
-class TestErrorHandling:
- """Test error conditions and guards."""
-
- def test_stochastic_update_without_active_mixer(self, stochastic_config):
- """update() on stochastic without active_mixer raises."""
- cache = Apriel2Cache(stochastic_config)
-
- with pytest.raises(RuntimeError, match="needs active_mixer set"):
- cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=1)
-
- def test_stochastic_accessor_without_active_mixer(self, stochastic_config):
- """Accessing stochastic cache without active_mixer raises."""
- cache = Apriel2Cache(stochastic_config)
-
- with pytest.raises(RuntimeError, match="requires set_active_mixer"):
- _ = cache.conv_states[1]
-
- def test_accessor_error_lists_available_mixers(self, stochastic_config):
- """Error message lists available mixers."""
- cache = Apriel2Cache(stochastic_config)
-
- with pytest.raises(RuntimeError, match="Available mixers:"):
- _ = cache.key_cache[1]
-
- def test_invalid_mixer_name(self, stochastic_config):
- """Invalid mixer name raises KeyError on access."""
- cache = Apriel2Cache(stochastic_config)
- cache.set_active_mixer(1, "nonexistent")
-
- with pytest.raises(KeyError):
- cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=1)
-
- def test_layer_idx_out_of_bounds(self, tiny_attention_config):
- """Out-of-bounds layer_idx raises IndexError."""
- cache = Apriel2Cache(tiny_attention_config)
-
- with pytest.raises(IndexError):
- cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=999)
-
-
-# =============================================================================
-# SECTION 10: INTERNAL CLASSES
-# =============================================================================
-
-
-class TestAttentionCacheInternal:
- """Test internal _AttentionCache class directly."""
-
- def test_unbounded_growth(self):
- """No window allows unbounded growth."""
- cache = _AttentionCache(window=None)
-
- for _ in range(10):
- cache.update(torch.randn(2, 4, 100, 16), torch.randn(2, 4, 100, 16))
-
- assert cache.key.shape[-2] == 1000
-
- def test_window_enforced(self):
- """Window caps cache size."""
- cache = _AttentionCache(window=50)
-
- for _ in range(10):
- cache.update(torch.randn(2, 4, 100, 16), torch.randn(2, 4, 100, 16))
-
- assert cache.key.shape[-2] == 50
-
-
-class TestSSMCacheInternal:
- """Test internal _SSMCache class directly."""
-
- def test_initial_none(self):
- """Initial states are None."""
- cache = _SSMCache()
-
- assert cache.conv is None
- assert cache.recurrent is None
-
- def test_stores_tuple(self):
- """Can store tuple (for KDA)."""
- cache = _SSMCache()
- cache.conv = (torch.randn(2, 64, 3),) * 3
-
- assert isinstance(cache.conv, tuple)
diff --git a/fast_llm_external_models/tests/test_apriel2/test_cache_apriel2_specific.py b/fast_llm_external_models/tests/test_apriel2/test_cache_apriel2_specific.py
new file mode 100644
index 000000000..b45779454
--- /dev/null
+++ b/fast_llm_external_models/tests/test_apriel2/test_cache_apriel2_specific.py
@@ -0,0 +1,341 @@
+"""Tests for Apriel2-specific cache behaviors with no HuggingFace equivalent.
+
+This module tests features unique to Apriel2Cache that cannot be validated
+against upstream HF implementations:
+
+1. Stochastic mixer routing (switching between attention/SSM per layer)
+2. Multi-mixer layer support
+3. Error handling and guard rails
+4. Beam search operations (batch_repeat, reorder, select)
+5. Crop operation
+
+Fixtures used from conftest.py:
+ - stochastic_config: Stochastic mixer config with attention and mamba
+ - attention_config: Pure attention config
+ - ssm_config: Pure SSM config
+"""
+
+import pytest
+import torch
+
+from fast_llm_external_models.apriel2.cache import Apriel2Cache
+
+# =============================================================================
+# STOCHASTIC MIXER ROUTING
+# =============================================================================
+
+
+class TestStochasticMixerRouting:
+ """Test routing operations to correct sub-cache in stochastic layers."""
+
+ def test_set_active_mixer(self, stochastic_config):
+ """set_active_mixer updates routing for layer."""
+ cache = Apriel2Cache(stochastic_config)
+
+ cache.set_active_mixer(0, "attention")
+ assert cache.active_mixers[0] == "attention"
+
+ cache.set_active_mixer(0, "mamba")
+ assert cache.active_mixers[0] == "mamba"
+
+ def test_update_routes_to_active_mixer(self, stochastic_config):
+ """update() stores in correct sub-cache based on active_mixer."""
+ cache = Apriel2Cache(stochastic_config)
+
+ # Route to attention
+ cache.set_active_mixer(0, "attention")
+ cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0)
+
+ # Attention sub-cache should have data
+ assert cache.layers[0]["attention"].key is not None
+ # Mamba sub-cache should be empty
+ assert cache.layers[0]["mamba"].conv is None
+
+ def test_each_mixer_has_independent_cache(self, stochastic_config):
+ """Each mixer in a stochastic layer has its own independent state."""
+ cache = Apriel2Cache(stochastic_config)
+
+ # Store in attention
+ cache.set_active_mixer(0, "attention")
+ cache.update(torch.randn(2, 4, 5, 16), torch.randn(2, 4, 5, 16), layer_idx=0)
+
+ # Switch to mamba and store
+ cache.set_active_mixer(0, "mamba")
+ cache.layers[0]["mamba"].conv = torch.randn(2, 64, 4)
+
+ # Attention data should be unchanged
+ assert cache.layers[0]["attention"].cumulative_length == 5
+
+ def test_switching_preserves_all_states(self, stochastic_config):
+ """Switching active_mixer doesn't clear other mixer's state."""
+ cache = Apriel2Cache(stochastic_config)
+
+ # Build up attention state
+ cache.set_active_mixer(0, "attention")
+ cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0)
+ attn_key = cache.layers[0]["attention"].key.clone()
+
+ # Switch to mamba
+ cache.set_active_mixer(0, "mamba")
+
+ # Attention state preserved
+ torch.testing.assert_close(cache.layers[0]["attention"].key, attn_key)
+
+
+# =============================================================================
+# ERROR HANDLING
+# =============================================================================
+
+
+class TestErrorHandling:
+ """Test guard rails and error messages."""
+
+ def test_update_without_active_mixer_raises(self, stochastic_config):
+ """update() on stochastic layer without active_mixer raises RuntimeError."""
+ cache = Apriel2Cache(stochastic_config)
+
+ with pytest.raises(RuntimeError, match="needs active_mixer set"):
+ cache.update(torch.randn(2, 4, 5, 16), torch.randn(2, 4, 5, 16), layer_idx=0)
+
+ def test_accessor_without_active_mixer_raises(self, stochastic_config):
+ """Accessing key_cache/value_cache without active_mixer raises RuntimeError."""
+ cache = Apriel2Cache(stochastic_config)
+
+ with pytest.raises(RuntimeError, match="requires set_active_mixer"):
+ _ = cache.key_cache[0]
+
+ def test_error_message_lists_available_mixers(self, stochastic_config):
+ """Error message includes list of available mixers."""
+ cache = Apriel2Cache(stochastic_config)
+
+ with pytest.raises(RuntimeError, match="attention.*mamba|mamba.*attention"):
+ _ = cache.key_cache[0]
+
+
+# =============================================================================
+# BEAM SEARCH OPERATIONS
+# =============================================================================
+
+
+class TestBeamSearchOperations:
+ """Test batch manipulation for beam search."""
+
+ def test_batch_repeat_interleave_attention(self, attention_config):
+ """batch_repeat_interleave expands batch dimension."""
+ cache = Apriel2Cache(attention_config)
+ cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0)
+
+ cache.batch_repeat_interleave(3)
+
+ assert cache.layers[0].key.shape[0] == 6 # 2 * 3
+
+ def test_batch_repeat_interleave_ssm(self, ssm_config):
+ """batch_repeat_interleave works for SSM caches."""
+ cache = Apriel2Cache(ssm_config)
+ cache.layers[0].conv = torch.randn(2, 64, 4)
+
+ cache.batch_repeat_interleave(3)
+
+ assert cache.layers[0].conv.shape[0] == 6
+
+ def test_batch_repeat_interleave_kda_tuple(self, ssm_config):
+ """batch_repeat_interleave handles KDA tuple conv states."""
+ cache = Apriel2Cache(ssm_config)
+ cache.layers[0].conv = (torch.randn(2, 64, 4),) * 3
+
+ cache.batch_repeat_interleave(3)
+
+ assert cache.layers[0].conv[0].shape[0] == 6
+
+ def test_reorder_cache_attention(self, attention_config):
+ """reorder_cache reorders batch dimension."""
+ cache = Apriel2Cache(attention_config)
+ k = torch.arange(4).float().view(4, 1, 1, 1).expand(4, 4, 10, 16)
+ cache.update(k, k.clone(), layer_idx=0)
+
+ beam_idx = torch.tensor([3, 2, 1, 0])
+ cache.reorder_cache(beam_idx)
+
+ # Check reordering
+ assert cache.layers[0].key[0, 0, 0, 0].item() == 3.0
+ assert cache.layers[0].key[3, 0, 0, 0].item() == 0.0
+
+ def test_batch_select_indices(self, attention_config):
+ """batch_select_indices selects subset of batch."""
+ cache = Apriel2Cache(attention_config)
+ cache.update(torch.randn(4, 4, 10, 16), torch.randn(4, 4, 10, 16), layer_idx=0)
+
+ indices = torch.tensor([0, 2])
+ cache.batch_select_indices(indices)
+
+ assert cache.layers[0].key.shape[0] == 2
+
+ def test_reorder_cache_ssm_tuple(self, ssm_config):
+ """reorder_cache handles KDA tuple conv states."""
+ cache = Apriel2Cache(ssm_config)
+ # Create distinguishable tensors for each batch position
+ conv0 = torch.full((1, 64, 4), 0.0)
+ conv1 = torch.full((1, 64, 4), 1.0)
+ conv2 = torch.full((1, 64, 4), 2.0)
+ cache.layers[0].conv = (
+ torch.cat([conv0, conv1, conv2], dim=0),
+ torch.cat([conv0, conv1, conv2], dim=0),
+ torch.cat([conv0, conv1, conv2], dim=0),
+ )
+
+ beam_idx = torch.tensor([2, 1, 0])
+ cache.reorder_cache(beam_idx)
+
+ # Check reordering: batch[0] should now have value 2.0
+ assert cache.layers[0].conv[0][0, 0, 0].item() == 2.0
+ assert cache.layers[0].conv[0][2, 0, 0].item() == 0.0
+
+ def test_batch_select_indices_ssm_tuple(self, ssm_config):
+ """batch_select_indices handles KDA tuple conv states."""
+ cache = Apriel2Cache(ssm_config)
+ cache.layers[0].conv = (torch.randn(4, 64, 4),) * 3
+
+ indices = torch.tensor([0, 2])
+ cache.batch_select_indices(indices)
+
+ assert cache.layers[0].conv[0].shape[0] == 2
+ assert cache.layers[0].conv[1].shape[0] == 2
+ assert cache.layers[0].conv[2].shape[0] == 2
+
+
+# =============================================================================
+# CROP OPERATION
+# =============================================================================
+
+
+class TestCropOperation:
+ """Test cache truncation."""
+
+ def test_crop_truncates_attention(self, attention_config):
+ """crop() truncates attention cache."""
+ cache = Apriel2Cache(attention_config)
+ cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0)
+
+ cache.crop(5)
+
+ assert cache.layers[0].key.shape[-2] == 5
+ assert cache.get_seq_length(0) == 5
+
+ def test_crop_affects_all_layers(self, attention_config):
+ """crop() affects all layers."""
+ cache = Apriel2Cache(attention_config)
+ cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0)
+ cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=1)
+
+ cache.crop(3)
+
+ assert cache.layers[0].key.shape[-2] == 3
+ assert cache.layers[1].key.shape[-2] == 3
+
+ def test_crop_ignores_ssm(self, ssm_config):
+ """crop() doesn't affect SSM caches (they don't have seq dimension)."""
+ cache = Apriel2Cache(ssm_config)
+ cache.layers[0].conv = torch.randn(2, 64, 4)
+
+ # Should not raise
+ cache.crop(5)
+
+ # SSM state unchanged
+ assert cache.layers[0].conv.shape == (2, 64, 4)
+
+
+# =============================================================================
+# CACHE PROPERTIES
+# =============================================================================
+
+
+class TestCacheProperties:
+ """Test cache property methods."""
+
+ def test_is_initialized_attention(self, attention_config):
+ """is_initialized True after update."""
+ cache = Apriel2Cache(attention_config)
+ assert not cache.is_initialized
+
+ cache.update(torch.randn(2, 4, 5, 16), torch.randn(2, 4, 5, 16), layer_idx=0)
+ assert cache.is_initialized
+
+ def test_is_initialized_ssm(self, ssm_config):
+ """is_initialized True after setting conv state."""
+ cache = Apriel2Cache(ssm_config)
+ assert not cache.is_initialized
+
+ cache.layers[0].conv = torch.randn(2, 64, 4)
+ assert cache.is_initialized
+
+ def test_has_previous_state_ssm_only(self, ssm_config):
+ """has_previous_state checks SSM conv states."""
+ cache = Apriel2Cache(ssm_config)
+ assert not cache.has_previous_state
+
+ cache.layers[0].conv = torch.randn(2, 64, 4)
+ assert cache.has_previous_state
+
+ def test_has_previous_state_ignores_attention(self, attention_config):
+ """has_previous_state ignores attention caches."""
+ cache = Apriel2Cache(attention_config)
+ cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0)
+
+ # Attention-only cache returns False for has_previous_state
+ assert not cache.has_previous_state
+
+ def test_reset_clears_ssm_states(self, ssm_config):
+ """reset() clears SSM conv and recurrent states."""
+ cache = Apriel2Cache(ssm_config)
+ cache.layers[0].conv = torch.randn(2, 64, 4)
+ cache.layers[0].recurrent = torch.randn(2, 64, 16)
+
+ cache.reset()
+
+ assert cache.layers[0].conv is None
+ assert cache.layers[0].recurrent is None
+
+ def test_max_batch_size_from_ssm_tuple(self, ssm_config):
+ """max_batch_size works with KDA tuple conv states."""
+ cache = Apriel2Cache(ssm_config)
+ cache.layers[0].conv = (torch.randn(3, 64, 4),) * 3
+
+ assert cache.max_batch_size == 3
+
+ def test_max_batch_size(self, attention_config):
+ """max_batch_size returns batch dimension."""
+ cache = Apriel2Cache(attention_config)
+ cache.update(torch.randn(3, 4, 10, 16), torch.randn(3, 4, 10, 16), layer_idx=0)
+
+ assert cache.max_batch_size == 3
+
+ def test_len_returns_num_layers(self, attention_config):
+ """__len__ returns number of layers."""
+ cache = Apriel2Cache(attention_config)
+ assert len(cache) == 2
+
+
+# =============================================================================
+# INDEXING
+# =============================================================================
+
+
+class TestCacheIndexing:
+ """Test __getitem__ for HF compatibility."""
+
+ def test_getitem_returns_kv_tuple(self, attention_config):
+ """cache[idx] returns (key, value) tuple."""
+ cache = Apriel2Cache(attention_config)
+ cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0)
+
+ k, v = cache[0]
+ assert k.shape == (2, 4, 10, 16)
+ assert v.shape == (2, 4, 10, 16)
+
+ def test_getitem_empty_returns_empty_tensors(self, attention_config):
+ """cache[idx] on empty cache returns empty tensors."""
+ cache = Apriel2Cache(attention_config)
+
+ k, v = cache[0]
+ assert k.numel() == 0
+ assert v.numel() == 0
diff --git a/fast_llm_external_models/tests/test_apriel2/test_cache_contracts.py b/fast_llm_external_models/tests/test_apriel2/test_cache_contracts.py
new file mode 100644
index 000000000..8ceabfb91
--- /dev/null
+++ b/fast_llm_external_models/tests/test_apriel2/test_cache_contracts.py
@@ -0,0 +1,591 @@
+"""Contract tests for Apriel2Cache against HuggingFace cache implementations.
+
+This module tests that Apriel2Cache components behave equivalently to their
+HuggingFace counterparts. This ensures compatibility with HF's generation
+infrastructure (mask creation, beam search, etc.).
+
+Mapping:
+ Apriel2 Component HuggingFace Equivalent
+ ----------------- ----------------------
+ _AttentionCache (no window) -> DynamicLayer
+ _AttentionCache (window) -> DynamicSlidingWindowLayer
+ _SSMCache -> MambaCache (different interface, same concept)
+
+Apriel2-specific features (stochastic routing, multi-mixer layers) are tested
+separately in test_cache_apriel2_specific.py since they have no HF equivalent.
+
+Fixtures used from conftest.py:
+ - batch_size, num_heads, head_dim: Tensor dimensions
+ - hf_dynamic_layer: HuggingFace DynamicLayer
+ - hf_sliding_layer: HuggingFace DynamicSlidingWindowLayer (parameterized by window_size)
+ - apriel_attention_cache: Apriel2 _AttentionCache (no window)
+ - apriel_sliding_cache: Apriel2 _AttentionCache (with window, parameterized)
+ - window_size: Parameterized window sizes [4, 8, 16, 32]
+ - attention_config, swa_config: Apriel2 configs
+"""
+
+import pytest
+import torch
+
+from fast_llm_external_models.apriel2.cache import Apriel2Cache, _AttentionCache
+
+# =============================================================================
+# SECTION 1: FULL ATTENTION - _AttentionCache vs DynamicLayer
+# =============================================================================
+
+
+class TestFullAttentionContract:
+ """Test _AttentionCache (no window) matches HuggingFace DynamicLayer.
+
+ DynamicLayer is the standard cache for full causal attention.
+ We test that our cache produces identical mask parameters.
+ """
+
+ # -------------------------------------------------------------------------
+ # get_seq_length: Must match exactly for generation to work
+ # -------------------------------------------------------------------------
+
+ @pytest.mark.parametrize("seq_len", [1, 5, 10, 50, 100])
+ def test_get_seq_length_after_prefill(
+ self, hf_dynamic_layer, apriel_attention_cache, batch_size, num_heads, head_dim, seq_len
+ ):
+ """After prefill, cumulative_length matches HF get_seq_length."""
+ key = torch.randn(batch_size, num_heads, seq_len, head_dim)
+ value = torch.randn(batch_size, num_heads, seq_len, head_dim)
+
+ hf_dynamic_layer.update(key.clone(), value.clone())
+ apriel_attention_cache.update(key.clone(), value.clone())
+
+ assert apriel_attention_cache.cumulative_length == hf_dynamic_layer.get_seq_length()
+
+ @pytest.mark.parametrize("prefill_len", [1, 5, 10])
+ @pytest.mark.parametrize("decode_steps", [1, 5, 10, 20])
+ def test_get_seq_length_during_decode(
+ self, hf_dynamic_layer, apriel_attention_cache, batch_size, num_heads, head_dim, prefill_len, decode_steps
+ ):
+ """During decode, cumulative_length tracks total tokens seen."""
+ # Prefill
+ key = torch.randn(batch_size, num_heads, prefill_len, head_dim)
+ value = torch.randn(batch_size, num_heads, prefill_len, head_dim)
+ hf_dynamic_layer.update(key.clone(), value.clone())
+ apriel_attention_cache.update(key.clone(), value.clone())
+
+ # Decode
+ for step in range(decode_steps):
+ key = torch.randn(batch_size, num_heads, 1, head_dim)
+ value = torch.randn(batch_size, num_heads, 1, head_dim)
+ hf_dynamic_layer.update(key.clone(), value.clone())
+ apriel_attention_cache.update(key.clone(), value.clone())
+
+ assert (
+ apriel_attention_cache.cumulative_length == hf_dynamic_layer.get_seq_length()
+ ), f"Mismatch at decode step {step}"
+
+ # -------------------------------------------------------------------------
+ # get_mask_sizes: Verify HF behavior for documentation
+ # -------------------------------------------------------------------------
+
+ @pytest.mark.parametrize("prefill_len", [1, 5, 10])
+ @pytest.mark.parametrize("decode_steps", [0, 1, 5, 10])
+ def test_hf_mask_sizes_kv_length(
+ self, hf_dynamic_layer, apriel_attention_cache, batch_size, num_heads, head_dim, prefill_len, decode_steps
+ ):
+ """Document HF's kv_length behavior and verify cumulative_length tracks correctly.
+
+ For full attention, kv_length = cumulative_length + query_length.
+ This test verifies our cache tracks tokens identically to HF.
+ """
+ # Prefill
+ key = torch.randn(batch_size, num_heads, prefill_len, head_dim)
+ value = torch.randn(batch_size, num_heads, prefill_len, head_dim)
+ hf_dynamic_layer.update(key.clone(), value.clone())
+ apriel_attention_cache.update(key.clone(), value.clone())
+
+ # Decode
+ for _ in range(decode_steps):
+ key = torch.randn(batch_size, num_heads, 1, head_dim)
+ value = torch.randn(batch_size, num_heads, 1, head_dim)
+ hf_dynamic_layer.update(key.clone(), value.clone())
+ apriel_attention_cache.update(key.clone(), value.clone())
+
+ # Verify cumulative_length matches HF
+ assert apriel_attention_cache.cumulative_length == hf_dynamic_layer.get_seq_length()
+
+ # Verify HF's kv_length follows the expected formula
+ cache_position = torch.arange(1) # Single token decode
+ hf_kv_len, hf_kv_offset = hf_dynamic_layer.get_mask_sizes(cache_position)
+ expected_kv_len = hf_dynamic_layer.get_seq_length() + cache_position.shape[0]
+ assert hf_kv_len == expected_kv_len
+
+ def test_hf_kv_offset_always_zero(self, hf_dynamic_layer, batch_size, num_heads, head_dim):
+ """Document that HF DynamicLayer always returns kv_offset=0.
+
+ For full attention, all cached KV pairs map to absolute positions
+ starting from 0, so kv_offset is always 0.
+ """
+ # Add many tokens
+ for _ in range(20):
+ key = torch.randn(batch_size, num_heads, 5, head_dim)
+ value = torch.randn(batch_size, num_heads, 5, head_dim)
+ hf_dynamic_layer.update(key.clone(), value.clone())
+
+ cache_position = torch.arange(1)
+ _, hf_kv_offset = hf_dynamic_layer.get_mask_sizes(cache_position)
+
+ assert hf_kv_offset == 0, "DynamicLayer always returns kv_offset=0"
+
+ # -------------------------------------------------------------------------
+ # update: Output shape and values must match
+ # -------------------------------------------------------------------------
+
+ @pytest.mark.parametrize("seq_len", [1, 5, 10])
+ def test_update_returns_same_shape(
+ self, hf_dynamic_layer, apriel_attention_cache, batch_size, num_heads, head_dim, seq_len
+ ):
+ """update() returns tensors with matching shapes."""
+ key = torch.randn(batch_size, num_heads, seq_len, head_dim)
+ value = torch.randn(batch_size, num_heads, seq_len, head_dim)
+
+ hf_k, hf_v = hf_dynamic_layer.update(key.clone(), value.clone())
+ apr_k, apr_v = apriel_attention_cache.update(key.clone(), value.clone())
+
+ assert hf_k.shape == apr_k.shape
+ assert hf_v.shape == apr_v.shape
+
+ def test_update_concatenates_identically(
+ self, hf_dynamic_layer, apriel_attention_cache, batch_size, num_heads, head_dim
+ ):
+ """Multiple updates produce identical concatenated states."""
+ # Use deterministic values for comparison
+ k1 = torch.arange(10).float().view(1, 1, 10, 1).expand(batch_size, num_heads, 10, head_dim)
+ v1 = k1.clone()
+
+ hf_dynamic_layer.update(k1.clone(), v1.clone())
+ apriel_attention_cache.update(k1.clone(), v1.clone())
+
+ k2 = torch.arange(10, 15).float().view(1, 1, 5, 1).expand(batch_size, num_heads, 5, head_dim)
+ v2 = k2.clone()
+
+ hf_k, hf_v = hf_dynamic_layer.update(k2.clone(), v2.clone())
+ apr_k, apr_v = apriel_attention_cache.update(k2.clone(), v2.clone())
+
+ torch.testing.assert_close(hf_k, apr_k)
+ torch.testing.assert_close(hf_v, apr_v)
+
+
+# =============================================================================
+# SECTION 2: SLIDING WINDOW - _AttentionCache vs DynamicSlidingWindowLayer
+# =============================================================================
+
+
+class TestSlidingWindowContract:
+ """Test _AttentionCache (with window) matches HuggingFace DynamicSlidingWindowLayer.
+
+ DynamicSlidingWindowLayer is used for sliding window attention (e.g., Mistral).
+ Critical behaviors:
+ - cumulative_length tracks ALL tokens seen (not just cached)
+ - kv_offset increases once window is exceeded
+ - kv_length is capped at window size
+
+ Uses fixtures from conftest.py:
+ - window_size: parameterized [4, 8, 16, 32]
+ - hf_sliding_layer: DynamicSlidingWindowLayer
+ - apriel_sliding_cache: _AttentionCache with window
+ """
+
+ # -------------------------------------------------------------------------
+ # cumulative_length: Must track total tokens, not cached tokens
+ # -------------------------------------------------------------------------
+
+ @pytest.mark.parametrize("prefill_len", [1, 3, 5, 10, 20])
+ def test_cumulative_length_matches_after_prefill(
+ self, hf_sliding_layer, apriel_sliding_cache, batch_size, num_heads, head_dim, prefill_len
+ ):
+ """cumulative_length matches HF get_seq_length after prefill."""
+ key = torch.randn(batch_size, num_heads, prefill_len, head_dim)
+ value = torch.randn(batch_size, num_heads, prefill_len, head_dim)
+
+ hf_sliding_layer.update(key.clone(), value.clone())
+ apriel_sliding_cache.update(key.clone(), value.clone())
+
+ assert apriel_sliding_cache.cumulative_length == hf_sliding_layer.get_seq_length()
+
+ def test_cumulative_length_continues_past_window(
+ self, hf_sliding_layer, apriel_sliding_cache, window_size, batch_size, num_heads, head_dim
+ ):
+ """cumulative_length keeps growing even after window is full."""
+ total_tokens = window_size * 3 # Way past window
+
+ for i in range(total_tokens):
+ key = torch.randn(batch_size, num_heads, 1, head_dim)
+ value = torch.randn(batch_size, num_heads, 1, head_dim)
+
+ hf_sliding_layer.update(key.clone(), value.clone())
+ apriel_sliding_cache.update(key.clone(), value.clone())
+
+ expected = i + 1
+ assert apriel_sliding_cache.cumulative_length == expected
+ assert hf_sliding_layer.get_seq_length() == expected
+
+ # -------------------------------------------------------------------------
+ # get_mask_sizes: kv_offset must increase once window is exceeded
+ # -------------------------------------------------------------------------
+
+ def test_kv_offset_zero_before_window_full(
+ self, hf_sliding_layer, apriel_sliding_cache, window_size, batch_size, num_heads, head_dim
+ ):
+ """kv_offset is 0 while cumulative < window.
+
+ Before the window is full, kv_offset should be 0 because all cached tokens
+ correspond to absolute positions starting from 0.
+ """
+ # Add tokens up to window-1
+ for i in range(window_size - 1):
+ key = torch.randn(batch_size, num_heads, 1, head_dim)
+ value = torch.randn(batch_size, num_heads, 1, head_dim)
+
+ hf_sliding_layer.update(key.clone(), value.clone())
+ apriel_sliding_cache.update(key.clone(), value.clone())
+
+ cache_position = torch.arange(1)
+ hf_kv_len, hf_kv_offset = hf_sliding_layer.get_mask_sizes(cache_position)
+
+ # Verify HF returns 0 offset before window full
+ assert hf_kv_offset == 0, f"HF offset should be 0 at step {i}"
+ # Verify Apriel cache tracks cumulative correctly
+ assert apriel_sliding_cache.cumulative_length == i + 1
+
+ def test_kv_offset_increases_after_window_full(
+ self, hf_sliding_layer, apriel_sliding_cache, window_size, batch_size, num_heads, head_dim
+ ):
+ """kv_offset increases once cumulative >= window.
+
+ Once the window is full, the cache discards oldest tokens. kv_offset tracks
+ which absolute position KV[0] corresponds to.
+ """
+ # Fill to exactly window
+ for _ in range(window_size):
+ key = torch.randn(batch_size, num_heads, 1, head_dim)
+ value = torch.randn(batch_size, num_heads, 1, head_dim)
+ hf_sliding_layer.update(key.clone(), value.clone())
+ apriel_sliding_cache.update(key.clone(), value.clone())
+
+ cache_position = torch.arange(1)
+ hf_kv_len, hf_kv_offset = hf_sliding_layer.get_mask_sizes(cache_position)
+
+ # At window boundary, offset should be 1
+ assert hf_kv_offset == 1, "HF offset should be 1 at window boundary"
+ assert apriel_sliding_cache.cumulative_length == window_size
+
+ # Add more tokens and verify offset keeps increasing with HF
+ for i in range(5):
+ key = torch.randn(batch_size, num_heads, 1, head_dim)
+ value = torch.randn(batch_size, num_heads, 1, head_dim)
+ hf_sliding_layer.update(key.clone(), value.clone())
+ apriel_sliding_cache.update(key.clone(), value.clone())
+
+ hf_kv_len, hf_kv_offset = hf_sliding_layer.get_mask_sizes(cache_position)
+
+ expected_offset = i + 2
+ assert hf_kv_offset == expected_offset
+ assert apriel_sliding_cache.cumulative_length == window_size + i + 1
+
+ def test_kv_length_capped_at_window(
+ self, hf_sliding_layer, apriel_sliding_cache, window_size, batch_size, num_heads, head_dim
+ ):
+ """kv_length is capped at window size once exceeded.
+
+ For a query of length 1 after the window is full, kv_length = window
+ (window-1 cached tokens + 1 query token).
+ """
+ # Way past window
+ for _ in range(window_size * 2):
+ key = torch.randn(batch_size, num_heads, 1, head_dim)
+ value = torch.randn(batch_size, num_heads, 1, head_dim)
+ hf_sliding_layer.update(key.clone(), value.clone())
+ apriel_sliding_cache.update(key.clone(), value.clone())
+
+ cache_position = torch.arange(1)
+ hf_kv_len, _ = hf_sliding_layer.get_mask_sizes(cache_position)
+
+ # HF returns window (window-1 cached + 1 query)
+ assert hf_kv_len == window_size
+ # Verify our cache tracked cumulative correctly
+ assert apriel_sliding_cache.cumulative_length == window_size * 2
+
+ # -------------------------------------------------------------------------
+ # Full sequence length tracking through generation
+ # -------------------------------------------------------------------------
+
+ @pytest.mark.parametrize("prefill_len", [1, 3, 5, 10, 20])
+ def test_cumulative_length_tracks_all_tokens(
+ self, hf_sliding_layer, apriel_sliding_cache, window_size, batch_size, num_heads, head_dim, prefill_len
+ ):
+ """cumulative_length tracks total tokens seen through prefill + decode.
+
+ This is the foundation for correct mask size computation. We verify that
+ our _AttentionCache tracks tokens identically to HuggingFace's DynamicSlidingWindowLayer.
+ The actual get_mask_sizes computation is tested in TestApriel2CacheIntegration.
+ """
+ # Prefill
+ key = torch.randn(batch_size, num_heads, prefill_len, head_dim)
+ value = torch.randn(batch_size, num_heads, prefill_len, head_dim)
+ hf_sliding_layer.update(key.clone(), value.clone())
+ apriel_sliding_cache.update(key.clone(), value.clone())
+
+ assert apriel_sliding_cache.cumulative_length == hf_sliding_layer.get_seq_length()
+
+ # Decode past window
+ for i in range(window_size + 10):
+ key = torch.randn(batch_size, num_heads, 1, head_dim)
+ value = torch.randn(batch_size, num_heads, 1, head_dim)
+ hf_sliding_layer.update(key.clone(), value.clone())
+ apriel_sliding_cache.update(key.clone(), value.clone())
+
+ assert (
+ apriel_sliding_cache.cumulative_length == hf_sliding_layer.get_seq_length()
+ ), f"cumulative_length mismatch at step {i}"
+
+
+# =============================================================================
+# SECTION 3: SSM CACHE - _SSMCache vs MambaCache concept
+# =============================================================================
+
+
+class TestSSMCacheContract:
+ """Document _SSMCache interface and verify basic contract.
+
+ Unlike attention caches which have HF equivalents (DynamicLayer, DynamicSlidingWindowLayer),
+ SSM caches have no direct HF counterpart with matching interface. HF's MambaCache uses
+ different methods (update_conv_state, update_ssm_state), so we can't do direct comparison.
+
+ These tests document the interface contract:
+ 1. `conv` and `recurrent` attributes for storing states
+ 2. Both support None (lazy initialization)
+ 3. `conv` can be tuple (for KDA which has separate q/k/v conv states)
+
+ Higher-level operations (reorder, batch_repeat, reset) are tested in
+ TestBeamSearchOperations in test_cache_apriel2_specific.py.
+ """
+
+ def test_conv_state_storage(self, ssm_cache):
+ """conv attribute stores conv states (batch, intermediate, kernel_size)."""
+ conv = torch.randn(2, 64, 4)
+ ssm_cache.conv = conv
+ torch.testing.assert_close(ssm_cache.conv, conv)
+
+ def test_recurrent_state_storage(self, ssm_cache):
+ """recurrent attribute stores SSM states (batch, intermediate, state_size)."""
+ recurrent = torch.randn(2, 64, 16)
+ ssm_cache.recurrent = recurrent
+ torch.testing.assert_close(ssm_cache.recurrent, recurrent)
+
+ def test_conv_state_tuple_for_kda(self, ssm_cache):
+ """conv can be tuple for KDA's separate q/k/v convolutions."""
+ conv_tuple = (torch.randn(2, 64, 4), torch.randn(2, 64, 4), torch.randn(2, 64, 4))
+ ssm_cache.conv = conv_tuple
+ assert isinstance(ssm_cache.conv, tuple)
+ assert len(ssm_cache.conv) == 3
+
+ def test_initial_states_none(self, ssm_cache):
+ """States are None initially (lazy initialization pattern)."""
+ assert ssm_cache.conv is None
+ assert ssm_cache.recurrent is None
+
+ def test_states_independent(self, ssm_cache):
+ """conv and recurrent states are independent."""
+ ssm_cache.conv = torch.randn(2, 64, 4)
+ assert ssm_cache.recurrent is None # recurrent unchanged
+
+ ssm_cache.recurrent = torch.randn(2, 64, 16)
+ assert ssm_cache.conv is not None # conv unchanged
+
+
+# =============================================================================
+# SECTION 4: APRIEL2CACHE INTEGRATION
+# =============================================================================
+
+
+class TestApriel2CacheIntegration:
+ """Test Apriel2Cache correctly delegates to underlying caches.
+
+ Uses fixtures from conftest.py:
+ - attention_config: Pure attention config
+ - swa_config: Sliding window attention config (window=8)
+ """
+
+ def test_get_seq_length_matches_dynamic_layer(self, attention_config):
+ """Apriel2Cache.get_seq_length matches DynamicLayer for full attention."""
+ from transformers.cache_utils import DynamicLayer
+
+ cache = Apriel2Cache(attention_config)
+ hf_layer = DynamicLayer()
+
+ key = torch.randn(2, 4, 10, 16)
+ value = torch.randn(2, 4, 10, 16)
+
+ cache.update(key.clone(), value.clone(), layer_idx=0)
+ hf_layer.update(key.clone(), value.clone())
+
+ assert cache.get_seq_length(0) == hf_layer.get_seq_length()
+
+ def test_get_mask_sizes_matches_dynamic_layer(self, attention_config):
+ """Apriel2Cache.get_mask_sizes matches DynamicLayer."""
+ from transformers.cache_utils import DynamicLayer
+
+ cache = Apriel2Cache(attention_config)
+ hf_layer = DynamicLayer()
+
+ key = torch.randn(2, 4, 10, 16)
+ value = torch.randn(2, 4, 10, 16)
+
+ cache.update(key.clone(), value.clone(), layer_idx=0)
+ hf_layer.update(key.clone(), value.clone())
+
+ cache_position = torch.arange(1)
+ hf_kv_len, hf_kv_offset = hf_layer.get_mask_sizes(cache_position)
+ apr_kv_len, apr_kv_offset = cache.get_mask_sizes(cache_position, layer_idx=0)
+
+ assert apr_kv_len == hf_kv_len
+ assert apr_kv_offset == hf_kv_offset
+
+ def test_get_mask_sizes_matches_sliding_layer(self, swa_config):
+ """Apriel2Cache.get_mask_sizes matches DynamicSlidingWindowLayer."""
+ from transformers.cache_utils import DynamicSlidingWindowLayer
+
+ cache = Apriel2Cache(swa_config)
+ hf_layer = DynamicSlidingWindowLayer(sliding_window=8)
+
+ # Fill past window
+ for _ in range(15):
+ key = torch.randn(2, 4, 1, 16)
+ value = torch.randn(2, 4, 1, 16)
+ cache.update(key.clone(), value.clone(), layer_idx=0)
+ hf_layer.update(key.clone(), value.clone())
+
+ cache_position = torch.arange(1)
+ hf_kv_len, hf_kv_offset = hf_layer.get_mask_sizes(cache_position)
+ apr_kv_len, apr_kv_offset = cache.get_mask_sizes(cache_position, layer_idx=0)
+
+ assert apr_kv_len == hf_kv_len
+ assert apr_kv_offset == hf_kv_offset
+
+ def test_reset_clears_cumulative_length(self, attention_config):
+ """reset() clears cumulative_length (matches DynamicLayer.reset)."""
+ cache = Apriel2Cache(attention_config)
+
+ cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0)
+ assert cache.get_seq_length(0) == 10
+
+ cache.reset()
+ assert cache.get_seq_length(0) == 0
+
+
+# =============================================================================
+# SECTION 5: MASK CORRECTNESS (SEMANTIC TESTS)
+# =============================================================================
+
+
+class TestMaskCorrectness:
+ """Test that mask parameters produce semantically correct masks.
+
+ These tests verify the END RESULT: masks created with our parameters
+ allow the correct attention patterns.
+ """
+
+ def test_full_attention_decode_can_attend_to_all(self):
+ """During decode, query can attend to all cached positions."""
+ from transformers.masking_utils import causal_mask_function, sdpa_mask
+
+ cache = _AttentionCache(window=None)
+
+ # Prefill + decode
+ for _ in range(10):
+ cache.update(torch.randn(1, 1, 1, 16), torch.randn(1, 1, 1, 16))
+
+ # Mask for decode step
+ cache_position = torch.tensor([10]) # Position of new token
+ kv_length = cache.cumulative_length + 1
+ kv_offset = 0
+
+ mask = sdpa_mask(
+ batch_size=1,
+ cache_position=cache_position,
+ kv_length=kv_length,
+ kv_offset=kv_offset,
+ mask_function=causal_mask_function,
+ )
+
+ if mask is not None:
+ # Query at position 10 should attend to positions 0-10
+ query_mask = mask[0, 0, 0, :]
+ for kv_idx in range(kv_length):
+ assert query_mask[kv_idx].item() == True, f"Should attend to position {kv_idx}"
+
+ @pytest.mark.parametrize("window_size", [4, 8, 16])
+ def test_sliding_window_decode_respects_window(self, window_size):
+ """During decode, query only attends within sliding window."""
+ from transformers.masking_utils import sdpa_mask, sliding_window_causal_mask_function
+
+ cache = _AttentionCache(window=window_size)
+
+ # Fill way past window
+ total_tokens = window_size * 2
+ for _ in range(total_tokens):
+ cache.update(torch.randn(1, 1, 1, 16), torch.randn(1, 1, 1, 16))
+
+ # Mask for decode step
+ cache_position = torch.tensor([total_tokens])
+ cumulative = cache.cumulative_length
+ kv_offset = max(cumulative - window_size + 1, 0)
+ kv_length = window_size - 1 + 1 # cached + query
+
+ mask = sdpa_mask(
+ batch_size=1,
+ cache_position=cache_position,
+ kv_length=kv_length,
+ kv_offset=kv_offset,
+ mask_function=sliding_window_causal_mask_function(window_size),
+ )
+
+ if mask is not None:
+ query_mask = mask[0, 0, 0, :]
+ query_pos = cache_position[0].item()
+
+ for kv_idx in range(kv_length):
+ abs_pos = kv_offset + kv_idx
+ in_window = abs_pos > query_pos - window_size
+ causal = abs_pos <= query_pos
+ expected = in_window and causal
+
+ assert (
+ query_mask[kv_idx].item() == expected
+ ), f"Position {abs_pos}: expected {expected}, got {query_mask[kv_idx].item()}"
+
+ def test_prefill_has_causal_pattern(self):
+ """During prefill, mask has proper causal (lower triangular) pattern."""
+ from transformers.masking_utils import causal_mask_function, sdpa_mask
+
+ cache = _AttentionCache(window=None)
+ cache.update(torch.randn(1, 1, 5, 16), torch.randn(1, 1, 5, 16))
+
+ cache_position = torch.arange(5)
+ kv_length = cache.cumulative_length
+ kv_offset = 0
+
+ mask = sdpa_mask(
+ batch_size=1,
+ cache_position=cache_position,
+ kv_length=kv_length,
+ kv_offset=kv_offset,
+ mask_function=causal_mask_function,
+ allow_is_causal_skip=False, # Force mask creation
+ )
+
+ if mask is not None:
+ # Check causal pattern
+ for q_idx in range(5):
+ for kv_idx in range(5):
+ expected = kv_idx <= q_idx
+ actual = mask[0, 0, q_idx, kv_idx].item()
+ assert actual == expected, f"q={q_idx}, kv={kv_idx}: expected {expected}"
diff --git a/fast_llm_external_models/tests/test_apriel2/test_causal_conv1d.py b/fast_llm_external_models/tests/test_apriel2/test_causal_conv1d.py
index ec6abc1d2..0567cd76e 100644
--- a/fast_llm_external_models/tests/test_apriel2/test_causal_conv1d.py
+++ b/fast_llm_external_models/tests/test_apriel2/test_causal_conv1d.py
@@ -24,7 +24,6 @@
from fast_llm_external_models.apriel2.modeling_apriel2 import CausalConv1d, _causal_conv1d_fn
-
# =============================================================================
# Fixtures
# =============================================================================
@@ -63,6 +62,7 @@ def kernel_size():
def to_device(conv: CausalConv1d, device: str) -> CausalConv1d:
"""Create a copy of conv on the specified device."""
import copy
+
return copy.deepcopy(conv).to(device)
@@ -71,7 +71,9 @@ def prefill(conv: CausalConv1d, x: torch.Tensor, state: torch.Tensor = None) ->
return conv(x, conv_state=state, return_final_state=True)
-def decode_sequence(conv: CausalConv1d, tokens: torch.Tensor, state: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+def decode_sequence(
+ conv: CausalConv1d, tokens: torch.Tensor, state: torch.Tensor
+) -> tuple[torch.Tensor, torch.Tensor]:
"""Decode multiple tokens one-by-one, return (stacked_outputs, final_state).
Args:
@@ -223,7 +225,7 @@ def test_chunked_prefill_cpu(self, conv, dim, total_len, chunk_size):
outputs = []
state = None
for start in range(0, total_len, chunk_size):
- chunk = x[:, :, start:start + chunk_size]
+ chunk = x[:, :, start : start + chunk_size]
out, state = prefill(conv, chunk, state)
outputs.append(out)
@@ -248,7 +250,7 @@ def test_chunked_prefill_cuda(self, conv, dim, total_len, chunk_size):
outputs = []
state = None
for start in range(0, total_len, chunk_size):
- chunk = x[:, :, start:start + chunk_size].cuda()
+ chunk = x[:, :, start : start + chunk_size].cuda()
out, state = prefill(conv_cuda, chunk, state)
outputs.append(out)
@@ -329,7 +331,7 @@ def test_all_cpu_paths_match(self, conv, dim):
outputs = []
state = None
for start in range(0, total_len, chunk_size):
- chunk = x[:, :, start:start + chunk_size]
+ chunk = x[:, :, start : start + chunk_size]
out, state = prefill(conv, chunk, state)
outputs.append(out)
path1 = torch.cat(outputs, dim=-1)
@@ -374,7 +376,7 @@ def test_all_paths_match_cross_device(self, conv, dim):
# CPU chunked
outputs, state = [], None
for start in range(0, total_len, chunk_size):
- out, state = prefill(conv, x[:, :, start:start + chunk_size], state)
+ out, state = prefill(conv, x[:, :, start : start + chunk_size], state)
outputs.append(out)
results["cpu_chunked"] = torch.cat(outputs, dim=-1)
@@ -393,7 +395,7 @@ def test_all_paths_match_cross_device(self, conv, dim):
# CUDA chunked
outputs, state = [], None
for start in range(0, total_len, chunk_size):
- out, state = prefill(conv_cuda, x[:, :, start:start + chunk_size].cuda(), state)
+ out, state = prefill(conv_cuda, x[:, :, start : start + chunk_size].cuda(), state)
outputs.append(out.cpu())
results["cuda_chunked"] = torch.cat(outputs, dim=-1)
@@ -431,8 +433,7 @@ def test_all_paths_match_cross_device(self, conv, dim):
for name, result in results.items():
tol = tolerances[name]
torch.testing.assert_close(
- result, reference, atol=tol, rtol=tol,
- msg=f"Path '{name}' diverged from reference"
+ result, reference, atol=tol, rtol=tol, msg=f"Path '{name}' diverged from reference"
)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required")
@@ -468,8 +469,8 @@ def test_long_decode_no_drift(self, conv, dim):
# Check no systematic drift (errors shouldn't consistently increase)
decode_errors = errors[prefill_len:]
- first_half = decode_errors[:len(decode_errors)//2].mean()
- second_half = decode_errors[len(decode_errors)//2:].mean()
+ first_half = decode_errors[: len(decode_errors) // 2].mean()
+ second_half = decode_errors[len(decode_errors) // 2 :].mean()
assert second_half < first_half * 2, "Errors growing over decode steps (drift detected)"
diff --git a/fast_llm_external_models/tests/test_apriel2/test_compose_configs.py b/fast_llm_external_models/tests/test_apriel2/test_compose_configs.py
index 0bd6ac88d..3413b9d25 100644
--- a/fast_llm_external_models/tests/test_apriel2/test_compose_configs.py
+++ b/fast_llm_external_models/tests/test_apriel2/test_compose_configs.py
@@ -20,7 +20,7 @@
import yaml
from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config
-from fast_llm_external_models.apriel2.conversion.config import apply_surgery, compose_configs
+from fast_llm_external_models.apriel2.conversion.config import compose_configs
class TestComposeConfigsLaws:
@@ -75,14 +75,10 @@ def source_config(self):
},
}
- def test_identity_empty_surgery(self, source_config):
- """Law 1: compose_configs(config, {}) == config"""
- result = compose_configs(source_config, {})
- assert result == source_config
-
- def test_identity_none_surgery(self, source_config):
- """Law 1: compose_configs(config, None) == config"""
- result = compose_configs(source_config, None)
+ @pytest.mark.parametrize("empty_surgery", [{}, None])
+ def test_identity(self, source_config, empty_surgery):
+ """Law 1: compose_configs(config, empty) == config for empty in [{}, None]"""
+ result = compose_configs(source_config, empty_surgery)
assert result == source_config
def test_override_explicit_values(self, source_config):
@@ -114,7 +110,7 @@ def test_same_type_inheritance(self, source_config):
assert mixer["head_size"] == 32 # Inherited
assert mixer["rope_theta"] == 10000.0 # Inherited
assert mixer["window_size"] == 512 # Added
- assert "init" not in mixer # Stripped by apply_surgery
+ # init is preserved for plan_surgery to see (stripped only at final output)
def test_cross_type_attention_to_gdn(self, source_config):
"""Law 5: attention→gdn derives GDN dims from attention geometry."""
@@ -239,8 +235,14 @@ def test_null_deletion(self, source_config):
assert "vision_encoder" not in result
- def test_init_stripped_from_result(self, source_config):
- """Verify `init` keys are stripped from final result."""
+ def test_init_preserved_for_plan_surgery(self, source_config):
+ """Verify `init` keys are preserved so plan_surgery can see them.
+
+ The `init` field controls weight initialization (transfer vs random).
+ It's preserved through composition and only stripped at final output.
+ """
+ from fast_llm_external_models.apriel2.conversion.config import strip_init_fields
+
surgery = {
"decoder": {
"block": {
@@ -252,20 +254,20 @@ def test_init_stripped_from_result(self, source_config):
"gdn": {"type": "gdn", "init": "random", "convolution_layer": {"kernel_size": 4}},
},
},
- "mlp": {"init": "transfer"},
- "normalization": {"init": "transfer"},
},
},
}
result = compose_configs(source_config, surgery)
- def check_no_init(d, path=""):
- assert "init" not in d, f"Found 'init' key at {path}"
- for k, v in d.items():
- if isinstance(v, dict):
- check_no_init(v, f"{path}.{k}")
+ # init is preserved in composed config
+ mixers = result["decoder"]["block"]["mixer"]["mixers"]
+ assert mixers["attention"].get("init") == "transfer"
+ assert mixers["gdn"].get("init") == "random"
- check_no_init(result)
+ # strip_init_fields removes them for final output
+ stripped = strip_init_fields(result)
+ assert "init" not in stripped["decoder"]["block"]["mixer"]["mixers"]["attention"]
+ assert "init" not in stripped["decoder"]["block"]["mixer"]["mixers"]["gdn"]
def test_init_random_still_inherits_config(self, source_config):
"""init: random is for weights only - config params still inherited."""
@@ -287,6 +289,212 @@ def test_init_random_still_inherits_config(self, source_config):
assert mixer["head_groups"] == 4
assert mixer["window_size"] == 512
+ # =========================================================================
+ # Monoid Laws: compose_configs forms a monoid action on configs
+ # =========================================================================
+
+ def test_surgery_monoid_associativity(self):
+ """MONOID: merge(merge(A, B), C) == merge(A, merge(B, C)) for partial configs."""
+ surgery_a = {"decoder": {"block": {"mixer": {"type": "stochastic", "main_mixer_name": "attention"}}}}
+ surgery_b = {"decoder": {"block": {"mixer": {"mixers": {"sliding_window": {"window_size": 512}}}}}}
+ surgery_c = {"decoder": {"block": {"mixer": {"mixers": {"gdn": {"type": "gdn"}}}}}}
+
+ # Left-associated: (A ∘ B) ∘ C
+ ab_c = compose_configs(compose_configs(surgery_a, surgery_b), surgery_c)
+ # Right-associated: A ∘ (B ∘ C)
+ a_bc = compose_configs(surgery_a, compose_configs(surgery_b, surgery_c))
+
+ assert ab_c == a_bc, "Surgery monoid should be associative"
+
+ @pytest.mark.parametrize("num_surgeries", [2, 3])
+ def test_monoid_action_compatibility(self, source_config, num_surgeries):
+ """MONOID ACTION: apply(apply(c, A), B) == apply(c, merge(A, B))
+
+ This is the key law: applying surgeries sequentially equals merging first.
+ Parameterized to test with 2 and 3 surgeries.
+ """
+ surgeries = [
+ {
+ "decoder": {
+ "block": {
+ "mixer": {"type": "stochastic", "main_mixer_name": "attention", "mixers": {"attention": {}}}
+ }
+ }
+ },
+ {"decoder": {"block": {"mixer": {"mixers": {"sliding_window": {"window_size": 512}}}}}},
+ {"decoder": {"block": {"mixer": {"mixers": {"gdn": {"type": "gdn"}}}}}},
+ ][:num_surgeries]
+
+ # Sequential: ((c ⊳ A) ⊳ B) ⊳ ...
+ result_sequential = source_config
+ for s in surgeries:
+ result_sequential = compose_configs(result_sequential, s)
+
+ # Merged: c ⊳ (A ∘ B ∘ ...)
+ merged = surgeries[0]
+ for s in surgeries[1:]:
+ merged = compose_configs(merged, s)
+ result_merged = compose_configs(source_config, merged)
+
+ assert result_sequential == result_merged, f"Monoid action compatibility failed for {num_surgeries} surgeries"
+
+
+class TestBiasConfigInheritance:
+ """Test per-layer bias inheritance through surgery composition.
+
+ These tests verify that the per-layer bias configuration (mirroring Fast-LLM's
+ AffineLinearConfig) is correctly inherited through surgery operations:
+ - query_layer.bias.enabled, key_layer.bias.enabled, etc. for attention
+ - layer_1.bias.enabled, layer_2.bias.enabled for MLP
+ """
+
+ @pytest.fixture
+ def source_config_with_bias(self):
+ """Source config with Qwen-style bias (QKV enabled, O disabled)."""
+ return {
+ "model_type": "apriel2",
+ "architectures": ["Apriel2ForCausalLM"],
+ "hidden_size": 256,
+ "vocab_size": 1000,
+ "decoder": {
+ "type": "fixed",
+ "num_blocks": 4,
+ "block": {
+ "mixer": {
+ "type": "attention",
+ "heads": 8,
+ "head_groups": 4,
+ "head_size": 32,
+ "rotary": {"type": "mistral_1d", "theta": 10000.0},
+ # Qwen-style per-layer bias
+ "query_layer": {"bias": {"enabled": True}},
+ "key_layer": {"bias": {"enabled": True}},
+ "value_layer": {"bias": {"enabled": True}},
+ "dense_layer": {"bias": {"enabled": False}},
+ },
+ "mlp": {
+ "type": "mlp",
+ "intermediate_size": 512,
+ "gated": False,
+ # Per-layer MLP bias
+ "layer_1": {"bias": {"enabled": True}},
+ "layer_2": {"bias": {"enabled": False}},
+ },
+ "normalization": {"type": "rms_norm", "epsilon": 1e-5},
+ },
+ },
+ }
+
+ def test_same_type_inherits_attention_bias(self, source_config_with_bias):
+ """Same-type surgery inherits per-layer attention bias settings."""
+ surgery = {
+ "decoder": {
+ "block": {
+ "mixer": {
+ "window_size": 512, # Add sliding window behavior
+ },
+ },
+ },
+ }
+ result = compose_configs(source_config_with_bias, surgery)
+
+ mixer = result["decoder"]["block"]["mixer"]
+ assert mixer["query_layer"]["bias"]["enabled"] is True
+ assert mixer["key_layer"]["bias"]["enabled"] is True
+ assert mixer["value_layer"]["bias"]["enabled"] is True
+ assert mixer["dense_layer"]["bias"]["enabled"] is False
+
+ def test_same_type_inherits_mlp_bias(self, source_config_with_bias):
+ """Same-type surgery inherits per-layer MLP bias settings."""
+ surgery = {
+ "decoder": {
+ "block": {
+ "mlp": {
+ "intermediate_size": 1024, # Change size
+ },
+ },
+ },
+ }
+ result = compose_configs(source_config_with_bias, surgery)
+
+ mlp = result["decoder"]["block"]["mlp"]
+ assert mlp["layer_1"]["bias"]["enabled"] is True
+ assert mlp["layer_2"]["bias"]["enabled"] is False
+ assert mlp["intermediate_size"] == 1024
+
+ def test_cross_type_attention_to_sliding_window_preserves_bias(self, source_config_with_bias):
+ """attention→sliding_window cross-type preserves per-layer bias."""
+ surgery = {
+ "decoder": {
+ "block": {
+ "mixer": {
+ "type": "sliding_window", # Cross-type derivation
+ "window_size": 512,
+ },
+ },
+ },
+ }
+ result = compose_configs(source_config_with_bias, surgery)
+
+ mixer = result["decoder"]["block"]["mixer"]
+ assert mixer["type"] == "sliding_window"
+ # Bias settings preserved through cross-type
+ assert mixer["query_layer"]["bias"]["enabled"] is True
+ assert mixer["key_layer"]["bias"]["enabled"] is True
+ assert mixer["value_layer"]["bias"]["enabled"] is True
+ assert mixer["dense_layer"]["bias"]["enabled"] is False
+
+ def test_stochastic_wrapper_inherits_bias(self, source_config_with_bias):
+ """Wrapping in stochastic inherits bias settings to all sub-mixers."""
+ surgery = {
+ "decoder": {
+ "block": {
+ "mixer": {
+ "type": "stochastic",
+ "main_mixer_name": "attention",
+ "mixers": {
+ "attention": {"init": "transfer"},
+ "sliding_window": {
+ "type": "sliding_window",
+ "window_size": 512,
+ "init": "transfer",
+ },
+ },
+ },
+ },
+ },
+ }
+ result = compose_configs(source_config_with_bias, surgery)
+
+ mixers = result["decoder"]["block"]["mixer"]["mixers"]
+
+ # Attention sub-mixer inherits bias
+ assert mixers["attention"]["query_layer"]["bias"]["enabled"] is True
+ assert mixers["attention"]["dense_layer"]["bias"]["enabled"] is False
+
+ # Sliding window sub-mixer also inherits bias
+ assert mixers["sliding_window"]["query_layer"]["bias"]["enabled"] is True
+ assert mixers["sliding_window"]["dense_layer"]["bias"]["enabled"] is False
+
+ def test_surgery_can_override_bias(self, source_config_with_bias):
+ """Surgery can explicitly override inherited bias settings."""
+ surgery = {
+ "decoder": {
+ "block": {
+ "mixer": {
+ "dense_layer": {"bias": {"enabled": True}}, # Override O bias
+ },
+ },
+ },
+ }
+ result = compose_configs(source_config_with_bias, surgery)
+
+ mixer = result["decoder"]["block"]["mixer"]
+ # Q/K/V unchanged
+ assert mixer["query_layer"]["bias"]["enabled"] is True
+ # O bias overridden
+ assert mixer["dense_layer"]["bias"]["enabled"] is True
+
class TestComposeConfigsRealYAML:
"""Test compose_configs with real YAML surgery files."""
@@ -398,160 +606,12 @@ def test_build_plan_returns_complete_config(self, llava_pixtral_checkpoint):
mixer = config.decoder["block"]["mixer"]
assert mixer["type"] == "stochastic"
- # Each sub-mixer should have complete config (no init keys)
+ # Each sub-mixer should have complete config
+ # (init is preserved for plan_surgery, stripped only at final output)
for name, sub_mixer in mixer["mixers"].items():
- assert "init" not in sub_mixer, f"Sub-mixer {name} still has 'init' key"
assert "type" in sub_mixer
-class TestMonoidLaws:
- """Test the algebraic laws of compose_configs.
-
- Surgery specs form a MONOID under deep-merge:
- - Identity: {}
- - Operation: deep merge (overlay wins)
- - Associativity: merge(merge(A, B), C) == merge(A, merge(B, C))
-
- compose_configs is a MONOID ACTION on configs:
- - Identity action: apply(config, {}) == config
- - Compatibility: apply(apply(c, A), B) == apply(c, merge(A, B))
- """
-
- @pytest.fixture
- def complete_config(self):
- """A complete Apriel2 config."""
- return {
- "model_type": "apriel2",
- "architectures": ["Apriel2ForConditionalGeneration"],
- "hidden_size": 256,
- "vocab_size": 1000,
- "bos_token_id": 1,
- "eos_token_id": 2,
- "tie_word_embeddings": False,
- "image_token_index": 100,
- "decoder": {
- "type": "fixed",
- "num_blocks": 4,
- "block": {
- "mixer": {
- "type": "attention",
- "heads": 8,
- "head_groups": 4,
- "head_size": 32,
- "rope_theta": 10000.0,
- },
- "mlp": {"type": "mlp", "intermediate_size": 512},
- "normalization": {"type": "rms_norm", "epsilon": 1e-5},
- },
- },
- }
-
- @pytest.fixture
- def surgery_a(self):
- """First surgery: wrap in stochastic with attention."""
- return {
- "decoder": {
- "block": {
- "mixer": {
- "type": "stochastic",
- "main_mixer_name": "attention",
- "mixers": {
- "attention": {"init": "transfer"},
- },
- },
- },
- },
- }
-
- @pytest.fixture
- def surgery_b(self):
- """Second surgery: add sliding window mixer."""
- return {
- "decoder": {
- "block": {
- "mixer": {
- "mixers": {
- "sliding_window": {"init": "transfer", "window_size": 512},
- },
- },
- },
- },
- }
-
- def test_identity_action(self, complete_config):
- """apply(config, {}) == config"""
- result = compose_configs(complete_config, {})
- assert result == complete_config
-
- def test_surgery_monoid_associativity(self, surgery_a, surgery_b):
- """merge(merge(A, B), C) == merge(A, merge(B, C)) for partial configs."""
- surgery_c = {
- "decoder": {
- "block": {
- "mixer": {
- "mixers": {
- "gdn": {"type": "gdn", "init": "transfer", "convolution_layer": {"kernel_size": 4}},
- },
- },
- },
- },
- }
-
- # Left-associated: (A ∘ B) ∘ C
- ab = compose_configs(surgery_a, surgery_b)
- ab_c = compose_configs(ab, surgery_c)
-
- # Right-associated: A ∘ (B ∘ C)
- bc = compose_configs(surgery_b, surgery_c)
- a_bc = compose_configs(surgery_a, bc)
-
- assert ab_c == a_bc, "Surgery monoid should be associative"
-
- def test_monoid_action_compatibility(self, complete_config, surgery_a, surgery_b):
- """apply(apply(c, A), B) == apply(c, merge(A, B))
-
- This is the key law: applying surgeries sequentially should equal
- merging the surgeries first, then applying once.
- """
- # Sequential application: (c ⊳ A) ⊳ B
- result_sequential = compose_configs(compose_configs(complete_config, surgery_a), surgery_b)
-
- # Merged application: c ⊳ (A ∘ B)
- merged_surgery = compose_configs(surgery_a, surgery_b)
- result_merged = compose_configs(complete_config, merged_surgery)
-
- # These should be equivalent
- assert result_sequential == result_merged, "Monoid action should satisfy compatibility law"
-
- def test_three_way_compatibility(self, complete_config, surgery_a, surgery_b):
- """Test with three surgeries for stronger confidence."""
- surgery_c = {
- "decoder": {
- "block": {
- "mixer": {
- "mixers": {
- "gdn": {"type": "gdn", "init": "transfer", "convolution_layer": {"kernel_size": 4}},
- },
- },
- },
- },
- }
-
- # Sequential: ((c ⊳ A) ⊳ B) ⊳ C
- seq = compose_configs(
- compose_configs(compose_configs(complete_config, surgery_a), surgery_b),
- surgery_c
- )
-
- # Merged: c ⊳ ((A ∘ B) ∘ C)
- merged = compose_configs(
- complete_config,
- compose_configs(compose_configs(surgery_a, surgery_b), surgery_c)
- )
-
- assert seq == merged, "Three-way monoid action should satisfy compatibility"
-
-
class TestCompositionTortureTest:
"""Comprehensive stress test for config composition.
@@ -650,19 +710,29 @@ def test_final_config_structure(self, complete_config, additive_surgery_chain):
assert mixer["mixers"]["sliding_window"]["window_size"] == 512
assert mixer["mixers"]["gdn"]["value_heads"] == 16
- def test_no_init_keys_in_result(self, complete_config, additive_surgery_chain):
- """Verify no 'init' keys leak through."""
+ def test_init_keys_preserved_for_planning(self, complete_config, additive_surgery_chain):
+ """Verify 'init' keys are preserved for plan_surgery to see.
- def check_no_init(d, path=""):
- if isinstance(d, dict):
- assert "init" not in d, f"Found 'init' key at {path}"
- for k, v in d.items():
- check_no_init(v, f"{path}.{k}")
+ The `init` field is metadata for weight initialization. It's preserved
+ through composition and only stripped when saving final output.
+ """
+ from fast_llm_external_models.apriel2.conversion.config import strip_init_fields
result = complete_config
for i, surgery in enumerate(additive_surgery_chain):
result = compose_configs(result, surgery)
- check_no_init(result, f"step_{i+1}")
+
+ # init should be in the composed config
+ mixer = result["decoder"]["block"]["mixer"]
+ if "mixers" in mixer:
+ has_init = any("init" in m for m in mixer["mixers"].values())
+ assert has_init, "init should be preserved in composed config"
+
+ # strip_init_fields removes them
+ stripped = strip_init_fields(result)
+ mixer = stripped["decoder"]["block"]["mixer"]
+ if "mixers" in mixer:
+ assert all("init" not in m for m in mixer["mixers"].values())
def test_full_torture_chain(self, complete_config, torture_surgery_chain):
"""Test the full 10-step torture chain produces valid configs."""
diff --git a/fast_llm_external_models/tests/test_apriel2/test_plan_composition_torture.py b/fast_llm_external_models/tests/test_apriel2/test_conversion_e2e.py
similarity index 78%
rename from fast_llm_external_models/tests/test_apriel2/test_plan_composition_torture.py
rename to fast_llm_external_models/tests/test_apriel2/test_conversion_e2e.py
index 3b4adc7f5..b91fb7e51 100644
--- a/fast_llm_external_models/tests/test_apriel2/test_plan_composition_torture.py
+++ b/fast_llm_external_models/tests/test_apriel2/test_conversion_e2e.py
@@ -1,6 +1,6 @@
-"""End-to-end torture test for plan composition.
+"""test_conversion_e2e.py - End-to-end conversion integration tests.
-This tests the FULL pipeline at every step of a surgery chain:
+Tests the FULL pipeline at every step of a surgery chain:
1. Config composition produces valid configs
2. Plan building works for each surgery
3. Plan execution produces valid weights
@@ -16,21 +16,12 @@
import pytest
import torch
-from fast_llm_external_models.tests.test_apriel2.conftest import requires_cuda
-
from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config
-from fast_llm_external_models.apriel2.conversion import (
- compose,
- compose_configs,
- execute,
- plan_surgery,
-)
-from fast_llm_external_models.apriel2.conversion.llava import (
- convert_config as convert_llava_config,
- plan_llava_to_apriel2,
-)
+from fast_llm_external_models.apriel2.conversion import compose, compose_configs, execute, plan_surgery
+from fast_llm_external_models.apriel2.conversion.llava import convert_config as convert_llava_config
+from fast_llm_external_models.apriel2.conversion.llava import plan_llava_to_apriel2
from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForConditionalGeneration
-
+from fast_llm_external_models.tests.test_apriel2.conftest import requires_cuda
# =============================================================================
# Cycling Surgery Generation
@@ -87,40 +78,20 @@ def generate_cycling_surgeries(config: dict) -> list[tuple[dict, str]]:
if sub_name != main_mixer:
# Build surgery path based on block_path
if block_path == "block":
- surgery = {
- "decoder": {
- "block": {"mixer": {"main_mixer_name": sub_name}}
- }
- }
+ surgery = {"decoder": {"block": {"mixer": {"main_mixer_name": sub_name}}}}
else:
# block_path is "blocks.block_name"
block_name = block_path.split(".")[1]
- surgery = {
- "decoder": {
- "blocks": {
- block_name: {"mixer": {"main_mixer_name": sub_name}}
- }
- }
- }
+ surgery = {"decoder": {"blocks": {block_name: {"mixer": {"main_mixer_name": sub_name}}}}}
surgeries.append((surgery, f"cycle {block_path} to {sub_name}"))
# Restore original main_mixer_name
if any(sub_name != main_mixer for sub_name in sub_mixer_names):
if block_path == "block":
- restore = {
- "decoder": {
- "block": {"mixer": {"main_mixer_name": main_mixer}}
- }
- }
+ restore = {"decoder": {"block": {"mixer": {"main_mixer_name": main_mixer}}}}
else:
block_name = block_path.split(".")[1]
- restore = {
- "decoder": {
- "blocks": {
- block_name: {"mixer": {"main_mixer_name": main_mixer}}
- }
- }
- }
+ restore = {"decoder": {"blocks": {block_name: {"mixer": {"main_mixer_name": main_mixer}}}}}
surgeries.append((restore, f"restore {block_path} to {main_mixer}"))
return surgeries
@@ -194,9 +165,7 @@ def source_config(self, llava_pixtral_checkpoint):
with open(llava_pixtral_checkpoint / "config.json") as f:
return json.load(f)
- def test_initial_conversion_produces_working_model(
- self, source_config, source_weights
- ):
+ def test_initial_conversion_produces_working_model(self, source_config, source_weights):
"""Test that Llava → Apriel2 conversion produces a working model."""
# Convert config
apriel2_config_dict = convert_llava_config(source_config)
@@ -219,9 +188,7 @@ def test_initial_conversion_produces_working_model(
assert outputs.logits.shape == (1, 8, config.vocab_size)
- def test_each_surgery_step_produces_working_model(
- self, source_config, source_weights, additive_surgery_chain
- ):
+ def test_each_surgery_step_produces_working_model(self, source_config, source_weights, additive_surgery_chain):
"""Test that each surgery step produces a model that can forward pass.
Key insight: Surgery plans reference Apriel2 keys, so we must COMPOSE
@@ -290,9 +257,7 @@ def test_each_surgery_step_produces_working_model(
except Exception as e:
pytest.fail(f"Step {i+1}: Forward pass failed - {e}")
- def test_all_stochastic_submixers_via_cycling(
- self, source_config, source_weights, additive_surgery_chain
- ):
+ def test_all_stochastic_submixers_via_cycling(self, source_config, source_weights, additive_surgery_chain):
"""Test ALL sub-mixers in stochastic blocks, not just the main mixer.
Problem: Forward pass only exercises the main_mixer_name. Other sub-mixers
@@ -312,9 +277,7 @@ def test_all_stochastic_submixers_via_cycling(
conversion_plan = plan_llava_to_apriel2(source_config)
# Expand surgery chain with cycling
- expanded_chain = expand_surgery_chain_with_cycling(
- additive_surgery_chain, apriel2_config
- )
+ expanded_chain = expand_surgery_chain_with_cycling(additive_surgery_chain, apriel2_config)
# Build cumulative plan: conversion | surgery_1 | cycling_1a | ... | restore_1 | surgery_2 | ...
current_plan = conversion_plan
@@ -359,9 +322,7 @@ def test_all_stochastic_submixers_via_cycling(
except Exception as e:
pytest.fail(f"{desc}: Forward pass failed - {e}")
- def test_composed_plan_equals_sequential_execution(
- self, source_config, source_weights, additive_surgery_chain
- ):
+ def test_composed_plan_equals_sequential_execution(self, source_config, source_weights, additive_surgery_chain):
"""Test that composing plans gives same result as sequential execution.
This verifies plan composition associativity:
@@ -399,13 +360,9 @@ def test_composed_plan_equals_sequential_execution(
# Compare weights
for key in seq_weights:
if key in composed_weights:
- assert torch.allclose(
- seq_weights[key], composed_weights[key], atol=1e-5
- ), f"Weight mismatch for {key}"
+ assert torch.allclose(seq_weights[key], composed_weights[key], atol=1e-5), f"Weight mismatch for {key}"
- def test_final_model_structure(
- self, source_config, source_weights, additive_surgery_chain
- ):
+ def test_final_model_structure(self, source_config, source_weights, additive_surgery_chain):
"""Verify the final model has the expected structure."""
# Initial conversion
current_config = convert_llava_config(source_config)
@@ -504,9 +461,7 @@ def base_setup(self, llava_pixtral_checkpoint):
"""Set up base config and weights after Llava conversion."""
from safetensors.torch import load_file
- from fast_llm_external_models.apriel2.conversion.llava import (
- convert_config as convert_llava_config,
- )
+ from fast_llm_external_models.apriel2.conversion.llava import convert_config as convert_llava_config
# Load source config and weights
with open(llava_pixtral_checkpoint / "config.json") as f:
@@ -534,9 +489,7 @@ def _merge_surgeries(self, surgeries: list[dict]) -> dict:
result = _deep_merge(result, s)
return result
- def _build_incremental_plans(
- self, base_config: dict, surgeries: list[dict]
- ) -> tuple[list, list[dict]]:
+ def _build_incremental_plans(self, base_config: dict, surgeries: list[dict]) -> tuple[list, list[dict]]:
"""Build incremental plans for each surgery step.
Returns (plans, configs) where configs[i] is the config after surgery i.
@@ -552,9 +505,7 @@ def _build_incremental_plans(
config = target_config
return plans, configs
- def test_incremental_equals_direct_full_chain(
- self, base_setup, additive_surgery_chain
- ):
+ def test_incremental_equals_direct_full_chain(self, base_setup, additive_surgery_chain):
"""Test that composing all incremental plans equals one direct plan.
compose(P1, P2, ..., Pn) ≡ plan_surgery(base, final)
@@ -575,9 +526,7 @@ def test_incremental_equals_direct_full_chain(
direct_plan = plan_surgery(base_config, final_config)
# Verify same target keys
- assert set(composed_plan.mappings.keys()) == set(
- direct_plan.mappings.keys()
- ), "Plan keys should match"
+ assert set(composed_plan.mappings.keys()) == set(direct_plan.mappings.keys()), "Plan keys should match"
# Execute both and compare weights
composed_weights = execute(composed_plan, base_weights, seed=0)
@@ -611,9 +560,7 @@ def test_every_prefix_consistency(self, base_setup, additive_surgery_chain):
direct = plan_surgery(base_config, configs[k])
# Verify keys match
- assert set(composed.mappings.keys()) == set(
- direct.mappings.keys()
- ), f"Prefix {k}: keys don't match"
+ assert set(composed.mappings.keys()) == set(direct.mappings.keys()), f"Prefix {k}: keys don't match"
# Execute and compare
composed_weights = execute(composed, base_weights, seed=0)
@@ -781,9 +728,7 @@ def torture_setup(self, llava_pixtral_checkpoint):
"""Set up for comprehensive torture tests."""
from safetensors.torch import load_file
- from fast_llm_external_models.apriel2.conversion.llava import (
- convert_config as convert_llava_config,
- )
+ from fast_llm_external_models.apriel2.conversion.llava import convert_config as convert_llava_config
# Load source
with open(llava_pixtral_checkpoint / "config.json") as f:
@@ -801,9 +746,7 @@ def torture_setup(self, llava_pixtral_checkpoint):
return base_config, base_weights
- def test_each_step_produces_valid_config(
- self, torture_setup, comprehensive_torture_chain
- ):
+ def test_each_step_produces_valid_config(self, torture_setup, comprehensive_torture_chain):
"""Test that each surgery step produces a valid config."""
base_config, _ = torture_setup
@@ -818,9 +761,7 @@ def test_each_step_produces_valid_config(
pytest.fail(f"Step {i+1} produced invalid config: {e}")
@requires_cuda
- def test_each_step_produces_working_model(
- self, torture_setup, comprehensive_torture_chain
- ):
+ def test_each_step_produces_working_model(self, torture_setup, comprehensive_torture_chain):
"""Test that each surgery step produces a model that can forward pass.
This is the ultimate integration test - config composition + plan building
@@ -875,9 +816,7 @@ def test_each_step_produces_working_model(
current_weights = new_weights
@requires_cuda
- def test_final_supernet_structure(
- self, torture_setup, comprehensive_torture_chain
- ):
+ def test_final_supernet_structure(self, torture_setup, comprehensive_torture_chain):
"""Verify the final architecture has supernet blocks with all 4 mixer types."""
base_config, base_weights = torture_setup
@@ -914,9 +853,7 @@ def test_final_supernet_structure(
assert outputs.logits.shape == (1, 8, config.vocab_size)
@requires_cuda
- def test_plan_config_consistency_comprehensive(
- self, torture_setup, comprehensive_torture_chain
- ):
+ def test_plan_config_consistency_comprehensive(self, torture_setup, comprehensive_torture_chain):
"""Test that incremental plan composition works for the comprehensive chain.
Note: We cannot compare to a "direct plan" because the comprehensive chain
@@ -1083,66 +1020,6 @@ def mamba_config(self):
},
}
- def test_config_composition_identical_regardless_of_init_mode(self, base_config):
- """Config composition produces same structure with init: transfer vs init: random."""
- # Surgery with init: transfer
- surgery_transfer = {
- "decoder": {
- "block": {
- "mixer": {
- "type": "stochastic",
- "main_mixer_name": "attention",
- "mixers": {
- "attention": {"type": "attention", "init": "transfer"},
- "swa": {
- "type": "attention",
- "init": "transfer",
- "sliding_window": 512,
- },
- },
- },
- },
- },
- }
-
- # Surgery with init: random
- surgery_random = {
- "decoder": {
- "block": {
- "mixer": {
- "type": "stochastic",
- "main_mixer_name": "attention",
- "mixers": {
- "attention": {"type": "attention", "init": "random"},
- "swa": {
- "type": "attention",
- "init": "random",
- "sliding_window": 512,
- },
- },
- },
- },
- },
- }
-
- # Compose configs
- result_transfer = compose_configs(base_config, surgery_transfer)
- result_random = compose_configs(base_config, surgery_random)
-
- # Both should produce identical structure (init is stripped)
- assert result_transfer == result_random, (
- "Config composition should produce identical structure regardless of init mode"
- )
-
- # Verify the structure is correct
- mixer = result_transfer["decoder"]["block"]["mixer"]
- assert mixer["type"] == "stochastic"
- assert "attention" in mixer["mixers"]
- assert "swa" in mixer["mixers"]
- # init should be stripped
- assert "init" not in mixer["mixers"]["attention"]
- assert "init" not in mixer["mixers"]["swa"]
-
def test_plan_surgery_random_succeeds_for_any_type_pair(self, mamba_config):
"""plan_surgery with init: random should succeed even for mamba -> attention."""
# This surgery changes mamba to attention with random init
@@ -1166,7 +1043,7 @@ def test_plan_surgery_random_succeeds_for_any_type_pair(self, mamba_config):
plan = plan_surgery(mamba_config, surgery)
# Verify the plan has the expected target keys
- target_keys = set(str(k) for k in plan.mappings.keys())
+ target_keys = {str(k) for k in plan.mappings.keys()}
assert any("mixer.q_proj" in k for k in target_keys)
def test_plan_surgery_transfer_fails_for_unsupported_type_pair(self, mamba_config):
@@ -1219,7 +1096,7 @@ def test_plan_surgery_transfer_succeeds_for_supported_type_pair(self, base_confi
plan = plan_surgery(base_config, surgery)
# Verify the plan has mamba target keys
- target_keys = set(str(k) for k in plan.mappings.keys())
+ target_keys = {str(k) for k in plan.mappings.keys()}
assert any("mixer.in_proj" in k for k in target_keys)
def test_stochastic_init_random_succeeds_for_any_submixer_type(self, mamba_config):
@@ -1259,7 +1136,7 @@ def test_stochastic_init_random_succeeds_for_any_submixer_type(self, mamba_confi
plan = plan_surgery(mamba_config, surgery)
# Verify both sub-mixers have target keys
- target_keys = set(str(k) for k in plan.mappings.keys())
+ target_keys = {str(k) for k in plan.mappings.keys()}
assert any("mixers.attention.q_proj" in k for k in target_keys)
assert any("mixers.swa.q_proj" in k for k in target_keys)
@@ -1294,7 +1171,7 @@ def test_mixed_init_modes_in_stochastic(self, base_config):
plan = plan_surgery(base_config, surgery)
# Verify both sub-mixers have target keys
- target_keys = set(str(k) for k in plan.mappings.keys())
+ target_keys = {str(k) for k in plan.mappings.keys()}
assert any("mixers.attention.q_proj" in k for k in target_keys)
assert any("mixers.gdn.in_proj_qkvz" in k for k in target_keys)
@@ -1313,8 +1190,8 @@ class TestMarkovianProperty:
"""
@pytest.fixture
- def attention_config(self):
- """Base config with attention."""
+ def attention_config_dict(self):
+ """Base config dict with attention mixer for compose_configs tests."""
return {
"model_type": "apriel2",
"hidden_size": 256,
@@ -1335,43 +1212,7 @@ def attention_config(self):
},
}
- @pytest.fixture
- def stochastic_config(self):
- """Config with stochastic mixer."""
- return {
- "model_type": "apriel2",
- "hidden_size": 256,
- "vocab_size": 1000,
- "decoder": {
- "type": "fixed",
- "num_blocks": 2,
- "block": {
- "mixer": {
- "type": "stochastic",
- "main_mixer_name": "attention",
- "mixers": {
- "attention": {
- "type": "attention",
- "heads": 8,
- "head_groups": 4,
- "head_size": 32,
- },
- "swa": {
- "type": "sliding_window",
- "heads": 8,
- "head_groups": 4,
- "head_size": 32,
- "window_size": 512,
- },
- },
- },
- "mlp": {"type": "mlp", "intermediate_size": 256},
- "normalization": {"type": "rms_norm", "epsilon": 1e-5},
- },
- },
- }
-
- def test_different_paths_same_config_same_plan(self, attention_config):
+ def test_different_paths_same_config_same_plan(self, attention_config_dict):
"""Two different paths to the same config produce identical plans.
Path A: attention -> stochastic{att, swa}
@@ -1398,7 +1239,7 @@ def test_different_paths_same_config_same_plan(self, attention_config):
},
},
}
- config_a = compose_configs(attention_config, surgery_a)
+ config_a = compose_configs(attention_config_dict, surgery_a)
# Path B: First add attention only, then add swa
surgery_b1 = {
@@ -1414,7 +1255,7 @@ def test_different_paths_same_config_same_plan(self, attention_config):
},
},
}
- intermediate_config = compose_configs(attention_config, surgery_b1)
+ intermediate_config = compose_configs(attention_config_dict, surgery_b1)
surgery_b2 = {
"decoder": {
@@ -1465,11 +1306,11 @@ def test_different_paths_same_config_same_plan(self, attention_config):
plan_from_b = plan_surgery(config_b, final_surgery)
# Compare plan mappings
- keys_a = set(str(k) for k in plan_from_a.mappings.keys())
- keys_b = set(str(k) for k in plan_from_b.mappings.keys())
+ keys_a = {str(k) for k in plan_from_a.mappings.keys()}
+ keys_b = {str(k) for k in plan_from_b.mappings.keys()}
assert keys_a == keys_b, "Plans from same config via different paths should be identical"
- def test_init_in_source_config_does_not_affect_plan(self, attention_config):
+ def test_init_in_source_config_does_not_affect_plan(self, attention_config_dict):
"""Manually injecting init into source config doesn't change the plan.
This tests that plan_surgery reads init from surgery, not source.
@@ -1479,8 +1320,8 @@ def test_init_in_source_config_does_not_affect_plan(self, attention_config):
import copy
# Create two copies of the config
- config_with_init = copy.deepcopy(attention_config)
- config_without_init = copy.deepcopy(attention_config)
+ config_with_init = copy.deepcopy(attention_config_dict)
+ config_without_init = copy.deepcopy(attention_config_dict)
# Manually inject init into one (bypassing compose_configs)
config_with_init["decoder"]["block"]["mixer"]["init"] = "random"
@@ -1504,238 +1345,12 @@ def test_init_in_source_config_does_not_affect_plan(self, attention_config):
plan_with = plan_surgery(config_with_init, surgery)
plan_without = plan_surgery(config_without_init, surgery)
- keys_with = set(str(k) for k in plan_with.mappings.keys())
- keys_without = set(str(k) for k in plan_without.mappings.keys())
+ keys_with = {str(k) for k in plan_with.mappings.keys()}
+ keys_without = {str(k) for k in plan_without.mappings.keys()}
# Plans should be identical - source's init field is ignored
assert keys_with == keys_without, "Plan should not depend on init in source config"
- def test_associativity_of_surgery_composition(self, attention_config):
- """Verify associativity: (A ∘ B) ∘ C == A ∘ (B ∘ C) for surgery specs.
-
- This tests that composing surgeries is associative, which is
- equivalent to Markovianity for plan creation.
- """
- surgery_a = {
- "decoder": {
- "block": {
- "mixer": {
- "type": "stochastic",
- "main_mixer_name": "attention",
- "mixers": {
- "attention": {"type": "attention", "init": "transfer"},
- },
- },
- },
- },
- }
-
- surgery_b = {
- "decoder": {
- "block": {
- "mixer": {
- "mixers": {
- "swa": {
- "type": "sliding_window",
- "init": "transfer",
- "window_size": 512,
- },
- },
- },
- },
- },
- }
-
- surgery_c = {
- "decoder": {
- "block": {
- "mixer": {
- "mixers": {
- "gdn": {
- "type": "gdn",
- "init": "random",
- "value_heads": 8,
- "key_heads": 4,
- "key_head_dim": 32,
- "value_head_dim": 32,
- "convolution_layer": {"kernel_size": 4},
- },
- },
- },
- },
- },
- }
-
- # Left association: ((attention_config ∘ A) ∘ B) ∘ C
- left_1 = compose_configs(attention_config, surgery_a)
- left_2 = compose_configs(left_1, surgery_b)
- left_result = compose_configs(left_2, surgery_c)
-
- # Right association: (attention_config ∘ A) ∘ (B ∘ C)
- # Note: B ∘ C is partial ∘ partial = deep merge of surgery specs
- bc_merged = compose_configs(surgery_b, surgery_c)
- right_1 = compose_configs(attention_config, surgery_a)
- right_result = compose_configs(right_1, bc_merged)
-
- assert left_result == right_result, "Surgery composition should be associative"
-
- def test_complete_configs_have_no_init_fields(self, attention_config):
- """Verify that compose_configs strips init from complete configs.
-
- This is the key invariant that enables Markovianity:
- - Complete configs (states) have no init fields
- - Surgery specs (transitions) have init fields
- - Plans read init from surgery, not state
- """
- surgery_with_init = {
- "decoder": {
- "block": {
- "mixer": {
- "type": "stochastic",
- "main_mixer_name": "attention",
- "mixers": {
- "attention": {"type": "attention", "init": "transfer"},
- "swa": {"type": "sliding_window", "init": "random", "window_size": 512},
- },
- },
- },
- },
- }
-
- result = compose_configs(attention_config, surgery_with_init)
-
- # Recursively check for init fields
- def has_init(obj):
- if isinstance(obj, dict):
- if "init" in obj:
- return True
- return any(has_init(v) for v in obj.values())
- if isinstance(obj, list):
- return any(has_init(v) for v in obj)
- return False
-
- assert not has_init(result), "Complete configs should have no init fields"
-
- def test_monoid_action_law_additive_surgeries(self):
- """Monoid action law HOLDS for additive surgeries.
-
- Additive surgeries (no type: declaration) support:
- apply(apply(s, t1), t2) == apply(s, t1 ∘ t2)
-
- This is because additive operations commute nicely:
- "add {a}" then "add {b}" == "add {a, b}"
- """
- # Start with stochastic (additive surgery target)
- s = {
- "model_type": "apriel2",
- "hidden_size": 256,
- "vocab_size": 1000,
- "decoder": {
- "type": "fixed",
- "num_blocks": 2,
- "block": {
- "mixer": {
- "type": "stochastic",
- "main_mixer_name": "attention",
- "mixers": {
- "attention": {"type": "attention", "heads": 8, "head_groups": 4, "head_size": 32},
- },
- },
- "mlp": {"type": "mlp", "intermediate_size": 256},
- "normalization": {"type": "rms_norm", "epsilon": 1e-5},
- },
- },
- }
-
- # Additive surgeries (no type: declaration)
- t1 = {"decoder": {"block": {"mixer": {"mixers": {"swa": {"type": "sliding_window", "window_size": 512}}}}}}
- t2 = {"decoder": {"block": {"mixer": {"mixers": {"mamba": {"type": "mamba", "d_inner": 512}}}}}}
-
- # Path A: Sequential
- s_prime = compose_configs(s, t1)
- s_double_prime_A = compose_configs(s_prime, t2)
-
- # Path B: Composed
- t1_t2 = compose_configs(t1, t2)
- s_double_prime_B = compose_configs(s, t1_t2)
-
- assert s_double_prime_A == s_double_prime_B, "Monoid action law should hold for additive surgeries"
-
- def test_monoid_action_law_replacement_surgeries_fails(self):
- """Monoid action law FAILS for replacement surgeries (by design).
-
- Replacement surgeries (type: stochastic declared) have:
- apply(apply(s, t1), t2) != apply(s, t1 ∘ t2)
-
- This is FUNDAMENTAL, not a bug:
- - Sequential: "set to {a}" then "set to {b}" → {b} (second wins)
- - Composed: merge({a}, {b}) = {a,b}, then apply → {a,b}
-
- These are genuinely different semantics. The failure documents
- the distinction between declarative composition (merge) and
- operational composition (function application).
- """
- s = {
- "model_type": "apriel2",
- "hidden_size": 256,
- "vocab_size": 1000,
- "decoder": {
- "type": "fixed",
- "num_blocks": 2,
- "block": {
- "mixer": {"type": "attention", "heads": 8, "head_groups": 4, "head_size": 32},
- "mlp": {"type": "mlp", "intermediate_size": 256},
- "normalization": {"type": "rms_norm", "epsilon": 1e-5},
- },
- },
- }
-
- # Replacement surgeries (both declare type: stochastic)
- t1 = {
- "decoder": {
- "block": {
- "mixer": {
- "type": "stochastic",
- "main_mixer_name": "attention",
- "mixers": {"attention": {"type": "attention"}},
- }
- }
- }
- }
- t2 = {
- "decoder": {
- "block": {
- "mixer": {
- "type": "stochastic",
- "main_mixer_name": "swa",
- "mixers": {"swa": {"type": "sliding_window", "window_size": 512}},
- }
- }
- }
- }
-
- # Path A: Sequential (second replacement wins)
- s_prime = compose_configs(s, t1)
- s_double_prime_A = compose_configs(s_prime, t2)
-
- # Path B: Composed (declarations merged)
- t1_t2 = compose_configs(t1, t2)
- s_double_prime_B = compose_configs(s, t1_t2)
-
- # They should be DIFFERENT (law fails)
- assert s_double_prime_A != s_double_prime_B, (
- "Monoid action law should FAIL for replacement surgeries"
- )
-
- # Verify the specific difference:
- # Sequential: only swa (second replacement wins)
- # Composed: both attention and swa (merged declarations)
- mixers_A = set(s_double_prime_A["decoder"]["block"]["mixer"]["mixers"].keys())
- mixers_B = set(s_double_prime_B["decoder"]["block"]["mixer"]["mixers"].keys())
-
- assert mixers_A == {"swa"}, "Sequential: second replacement wins"
- assert mixers_B == {"attention", "swa"}, "Composed: declarations merged"
-
class TestCyclingSurgeryGeneration:
"""Tests for the cycling surgery generation functions.
@@ -1936,7 +1551,7 @@ def test_expand_surgery_chain_adds_cycling(self):
# Verify restore flag
assert expanded[0][2] is False # surgery - not restore
assert expanded[1][2] is False # cycle - not restore
- assert expanded[2][2] is True # restore
+ assert expanded[2][2] is True # restore
def test_expand_surgery_chain_preserves_invariant(self):
"""Test that cycling leaves the chain state invariant."""
@@ -1980,3 +1595,151 @@ def test_expand_surgery_chain_preserves_invariant(self):
# After cycling and restore, we should be back to the same state
assert current_config == config_after_original
+
+
+class TestBiasSurgeryChain:
+ """Torture tests for per-layer bias inheritance through surgery operations.
+
+ Uses apriel2_config_with_bias + bias_surgery_chain to test that:
+ - Qwen-style per-layer attention bias (QKV enabled, O disabled) survives surgery
+ - Non-gated MLP per-layer bias (layer_1 enabled, layer_2 disabled) survives surgery
+ - Bias settings are correctly inherited by new sub-mixers
+ - Bias is correctly tracked in surgery plans
+ """
+
+ @pytest.fixture
+ def bias_source_config(self, apriel2_config_with_bias):
+ """Convert Apriel2Config to dict for surgery operations."""
+ return apriel2_config_with_bias.to_dict()
+
+ def test_bias_survives_stochastic_wrapper(self, bias_source_config, bias_surgery_chain):
+ """Test that bias settings survive wrapping in stochastic mixer."""
+ # Apply first surgery (wrap in stochastic)
+ result = compose_configs(bias_source_config, bias_surgery_chain[0])
+
+ # Check attention sub-mixer inherited bias settings
+ mixer = result["decoder"]["block"]["mixer"]
+ assert mixer["type"] == "stochastic"
+
+ attn_mixer = mixer["mixers"]["attention"]
+ assert attn_mixer["query_layer"]["bias"]["enabled"] is True
+ assert attn_mixer["key_layer"]["bias"]["enabled"] is True
+ assert attn_mixer["value_layer"]["bias"]["enabled"] is True
+ assert attn_mixer["dense_layer"]["bias"]["enabled"] is False
+
+ # Check MLP bias survived
+ mlp = result["decoder"]["block"]["mlp"]
+ assert mlp["layer_1"]["bias"]["enabled"] is True
+ assert mlp["layer_2"]["bias"]["enabled"] is False
+
+ def test_new_submixer_inherits_bias(self, bias_source_config, bias_surgery_chain):
+ """Test that new sub-mixers inherit bias from source attention."""
+ # Apply S1 + S2 (wrap in stochastic, add sliding_window)
+ config = bias_source_config
+ for surgery in bias_surgery_chain[:2]:
+ config = compose_configs(config, surgery)
+
+ # sliding_window should inherit bias from source attention
+ mixer = config["decoder"]["block"]["mixer"]
+ sw_mixer = mixer["mixers"]["sliding_window"]
+
+ assert sw_mixer["query_layer"]["bias"]["enabled"] is True
+ assert sw_mixer["key_layer"]["bias"]["enabled"] is True
+ assert sw_mixer["value_layer"]["bias"]["enabled"] is True
+ assert sw_mixer["dense_layer"]["bias"]["enabled"] is False
+
+ def test_full_bias_chain_produces_valid_config(self, bias_source_config, bias_surgery_chain):
+ """Test that full bias surgery chain produces valid config."""
+ config = bias_source_config
+ for surgery in bias_surgery_chain:
+ config = compose_configs(config, surgery)
+
+ # Verify final config structure
+ mixer = config["decoder"]["block"]["mixer"]
+ assert mixer["type"] == "stochastic"
+ assert "attention" in mixer["mixers"]
+ assert "sliding_window" in mixer["mixers"]
+ assert "full_bias_attn" in mixer["mixers"]
+
+ # attention and sliding_window inherit Qwen-style bias
+ for name in ["attention", "sliding_window"]:
+ sub = mixer["mixers"][name]
+ assert sub["query_layer"]["bias"]["enabled"] is True
+ assert sub["dense_layer"]["bias"]["enabled"] is False
+
+ # full_bias_attn has add_linear_biases=True but per-layer settings inherited from
+ # source take precedence, so O bias is still disabled
+ full_bias = mixer["mixers"]["full_bias_attn"]
+ assert full_bias.get("add_linear_biases") is True
+ # Per-layer dense_layer.bias.enabled=False inherited from source takes precedence
+ assert full_bias["dense_layer"]["bias"]["enabled"] is False
+
+ def test_bias_plan_has_correct_mappings(self, bias_source_config, bias_surgery_chain):
+ """Test that surgery plan correctly includes/excludes bias weight mappings."""
+ # Compose config first to get full target config with inherited bias settings
+ target_config = compose_configs(bias_source_config, bias_surgery_chain[0])
+ plan = plan_surgery(bias_source_config, target_config)
+ mapping_strs = [str(k) for k in plan.mappings.keys()]
+
+ # Should have q_proj.bias (enabled)
+ q_bias = [m for m in mapping_strs if "q_proj.bias" in m]
+ assert len(q_bias) > 0, "Should have q_proj.bias mappings"
+
+ # Should NOT have o_proj.bias (disabled)
+ o_bias = [m for m in mapping_strs if "o_proj.bias" in m]
+ assert len(o_bias) == 0, "Should not have o_proj.bias mappings"
+
+ # Should have up_proj.bias (layer_1 enabled)
+ up_bias = [m for m in mapping_strs if "up_proj.bias" in m]
+ assert len(up_bias) > 0, "Should have up_proj.bias mappings"
+
+ # Should NOT have down_proj.bias (layer_2 disabled)
+ down_bias = [m for m in mapping_strs if "down_proj.bias" in m]
+ assert len(down_bias) == 0, "Should not have down_proj.bias mappings"
+
+ def test_bias_chain_produces_working_model(self, bias_source_config, bias_surgery_chain):
+ """Test that bias surgery chain produces a working model."""
+ from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForCausalLM
+
+ # Apply full chain
+ config = bias_source_config
+ for surgery in bias_surgery_chain:
+ config = compose_configs(config, surgery)
+
+ # Create model
+ apriel_config = Apriel2Config(**config)
+ model = Apriel2ForCausalLM(apriel_config)
+ model.eval()
+
+ # Verify model structure has correct biases
+ block = model.model.decoder.blocks[0]
+
+ # attention sub-mixer should have QKV bias, no O bias
+ attn = block.mixer.mixers["attention"]
+ assert attn.q_proj.bias is not None
+ assert attn.k_proj.bias is not None
+ assert attn.v_proj.bias is not None
+ assert attn.o_proj.bias is None
+
+ # sliding_window should also inherit bias settings
+ sw = block.mixer.mixers["sliding_window"]
+ assert sw.q_proj.bias is not None
+ assert sw.o_proj.bias is None
+
+ # full_bias_attn inherits per-layer bias from source (even with add_linear_biases=True,
+ # per-layer settings take precedence in same-type inheritance)
+ full_bias = block.mixer.mixers["full_bias_attn"]
+ assert full_bias.q_proj.bias is not None
+ # O bias is disabled because inherited per-layer dense_layer.bias.enabled=False
+ # takes precedence over add_linear_biases=True
+ assert full_bias.o_proj.bias is None
+
+ # MLP should have layer_1 bias, no layer_2 bias
+ assert block.mlp.up_proj.bias is not None
+ assert block.mlp.down_proj.bias is None
+
+ # Forward pass should work
+ input_ids = torch.randint(0, config["vocab_size"], (1, 10))
+ with torch.no_grad():
+ outputs = model(input_ids, use_cache=False)
+ assert outputs.logits.shape == (1, 10, config["vocab_size"])
diff --git a/fast_llm_external_models/tests/test_apriel2/test_convert_from_llava.py b/fast_llm_external_models/tests/test_apriel2/test_convert_from_llava.py
index a437f920d..f96f5ac40 100644
--- a/fast_llm_external_models/tests/test_apriel2/test_convert_from_llava.py
+++ b/fast_llm_external_models/tests/test_apriel2/test_convert_from_llava.py
@@ -14,23 +14,15 @@
"""
import json
-from pathlib import Path
-import pytest
import torch
from safetensors import safe_open
-from safetensors.torch import save_file
from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config
-from fast_llm_external_models.apriel2.conversion import (
- convert_llava_config as convert_config,
- execute,
- plan_llava_to_apriel2,
- plan_surgery,
-)
+from fast_llm_external_models.apriel2.conversion import convert_llava_config as convert_config
+from fast_llm_external_models.apriel2.conversion import execute, plan_llava_to_apriel2, plan_surgery
from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForConditionalGeneration
-
# =============================================================================
# Config Conversion Tests
# =============================================================================
@@ -330,9 +322,9 @@ def test_plan_keys_match_model_state_dict(self, llava_pixtral_checkpoint):
extra_in_plan = plan_keys - model_keys
# Filter out expected missing keys (caches, positions, etc.)
- missing_in_plan = {k for k in missing_in_plan if not any(
- skip in k.lower() for skip in ["cache", "position", "mask"]
- )}
+ missing_in_plan = {
+ k for k in missing_in_plan if not any(skip in k.lower() for skip in ["cache", "position", "mask"])
+ }
assert not missing_in_plan, f"Model keys not in plan: {sorted(missing_in_plan)[:10]}"
assert not extra_in_plan, f"Plan keys not in model: {sorted(extra_in_plan)[:10]}"
diff --git a/fast_llm_external_models/tests/test_apriel2/test_equivalence.py b/fast_llm_external_models/tests/test_apriel2/test_equivalence.py
index c59ed2000..9b3eb4efe 100644
--- a/fast_llm_external_models/tests/test_apriel2/test_equivalence.py
+++ b/fast_llm_external_models/tests/test_apriel2/test_equivalence.py
@@ -23,9 +23,6 @@
import torch
from transformers import LlavaForConditionalGeneration
-from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForConditionalGeneration
-
-
# =============================================================================
# Input Configuration
# =============================================================================
@@ -487,8 +484,10 @@ def test_batch_processing_behavior(self, model_pair):
batch_tgt = target.get_image_features(pixel_values).view(-1, batch_src.shape[-1])
# Sequential processing
- singles_src = [get_pixtral_vision_features(source, pixel_values[i:i+1]) for i in range(3)]
- singles_tgt = [target.get_image_features(pixel_values[i:i+1]).view(-1, batch_src.shape[-1]) for i in range(3)]
+ singles_src = [get_pixtral_vision_features(source, pixel_values[i : i + 1]) for i in range(3)]
+ singles_tgt = [
+ target.get_image_features(pixel_values[i : i + 1]).view(-1, batch_src.shape[-1]) for i in range(3)
+ ]
single_concat_src = torch.cat(singles_src, dim=0)
single_concat_tgt = torch.cat(singles_tgt, dim=0)
@@ -500,9 +499,9 @@ def test_batch_processing_behavior(self, model_pair):
print(f"Apriel2 batch vs sequential: {tgt_diff:.6f}")
# Both should have the same behavior (within FP tolerance)
- assert abs(src_diff - tgt_diff) < 1e-6, (
- f"Batch processing behavior differs: src={src_diff:.6f}, tgt={tgt_diff:.6f}"
- )
+ assert (
+ abs(src_diff - tgt_diff) < 1e-6
+ ), f"Batch processing behavior differs: src={src_diff:.6f}, tgt={tgt_diff:.6f}"
if __name__ == "__main__":
diff --git a/fast_llm_external_models/tests/test_apriel2/test_expr_plan.py b/fast_llm_external_models/tests/test_apriel2/test_expr_plan.py
index c487ab3a3..2dccac5ad 100644
--- a/fast_llm_external_models/tests/test_apriel2/test_expr_plan.py
+++ b/fast_llm_external_models/tests/test_apriel2/test_expr_plan.py
@@ -1,15 +1,13 @@
"""Tests for the expression-based plan system."""
import json
+
import pytest
import torch
-from fast_llm_external_models.tests.test_apriel2.conftest import requires_cuda
-
from fast_llm_external_models.apriel2.conversion import (
Concat,
EvalKwargs,
- Expr,
ExprAdapter,
ExprPlan,
Init,
@@ -18,10 +16,9 @@
Slice,
StreamingExecutor,
W,
- compose,
execute,
- fuse,
full_slice,
+ fuse,
make_slice,
plan_dil_attention_to_gdn,
plan_kil_attention_to_kda,
@@ -31,6 +28,7 @@
slice_spec,
substitute,
)
+from fast_llm_external_models.tests.test_apriel2.conftest import requires_cuda
def make_eval_kwargs(
@@ -219,10 +217,13 @@ def test_substitute_init_unchanged(self):
def test_substitute_complex(self):
"""Substitute handles complex nested expressions."""
# Concat of Slice(Ref) and Init
- expr = Concat(exprs=(
- Slice(expr=Ref(key=W("a")), slices=((0, 5, None),)),
- Init(shape=(5,), init_type="zeros"),
- ), dim=0)
+ expr = Concat(
+ exprs=(
+ Slice(expr=Ref(key=W("a")), slices=((0, 5, None),)),
+ Init(shape=(5,), init_type="zeros"),
+ ),
+ dim=0,
+ )
bindings = {W("a"): Ref(key=W("source"))}
result = substitute(expr, bindings)
@@ -238,7 +239,13 @@ class TestFuse:
def test_fuse_flatten_concat(self):
"""Fuse flattens nested Concat with same dim."""
inner = Concat(exprs=(Ref(key=W("a")), Ref(key=W("b"))), dim=0)
- outer = Concat(exprs=(inner, Ref(key=W("c")),), dim=0)
+ outer = Concat(
+ exprs=(
+ inner,
+ Ref(key=W("c")),
+ ),
+ dim=0,
+ )
result = fuse(outer)
assert isinstance(result, Concat)
@@ -250,7 +257,13 @@ def test_fuse_flatten_concat(self):
def test_fuse_no_flatten_different_dim(self):
"""Fuse doesn't flatten Concat with different dim."""
inner = Concat(exprs=(Ref(key=W("a")), Ref(key=W("b"))), dim=1)
- outer = Concat(exprs=(inner, Ref(key=W("c")),), dim=0)
+ outer = Concat(
+ exprs=(
+ inner,
+ Ref(key=W("c")),
+ ),
+ dim=0,
+ )
result = fuse(outer)
assert isinstance(result, Concat)
@@ -340,28 +353,34 @@ class TestExprPlan:
def test_plan_define_and_access(self):
"""Plan stores and retrieves expressions."""
- plan = ExprPlan(mappings={
- W("target"): Ref(key=W("source")),
- })
+ plan = ExprPlan(
+ mappings={
+ W("target"): Ref(key=W("source")),
+ }
+ )
assert W("target") in plan
assert isinstance(plan[W("target")], Ref)
def test_plan_source_keys(self):
"""Plan identifies all source references."""
- plan = ExprPlan(mappings={
- W("a"): Ref(key=W("x")),
- W("b"): Concat(exprs=(Ref(key=W("y")), Ref(key=W("z"))), dim=0),
- W("c"): Init(shape=(10,), init_type="zeros"),
- })
+ plan = ExprPlan(
+ mappings={
+ W("a"): Ref(key=W("x")),
+ W("b"): Concat(exprs=(Ref(key=W("y")), Ref(key=W("z"))), dim=0),
+ W("c"): Init(shape=(10,), init_type="zeros"),
+ }
+ )
assert plan.source_keys() == {W("x"), W("y"), W("z")}
def test_plan_target_keys(self):
"""Plan identifies all target keys."""
- plan = ExprPlan(mappings={
- W("a"): Ref(key=W("x")),
- W("b"): Ref(key=W("y")),
- })
+ plan = ExprPlan(
+ mappings={
+ W("a"): Ref(key=W("x")),
+ W("b"): Ref(key=W("y")),
+ }
+ )
assert plan.target_keys() == {W("a"), W("b")}
@@ -386,9 +405,17 @@ def test_plan_summary(self):
def test_plan_fuse(self):
"""Plan fuse applies optimizations."""
inner = Concat(exprs=(Ref(key=W("a")), Ref(key=W("b"))), dim=0)
- plan = ExprPlan(mappings={
- W("out"): Concat(exprs=(inner, Ref(key=W("c")),), dim=0),
- })
+ plan = ExprPlan(
+ mappings={
+ W("out"): Concat(
+ exprs=(
+ inner,
+ Ref(key=W("c")),
+ ),
+ dim=0,
+ ),
+ }
+ )
fused = plan.fuse()
assert isinstance(fused[W("out")], Concat)
@@ -532,9 +559,11 @@ class TestStreamingExecution:
def test_execute_simple(self):
"""Execute simple plan."""
- plan = ExprPlan(mappings={
- W("out"): Ref(key=W("in")),
- })
+ plan = ExprPlan(
+ mappings={
+ W("out"): Ref(key=W("in")),
+ }
+ )
sources = {W("in"): torch.tensor([1.0, 2.0, 3.0])}
result = execute(plan, sources, seed=42)
@@ -544,9 +573,11 @@ def test_execute_simple(self):
def test_execute_concat(self):
"""Execute plan with Concat."""
- plan = ExprPlan(mappings={
- W("combined"): Concat(exprs=(Ref(key=W("a")), Ref(key=W("b"))), dim=0),
- })
+ plan = ExprPlan(
+ mappings={
+ W("combined"): Concat(exprs=(Ref(key=W("a")), Ref(key=W("b"))), dim=0),
+ }
+ )
sources = {
W("a"): torch.ones(2, 3),
@@ -559,14 +590,19 @@ def test_execute_concat(self):
def test_execute_mil_like(self):
"""Execute MIL-like Concat of Slices and Init."""
# Simulated MIL: in_proj = [z, x, B, C]
- plan = ExprPlan(mappings={
- W("in_proj"): Concat(exprs=(
- Init(shape=(4, 8), init_type="zeros"), # z
- Slice(expr=Ref(key=W("v")), slices=((0, 2, None), (None, None, None))), # x
- Slice(expr=Ref(key=W("k")), slices=((0, 2, None), (None, None, None))), # B
- Slice(expr=Ref(key=W("q")), slices=((0, 4, None), (None, None, None))), # C
- ), dim=0),
- })
+ plan = ExprPlan(
+ mappings={
+ W("in_proj"): Concat(
+ exprs=(
+ Init(shape=(4, 8), init_type="zeros"), # z
+ Slice(expr=Ref(key=W("v")), slices=((0, 2, None), (None, None, None))), # x
+ Slice(expr=Ref(key=W("k")), slices=((0, 2, None), (None, None, None))), # B
+ Slice(expr=Ref(key=W("q")), slices=((0, 4, None), (None, None, None))), # C
+ ),
+ dim=0,
+ ),
+ }
+ )
sources = {
W("q"): torch.ones(4, 8),
@@ -583,11 +619,13 @@ def test_execute_mil_like(self):
def test_streaming_execution(self):
"""Streaming executor processes all targets."""
- plan = ExprPlan(mappings={
- W("out1"): Ref(key=W("shared")),
- W("out2"): Ref(key=W("shared")),
- W("out3"): Ref(key=W("unique")),
- })
+ plan = ExprPlan(
+ mappings={
+ W("out1"): Ref(key=W("shared")),
+ W("out2"): Ref(key=W("shared")),
+ W("out3"): Ref(key=W("unique")),
+ }
+ )
load_calls = []
@@ -858,25 +896,23 @@ def test_plan_dil_execution(self):
key_dim = 64
value_dim = 64
- head_k_dim = 16
- head_v_dim = 16
conv_dim = 2 * key_dim + value_dim # 192
# Create attention weights with per-head distinctive values
# Q: each head gets value (head_idx + 1)
q_weight = torch.zeros(64, 64)
for h in range(4):
- q_weight[h*16:(h+1)*16, :] = float(h + 1)
+ q_weight[h * 16 : (h + 1) * 16, :] = float(h + 1)
# K: each head gets value (head_idx + 1) * 10
k_weight = torch.zeros(64, 64)
for h in range(4):
- k_weight[h*16:(h+1)*16, :] = float((h + 1) * 10)
+ k_weight[h * 16 : (h + 1) * 16, :] = float((h + 1) * 10)
# V: each head gets value (head_idx + 1) * 100
v_weight = torch.zeros(64, 64)
for h in range(4):
- v_weight[h*16:(h+1)*16, :] = float((h + 1) * 100)
+ v_weight[h * 16 : (h + 1) * 16, :] = float((h + 1) * 100)
sources = {
W("attn.q_proj.weight"): q_weight,
@@ -894,30 +930,23 @@ def test_plan_dil_execution(self):
# Q_all (rows 0-63): heads 0,1,2,3 concatenated
for h in range(4):
- assert torch.allclose(
- in_proj_qkvz[h*16:(h+1)*16],
- torch.full((16, 64), float(h + 1))
- )
+ assert torch.allclose(in_proj_qkvz[h * 16 : (h + 1) * 16], torch.full((16, 64), float(h + 1)))
# K_all (rows 64-127): heads 0,1,2,3 concatenated
for h in range(4):
assert torch.allclose(
- in_proj_qkvz[key_dim + h*16:key_dim + (h+1)*16],
- torch.full((16, 64), float((h + 1) * 10))
+ in_proj_qkvz[key_dim + h * 16 : key_dim + (h + 1) * 16], torch.full((16, 64), float((h + 1) * 10))
)
# V_all (rows 128-191): heads 0,1,2,3 concatenated
for h in range(4):
assert torch.allclose(
- in_proj_qkvz[2*key_dim + h*16:2*key_dim + (h+1)*16],
- torch.full((16, 64), float((h + 1) * 100))
+ in_proj_qkvz[2 * key_dim + h * 16 : 2 * key_dim + (h + 1) * 16],
+ torch.full((16, 64), float((h + 1) * 100)),
)
# Z_all (rows 192-255): zeros
- assert torch.allclose(
- in_proj_qkvz[2*key_dim + value_dim:],
- torch.zeros(value_dim, 64)
- )
+ assert torch.allclose(in_proj_qkvz[2 * key_dim + value_dim :], torch.zeros(value_dim, 64))
# in_proj_ba should be zeros
in_proj_ba = result[W("in_proj_ba.weight")]
@@ -971,17 +1000,17 @@ def test_plan_dil_execution_gqa(self):
# Q: 4 heads, each with value (head_idx + 1)
q_weight = torch.zeros(64, 64)
for h in range(4):
- q_weight[h*16:(h+1)*16, :] = float(h + 1)
+ q_weight[h * 16 : (h + 1) * 16, :] = float(h + 1)
# K: 2 kv_heads, each with value (head_idx + 1) * 10
k_weight = torch.zeros(32, 64)
for h in range(2):
- k_weight[h*16:(h+1)*16, :] = float((h + 1) * 10)
+ k_weight[h * 16 : (h + 1) * 16, :] = float((h + 1) * 10)
# V: 2 kv_heads, each with value (head_idx + 1) * 100
v_weight = torch.zeros(32, 64)
for h in range(2):
- v_weight[h*16:(h+1)*16, :] = float((h + 1) * 100)
+ v_weight[h * 16 : (h + 1) * 16, :] = float((h + 1) * 100)
sources = {
W("attn.q_proj.weight"): q_weight,
@@ -1007,22 +1036,22 @@ def test_plan_dil_execution_gqa(self):
# K_all (rows 32-63): k_heads 0,1 (maps to source K heads 0,1 via modulo)
# k_head 0 → source K head 0 (value 10)
- assert torch.allclose(in_proj_qkvz[key_dim:key_dim+16], torch.full((16, 64), 10.0))
+ assert torch.allclose(in_proj_qkvz[key_dim : key_dim + 16], torch.full((16, 64), 10.0))
# k_head 1 → source K head 1 (value 20)
- assert torch.allclose(in_proj_qkvz[key_dim+16:key_dim+32], torch.full((16, 64), 20.0))
+ assert torch.allclose(in_proj_qkvz[key_dim + 16 : key_dim + 32], torch.full((16, 64), 20.0))
# V_all (rows 64-127): 4 v_heads, tiled from 2 source KV heads via modulo
# v_head 0 → src_v_head 0 (value 100)
- assert torch.allclose(in_proj_qkvz[2*key_dim:2*key_dim+16], torch.full((16, 64), 100.0))
+ assert torch.allclose(in_proj_qkvz[2 * key_dim : 2 * key_dim + 16], torch.full((16, 64), 100.0))
# v_head 1 → src_v_head 1 (value 200)
- assert torch.allclose(in_proj_qkvz[2*key_dim+16:2*key_dim+32], torch.full((16, 64), 200.0))
+ assert torch.allclose(in_proj_qkvz[2 * key_dim + 16 : 2 * key_dim + 32], torch.full((16, 64), 200.0))
# v_head 2 → src_v_head 0 (value 100, tiled)
- assert torch.allclose(in_proj_qkvz[2*key_dim+32:2*key_dim+48], torch.full((16, 64), 100.0))
+ assert torch.allclose(in_proj_qkvz[2 * key_dim + 32 : 2 * key_dim + 48], torch.full((16, 64), 100.0))
# v_head 3 → src_v_head 1 (value 200, tiled)
- assert torch.allclose(in_proj_qkvz[2*key_dim+48:2*key_dim+64], torch.full((16, 64), 200.0))
+ assert torch.allclose(in_proj_qkvz[2 * key_dim + 48 : 2 * key_dim + 64], torch.full((16, 64), 200.0))
# Z_all (rows 128-191): zeros
- assert torch.allclose(in_proj_qkvz[2*key_dim+value_dim:], torch.zeros(value_dim, 64))
+ assert torch.allclose(in_proj_qkvz[2 * key_dim + value_dim :], torch.zeros(value_dim, 64))
def test_plan_kil_attention_to_kda(self):
"""AIK plan produces correct structure for attention → KDA conversion."""
@@ -1188,6 +1217,7 @@ def test_compose_llava_to_mamba(self, llava_pixtral_config, apriel2_config_stoch
# Build surgery plan (need intermediate config)
from fast_llm_external_models.apriel2.conversion.llava import convert_config
+
intermediate_config = convert_config(llava_pixtral_config)
target_config = apriel2_config_stochastic.to_dict()
surgery_plan = plan_surgery(intermediate_config, target_config)
@@ -1210,6 +1240,7 @@ def test_execute_composed_pipeline(self, llava_pixtral_checkpoint):
"""
import json
from pathlib import Path
+
from safetensors.torch import load_file
# Load config
@@ -1448,10 +1479,9 @@ def test_comprehensive_conversion_all_mixer_types(self, llava_pixtral_checkpoint
the conversion produced correct keys and shapes.
"""
import json
- from pathlib import Path
from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config
- from fast_llm_external_models.apriel2.convert import build_plan, convert
+ from fast_llm_external_models.apriel2.convert import convert
from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForConditionalGeneration
# Load LLaVA config
@@ -1477,11 +1507,11 @@ def test_comprehensive_conversion_all_mixer_types(self, llava_pixtral_checkpoint
"type": "pattern",
"num_blocks": 5,
"pattern": [
- "attn", # 0: attention → attention (passthrough)
- "mamba", # 1: attention → mamba (MIL)
- "gdn", # 2: attention → gated_delta_net (DIL)
- "stoch_am", # 3: attention → stochastic(attention + mamba)
- "stoch_sg", # 4: attention → stochastic(swa + gdn)
+ "attn", # 0: attention → attention (passthrough)
+ "mamba", # 1: attention → mamba (MIL)
+ "gdn", # 2: attention → gated_delta_net (DIL)
+ "stoch_am", # 3: attention → stochastic(attention + mamba)
+ "stoch_sg", # 4: attention → stochastic(swa + gdn)
],
"blocks": {
# Pure attention (passthrough from source)
@@ -1609,7 +1639,8 @@ def test_comprehensive_conversion_all_mixer_types(self, llava_pixtral_checkpoint
"type": "attention",
"heads": llava_config["vision_config"]["num_attention_heads"],
"head_groups": llava_config["vision_config"]["num_attention_heads"],
- "head_size": llava_config["vision_config"]["hidden_size"] // llava_config["vision_config"]["num_attention_heads"],
+ "head_size": llava_config["vision_config"]["hidden_size"]
+ // llava_config["vision_config"]["num_attention_heads"],
"add_linear_biases": False,
"causal": False,
"rotary": {
@@ -1688,7 +1719,6 @@ def test_conversion_plan_targets_match_model_state_dict(self, llava_pixtral_conf
This test validates the plan WITHOUT executing it, by comparing
plan target keys against what the model expects.
"""
- import json
from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config
from fast_llm_external_models.apriel2.convert import build_plan
@@ -1703,7 +1733,7 @@ def test_conversion_plan_targets_match_model_state_dict(self, llava_pixtral_conf
expected_keys = set(model.state_dict().keys())
# Get plan target keys
- plan_target_keys = set(str(k) for k in plan.target_keys())
+ plan_target_keys = {str(k) for k in plan.target_keys()}
# Compare
missing_from_plan = expected_keys - plan_target_keys
@@ -1711,3 +1741,214 @@ def test_conversion_plan_targets_match_model_state_dict(self, llava_pixtral_conf
assert not missing_from_plan, f"Plan missing keys that model expects: {sorted(missing_from_plan)}"
assert not extra_in_plan, f"Plan has extra keys model doesn't expect: {sorted(extra_in_plan)}"
+
+
+class TestBiasPlanGeneration:
+ """Test that surgery plans correctly handle per-layer bias configurations.
+
+ These tests verify that plan_surgery correctly includes/excludes bias
+ weight mappings based on the per-layer bias settings:
+ - query_layer.bias.enabled, key_layer.bias.enabled, etc. for attention
+ - layer_1.bias.enabled, layer_2.bias.enabled for MLP
+ """
+
+ @pytest.fixture
+ def source_config_with_bias(self):
+ """Source config with Qwen-style bias (QKV enabled, O disabled)."""
+ return {
+ "model_type": "apriel2",
+ "hidden_size": 256,
+ "vocab_size": 1000,
+ "decoder": {
+ "type": "fixed",
+ "num_blocks": 2,
+ "block": {
+ "mixer": {
+ "type": "attention",
+ "heads": 8,
+ "head_groups": 4,
+ "head_size": 32,
+ "rotary": {"type": "mistral_1d", "theta": 10000.0},
+ # Qwen-style: QKV bias enabled, O bias disabled
+ "query_layer": {"bias": {"enabled": True}},
+ "key_layer": {"bias": {"enabled": True}},
+ "value_layer": {"bias": {"enabled": True}},
+ "dense_layer": {"bias": {"enabled": False}},
+ },
+ "mlp": {
+ "type": "mlp",
+ "intermediate_size": 512,
+ "gated": False,
+ # Per-layer MLP bias: layer_1 enabled, layer_2 disabled
+ "layer_1": {"bias": {"enabled": True}},
+ "layer_2": {"bias": {"enabled": False}},
+ },
+ "normalization": {"type": "rms_norm", "epsilon": 1e-5},
+ },
+ },
+ }
+
+ def test_plan_includes_enabled_attention_biases(self, source_config_with_bias):
+ """Surgery plan includes bias mappings for enabled attention biases."""
+ from fast_llm_external_models.apriel2.conversion.config import compose_configs
+ from fast_llm_external_models.apriel2.conversion.converters import plan_surgery
+
+ target_config = compose_configs(
+ source_config_with_bias,
+ {
+ "decoder": {
+ "block": {
+ "mixer": {
+ "type": "stochastic",
+ "main_mixer_name": "attention",
+ "mixers": {
+ "attention": {"init": "transfer"},
+ },
+ },
+ "mlp": {"init": "transfer"},
+ },
+ },
+ },
+ )
+
+ plan = plan_surgery(source_config_with_bias, target_config)
+ mapping_strs = [str(k) for k in plan.mappings.keys()]
+
+ # Should have q_proj.bias, k_proj.bias, v_proj.bias mappings
+ q_bias = [m for m in mapping_strs if "q_proj.bias" in m]
+ k_bias = [m for m in mapping_strs if "k_proj.bias" in m]
+ v_bias = [m for m in mapping_strs if "v_proj.bias" in m]
+
+ assert len(q_bias) > 0, "Should have q_proj.bias mappings"
+ assert len(k_bias) > 0, "Should have k_proj.bias mappings"
+ assert len(v_bias) > 0, "Should have v_proj.bias mappings"
+
+ def test_plan_excludes_disabled_attention_biases(self, source_config_with_bias):
+ """Surgery plan excludes bias mappings for disabled attention biases."""
+ from fast_llm_external_models.apriel2.conversion.config import compose_configs
+ from fast_llm_external_models.apriel2.conversion.converters import plan_surgery
+
+ target_config = compose_configs(
+ source_config_with_bias,
+ {
+ "decoder": {
+ "block": {
+ "mixer": {
+ "type": "stochastic",
+ "main_mixer_name": "attention",
+ "mixers": {
+ "attention": {"init": "transfer"},
+ },
+ },
+ "mlp": {"init": "transfer"},
+ },
+ },
+ },
+ )
+
+ plan = plan_surgery(source_config_with_bias, target_config)
+ mapping_strs = [str(k) for k in plan.mappings.keys()]
+
+ # Should NOT have o_proj.bias mappings (disabled)
+ o_bias = [m for m in mapping_strs if "o_proj.bias" in m]
+ assert len(o_bias) == 0, f"Should not have o_proj.bias mappings, found: {o_bias}"
+
+ def test_plan_includes_enabled_mlp_biases(self, source_config_with_bias):
+ """Surgery plan includes bias mappings for enabled MLP biases."""
+ from fast_llm_external_models.apriel2.conversion.config import compose_configs
+ from fast_llm_external_models.apriel2.conversion.converters import plan_surgery
+
+ target_config = compose_configs(
+ source_config_with_bias,
+ {
+ "decoder": {
+ "block": {
+ "mixer": {
+ "type": "stochastic",
+ "main_mixer_name": "attention",
+ "mixers": {
+ "attention": {"init": "transfer"},
+ },
+ },
+ "mlp": {"init": "transfer"},
+ },
+ },
+ },
+ )
+
+ plan = plan_surgery(source_config_with_bias, target_config)
+ mapping_strs = [str(k) for k in plan.mappings.keys()]
+
+ # Should have up_proj.bias (layer_1) mappings
+ up_bias = [m for m in mapping_strs if "up_proj.bias" in m]
+ assert len(up_bias) > 0, "Should have up_proj.bias mappings"
+
+ def test_plan_excludes_disabled_mlp_biases(self, source_config_with_bias):
+ """Surgery plan excludes bias mappings for disabled MLP biases."""
+ from fast_llm_external_models.apriel2.conversion.config import compose_configs
+ from fast_llm_external_models.apriel2.conversion.converters import plan_surgery
+
+ target_config = compose_configs(
+ source_config_with_bias,
+ {
+ "decoder": {
+ "block": {
+ "mixer": {
+ "type": "stochastic",
+ "main_mixer_name": "attention",
+ "mixers": {
+ "attention": {"init": "transfer"},
+ },
+ },
+ "mlp": {"init": "transfer"},
+ },
+ },
+ },
+ )
+
+ plan = plan_surgery(source_config_with_bias, target_config)
+ mapping_strs = [str(k) for k in plan.mappings.keys()]
+
+ # Should NOT have down_proj.bias (layer_2) mappings
+ down_bias = [m for m in mapping_strs if "down_proj.bias" in m]
+ assert len(down_bias) == 0, f"Should not have down_proj.bias mappings, found: {down_bias}"
+
+ def test_plan_random_init_creates_init_expressions_for_bias(self, source_config_with_bias):
+ """Random init creates Init expressions for bias weights."""
+ from fast_llm_external_models.apriel2.conversion.converters import plan_surgery
+
+ # Surgery spec - pass directly to plan_surgery (NOT composed, to preserve init)
+ surgery = {
+ "decoder": {
+ "block": {
+ "mixer": {
+ "type": "stochastic",
+ "main_mixer_name": "attention",
+ "mixers": {
+ "attention": {"init": "transfer"},
+ "new_attention": {
+ "type": "attention",
+ "init": "random", # This triggers random init
+ "heads": 8,
+ "head_groups": 4,
+ "head_size": 32,
+ "rotary": {"type": "mistral_1d", "theta": 10000.0},
+ "add_linear_biases": True, # All biases enabled
+ },
+ },
+ },
+ },
+ },
+ }
+
+ # Pass surgery spec directly - init fields are preserved
+ plan = plan_surgery(source_config_with_bias, surgery)
+
+ # Check that new_attention biases use Init expressions
+ new_mixer_bias_keys = [k for k in plan.mappings.keys() if "new_attention" in str(k) and "bias" in str(k)]
+
+ assert len(new_mixer_bias_keys) > 0, "Should have bias mappings for new_attention"
+
+ for key in new_mixer_bias_keys:
+ expr = plan.mappings[key]
+ assert isinstance(expr, Init), f"{key} should be Init, got {type(expr)}"
diff --git a/fast_llm_external_models/tests/test_apriel2/test_integration.py b/fast_llm_external_models/tests/test_apriel2/test_integration.py
new file mode 100644
index 000000000..e84fa06ef
--- /dev/null
+++ b/fast_llm_external_models/tests/test_apriel2/test_integration.py
@@ -0,0 +1,330 @@
+"""Integration tests for Qwen2 -> Apriel2 -> Fast-LLM conversion pipeline.
+
+Tests verify the full conversion chain:
+1. Qwen2 -> Apriel2 (external module conversion)
+2. Apriel2 + Surgery -> Supernet (stochastic mixer creation)
+3. Supernet -> Fast-LLM -> Supernet (roundtrip through training format)
+
+Test Strategy:
+- Use real HuggingFace model (Qwen2.5-0.5B) for meaningful validation
+- Separate config preservation tests from numerical equivalence tests
+- Parameterize both conversion stages AND input variations
+- Single test implementation applied across all stages
+"""
+
+import json
+import tempfile
+from pathlib import Path
+
+import pytest
+import torch
+
+from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config
+from fast_llm_external_models.apriel2.conversion import compose, compose_configs, execute, plan_surgery
+from fast_llm_external_models.apriel2.conversion.expr import W
+from fast_llm_external_models.apriel2.conversion.qwen2.config import convert_config as convert_qwen2_config
+from fast_llm_external_models.apriel2.conversion.qwen2.plan import plan_qwen2_to_apriel2
+from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForCausalLM
+
+from .conftest import requires_fastllm
+
+# =============================================================================
+# Test Input Variations
+# =============================================================================
+
+TEST_INPUTS = pytest.mark.parametrize(
+ "prompts,max_new_tokens",
+ [
+ pytest.param(["Hello world"], 10, id="single_short"),
+ pytest.param(["Hi", "The quick brown fox jumps over the lazy dog"], 20, id="batch_varied"),
+ pytest.param(["Once upon a time"], 50, id="long_generation"),
+ ],
+)
+
+
+# =============================================================================
+# Conversion Fixtures
+# =============================================================================
+
+
+@pytest.fixture(scope="module")
+def qwen2_source():
+ """Load Qwen2.5-0.5B as the source/reference model."""
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
+
+ model_name = "Qwen/Qwen2.5-0.5B"
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32, trust_remote_code=True)
+ config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
+ model.eval()
+
+ if tokenizer.pad_token is None:
+ tokenizer.pad_token = tokenizer.eos_token
+ tokenizer.padding_side = "left"
+
+ return {
+ "model": model,
+ "tokenizer": tokenizer,
+ "config_dict": config.to_dict(),
+ "state_dict": model.state_dict(),
+ }
+
+
+@pytest.fixture(scope="module")
+def apriel2_converted(qwen2_source):
+ """Stage 1: Qwen2 -> Apriel2."""
+ config_dict = convert_qwen2_config(qwen2_source["config_dict"])
+ plan = plan_qwen2_to_apriel2(qwen2_source["config_dict"])
+ weights = execute(plan, {W(k): v for k, v in qwen2_source["state_dict"].items()}, seed=42)
+
+ config = Apriel2Config(**config_dict)
+ model = Apriel2ForCausalLM(config)
+ model.load_state_dict({str(k): v for k, v in weights.items()}, strict=False)
+ model.eval()
+
+ return {"model": model, "config_dict": config_dict, "plan": plan, "name": "Apriel2"}
+
+
+@pytest.fixture(scope="module")
+def supernet_converted(qwen2_source, apriel2_converted):
+ """Stage 2: Apriel2 + Surgery -> Supernet."""
+ surgery_spec = {
+ "decoder": {
+ "block": {
+ "mixer": {
+ "type": "stochastic",
+ "main_mixer_name": "attention",
+ "mixers": {
+ "attention": {"type": "attention", "init": "transfer"},
+ "sliding_window": {
+ "type": "attention",
+ "init": "transfer",
+ "window_size": 4096,
+ },
+ },
+ },
+ },
+ },
+ }
+
+ apriel_config = apriel2_converted["config_dict"]
+ supernet_config = compose_configs(apriel_config, surgery_spec)
+
+ full_plan = compose(
+ apriel2_converted["plan"],
+ plan_surgery(apriel_config, supernet_config),
+ )
+
+ weights = execute(full_plan, {W(k): v for k, v in qwen2_source["state_dict"].items()}, seed=42)
+
+ config = Apriel2Config(**supernet_config)
+ model = Apriel2ForCausalLM(config)
+ model.load_state_dict({str(k): v for k, v in weights.items()}, strict=False)
+ model.eval()
+
+ return {"model": model, "config_dict": supernet_config, "name": "Supernet"}
+
+
+@pytest.fixture(scope="module")
+def roundtrip_converted(supernet_converted, qwen2_source):
+ """Stage 3: Supernet -> Fast-LLM -> Supernet."""
+ if not torch.cuda.is_available():
+ pytest.skip("Roundtrip conversion requires CUDA (integration tests need realistic hardware)")
+
+ from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, CheckpointSaveConfig, FastLLMCheckpointFormat
+ from fast_llm.engine.checkpoint.convert import ConvertConfig
+ from fast_llm.models.gpt.config import GPTModelConfig
+ from fast_llm.models.gpt.conversion.config import Apriel2TextCheckpointFormat
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ tmpdir = Path(tmpdir)
+ supernet_path = tmpdir / "supernet"
+ fastllm_path = tmpdir / "fastllm"
+ roundtrip_path = tmpdir / "roundtrip"
+
+ supernet_converted["model"].save_pretrained(supernet_path)
+ qwen2_source["tokenizer"].save_pretrained(supernet_path)
+
+ ConvertConfig(
+ model=GPTModelConfig,
+ input=CheckpointLoadConfig(path=supernet_path, format=Apriel2TextCheckpointFormat),
+ output=CheckpointSaveConfig(path=fastllm_path, format=FastLLMCheckpointFormat),
+ ).run()
+
+ ConvertConfig(
+ model=GPTModelConfig,
+ input=CheckpointLoadConfig(path=fastllm_path, format=FastLLMCheckpointFormat),
+ output=CheckpointSaveConfig(path=roundtrip_path, format=Apriel2TextCheckpointFormat),
+ ).run()
+
+ model = Apriel2ForCausalLM.from_pretrained(roundtrip_path)
+ model.eval()
+
+ with open(roundtrip_path / "config.json") as f:
+ config_dict = json.load(f)
+
+ yield {"model": model, "config_dict": config_dict, "name": "Roundtrip"}
+
+
+# =============================================================================
+# Parameterized Fixture: All Conversion Stages
+# =============================================================================
+
+
+@pytest.fixture(params=["apriel2", "supernet", "roundtrip"])
+def converted_model(request, apriel2_converted, supernet_converted):
+ """Parameterized fixture providing each conversion stage for testing.
+
+ This allows a single test to run against all stages automatically.
+ """
+ if request.param == "roundtrip":
+ pytest.importorskip("fast_llm")
+ if not torch.cuda.is_available():
+ pytest.skip("Roundtrip tests require CUDA (integration tests need realistic hardware)")
+ # Lazy-load to avoid fixture evaluation when CUDA unavailable
+ roundtrip_converted = request.getfixturevalue("roundtrip_converted")
+ return roundtrip_converted
+
+ return {
+ "apriel2": apriel2_converted,
+ "supernet": supernet_converted,
+ }[request.param]
+
+
+# =============================================================================
+# Config Preservation Tests
+# =============================================================================
+
+
+@pytest.mark.slow
+class TestConfigPreservation:
+ """Verify configs are correctly preserved through the conversion chain."""
+
+ def test_apriel2_structure(self, qwen2_source, apriel2_converted):
+ """Qwen2 -> Apriel2 preserves model dimensions."""
+ qwen = qwen2_source["config_dict"]
+ apriel = apriel2_converted["config_dict"]
+
+ assert apriel["hidden_size"] == qwen["hidden_size"]
+ assert apriel["vocab_size"] == qwen["vocab_size"]
+ assert apriel["decoder"]["num_blocks"] == qwen["num_hidden_layers"]
+
+ def test_apriel2_bias_pattern(self, apriel2_converted):
+ """Qwen2 -> Apriel2 preserves Qwen-style bias (QKV yes, O no)."""
+ mixer = apriel2_converted["config_dict"]["decoder"]["block"]["mixer"]
+
+ assert mixer["query_layer"]["bias"]["enabled"] is True
+ assert mixer["key_layer"]["bias"]["enabled"] is True
+ assert mixer["value_layer"]["bias"]["enabled"] is True
+ assert mixer["dense_layer"]["bias"]["enabled"] is False
+
+ def test_supernet_structure(self, supernet_converted):
+ """Surgery creates correct stochastic mixer structure."""
+ mixer = supernet_converted["config_dict"]["decoder"]["block"]["mixer"]
+
+ assert mixer["type"] == "stochastic"
+ assert mixer["main_mixer_name"] == "attention"
+ assert set(mixer["mixers"].keys()) == {"attention", "sliding_window"}
+
+ def test_supernet_bias_inheritance(self, supernet_converted):
+ """Submixers inherit bias settings from source."""
+ mixer = supernet_converted["config_dict"]["decoder"]["block"]["mixer"]
+
+ for name in ["attention", "sliding_window"]:
+ assert mixer["mixers"][name]["query_layer"]["bias"]["enabled"] is True
+ assert mixer["mixers"][name]["dense_layer"]["bias"]["enabled"] is False
+
+ @requires_fastllm
+ def test_roundtrip_structure(self, roundtrip_converted):
+ """Fast-LLM roundtrip preserves stochastic mixer structure."""
+ mixer = roundtrip_converted["config_dict"]["decoder"]["block"]["mixer"]
+
+ assert mixer["type"] == "stochastic"
+ assert mixer["main_mixer_name"] == "attention"
+ assert set(mixer["mixers"].keys()) == {"attention", "sliding_window"}
+
+ @requires_fastllm
+ def test_roundtrip_bias_preservation(self, roundtrip_converted):
+ """Fast-LLM roundtrip preserves per-layer bias settings."""
+ mixer = roundtrip_converted["config_dict"]["decoder"]["block"]["mixer"]
+
+ for name in ["attention", "sliding_window"]:
+ assert mixer["mixers"][name]["query_layer"]["bias"]["enabled"] is True
+ assert mixer["mixers"][name]["dense_layer"]["bias"]["enabled"] is False
+
+
+# =============================================================================
+# Numerical Equivalence Tests
+# =============================================================================
+
+
+@pytest.mark.slow
+class TestNumericalEquivalence:
+ """Verify all conversion stages produce numerically identical outputs.
+
+ Uses parameterized fixtures to test all stages with all input variations,
+ giving us 3 stages × 3 inputs = 9 test cases from a single test function.
+ """
+
+ @TEST_INPUTS
+ def test_logits_match(self, qwen2_source, converted_model, prompts, max_new_tokens):
+ """Converted model produces identical logits to source."""
+ tokenizer = qwen2_source["tokenizer"]
+ ref_model = qwen2_source["model"]
+ test_model = converted_model["model"]
+ stage = converted_model["name"]
+
+ inputs = tokenizer(prompts, return_tensors="pt", padding=True)
+ ref_device = next(ref_model.parameters()).device
+ test_device = next(test_model.parameters()).device
+
+ with torch.no_grad():
+ ref_logits = ref_model(
+ input_ids=inputs.input_ids.to(ref_device),
+ attention_mask=inputs.attention_mask.to(ref_device),
+ ).logits.cpu()
+
+ test_logits = test_model(
+ input_ids=inputs.input_ids.to(test_device),
+ attention_mask=inputs.attention_mask.to(test_device),
+ ).logits.cpu()
+
+ max_diff = (ref_logits - test_logits).abs().max().item()
+ assert torch.allclose(
+ ref_logits, test_logits, rtol=1e-4, atol=1e-4
+ ), f"{stage} logits mismatch: max diff = {max_diff:.6f}"
+
+ @TEST_INPUTS
+ def test_generation_match(self, qwen2_source, converted_model, prompts, max_new_tokens):
+ """Converted model produces identical generation to source."""
+ tokenizer = qwen2_source["tokenizer"]
+ ref_model = qwen2_source["model"]
+ test_model = converted_model["model"]
+ stage = converted_model["name"]
+
+ inputs = tokenizer(prompts, return_tensors="pt", padding=True)
+ ref_device = next(ref_model.parameters()).device
+ test_device = next(test_model.parameters()).device
+
+ with torch.no_grad():
+ ref_gen = ref_model.generate(
+ input_ids=inputs.input_ids.to(ref_device),
+ attention_mask=inputs.attention_mask.to(ref_device),
+ max_new_tokens=max_new_tokens,
+ do_sample=False,
+ pad_token_id=tokenizer.pad_token_id,
+ ).cpu()
+
+ test_gen = test_model.generate(
+ input_ids=inputs.input_ids.to(test_device),
+ attention_mask=inputs.attention_mask.to(test_device),
+ max_new_tokens=max_new_tokens,
+ do_sample=False,
+ pad_token_id=tokenizer.pad_token_id,
+ ).cpu()
+
+ assert torch.equal(ref_gen, test_gen), (
+ f"{stage} generation mismatch:\n"
+ f" Reference: {tokenizer.batch_decode(ref_gen, skip_special_tokens=True)}\n"
+ f" Test: {tokenizer.batch_decode(test_gen, skip_special_tokens=True)}"
+ )
diff --git a/fast_llm_external_models/tests/test_apriel2/test_mixer_equivalence.py b/fast_llm_external_models/tests/test_apriel2/test_mixer_equivalence.py
index 1aa8a56d9..c6f3337e8 100644
--- a/fast_llm_external_models/tests/test_apriel2/test_mixer_equivalence.py
+++ b/fast_llm_external_models/tests/test_apriel2/test_mixer_equivalence.py
@@ -28,15 +28,7 @@
import torch
import torch.nn as nn
-from fast_llm_external_models.apriel2.conversion import (
- Concat,
- ExprPlan,
- Ref,
- Slice,
- W,
- execute,
-)
-
+from fast_llm_external_models.apriel2.conversion import Concat, ExprPlan, Ref, Slice, W, execute
# =============================================================================
# Shared Fixtures
@@ -69,10 +61,10 @@ def hidden_size(request):
@pytest.fixture(
params=[
- pytest.param((8, 8, 32), id="mha-8h-32d"), # MHA: 8 heads, 8 kv heads, 32 head_dim
- pytest.param((8, 4, 32), id="gqa-8h4kv-32d"), # GQA: 8 heads, 4 kv heads, 32 head_dim
- pytest.param((8, 2, 64), id="gqa-8h2kv-64d"), # GQA: 8 heads, 2 kv heads, 64 head_dim
- pytest.param((4, 1, 64), id="mqa-4h1kv-64d"), # MQA: 4 heads, 1 kv head, 64 head_dim
+ pytest.param((8, 8, 32), id="mha-8h-32d"), # MHA: 8 heads, 8 kv heads, 32 head_dim
+ pytest.param((8, 4, 32), id="gqa-8h4kv-32d"), # GQA: 8 heads, 4 kv heads, 32 head_dim
+ pytest.param((8, 2, 64), id="gqa-8h2kv-64d"), # GQA: 8 heads, 2 kv heads, 64 head_dim
+ pytest.param((4, 1, 64), id="mqa-4h1kv-64d"), # MQA: 4 heads, 1 kv head, 64 head_dim
]
)
def attention_config(request):
@@ -90,7 +82,7 @@ def attention_config(request):
params=[
pytest.param((8, 4, 32, 32), id="8v-4k-32d"), # 8 value heads, 4 key heads, symmetric dims
pytest.param((8, 2, 64, 64), id="8v-2k-64d"), # 8 value heads, 2 key heads, larger dims
- pytest.param((4, 2, 32, 64), id="4v-2k-asym"), # Asymmetric key/value dims
+ pytest.param((4, 2, 32, 64), id="4v-2k-asym"), # Asymmetric key/value dims
]
)
def gdn_config(request):
@@ -100,9 +92,9 @@ def gdn_config(request):
@pytest.fixture(
params=[
- pytest.param((4, 8), id="4h-8d"), # 4 heads, 8 head_dim (small)
- pytest.param((8, 16), id="8h-16d"), # 8 heads, 16 head_dim (medium)
- pytest.param((4, 32), id="4h-32d"), # 4 heads, 32 head_dim (large head_dim)
+ pytest.param((4, 8), id="4h-8d"), # 4 heads, 8 head_dim (small)
+ pytest.param((8, 16), id="8h-16d"), # 8 heads, 16 head_dim (medium)
+ pytest.param((4, 32), id="4h-32d"), # 4 heads, 32 head_dim (large head_dim)
]
)
def kda_config(request):
@@ -283,9 +275,21 @@ def plan_qwen3next_gdn_to_apriel2(
for g in range(num_k_heads):
base = g * group_size
q_slices.append(Slice(expr=qkvz_ref, slices=((base, base + head_k_dim, None), (None, None, None))))
- k_slices.append(Slice(expr=qkvz_ref, slices=((base + head_k_dim, base + 2 * head_k_dim, None), (None, None, None))))
- v_slices.append(Slice(expr=qkvz_ref, slices=((base + 2 * head_k_dim, base + 2 * head_k_dim + v_per_group, None), (None, None, None))))
- z_slices.append(Slice(expr=qkvz_ref, slices=((base + 2 * head_k_dim + v_per_group, base + group_size, None), (None, None, None))))
+ k_slices.append(
+ Slice(expr=qkvz_ref, slices=((base + head_k_dim, base + 2 * head_k_dim, None), (None, None, None)))
+ )
+ v_slices.append(
+ Slice(
+ expr=qkvz_ref,
+ slices=((base + 2 * head_k_dim, base + 2 * head_k_dim + v_per_group, None), (None, None, None)),
+ )
+ )
+ z_slices.append(
+ Slice(
+ expr=qkvz_ref,
+ slices=((base + 2 * head_k_dim + v_per_group, base + group_size, None), (None, None, None)),
+ )
+ )
in_proj_qkvz_expr = Concat(
exprs=(
@@ -304,8 +308,15 @@ def plan_qwen3next_gdn_to_apriel2(
b_slices, a_slices = [], []
for g in range(num_k_heads):
base = g * ba_per_group
- b_slices.append(Slice(expr=ba_ref, slices=((base, base + num_v_heads // num_k_heads, None), (None, None, None))))
- a_slices.append(Slice(expr=ba_ref, slices=((base + num_v_heads // num_k_heads, base + ba_per_group, None), (None, None, None))))
+ b_slices.append(
+ Slice(expr=ba_ref, slices=((base, base + num_v_heads // num_k_heads, None), (None, None, None)))
+ )
+ a_slices.append(
+ Slice(
+ expr=ba_ref,
+ slices=((base + num_v_heads // num_k_heads, base + ba_per_group, None), (None, None, None)),
+ )
+ )
in_proj_ba_expr = Concat(
exprs=(Concat(exprs=tuple(b_slices), dim=0), Concat(exprs=tuple(a_slices), dim=0)),
@@ -565,6 +576,7 @@ def test_causal_vs_mistral(
):
"""Verify Apriel2Attention (causal) matches MistralAttention output."""
from transformers.models.mistral.modeling_mistral import MistralAttention, MistralRotaryEmbedding
+
from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2Attention
mixer_config = apriel2_config.decoder["block"]["mixer"]
@@ -593,13 +605,20 @@ def test_causal_vs_mistral(
apriel2_attn.eval()
with torch.no_grad():
- mistral_out = mistral_attn(hidden_states, position_embeddings=position_embeddings, attention_mask=causal_mask)[0]
- apriel2_out = apriel2_attn(hidden_states, attention_mask=causal_mask, position_embeddings=position_embeddings)[0]
+ mistral_out = mistral_attn(
+ hidden_states, position_embeddings=position_embeddings, attention_mask=causal_mask
+ )[0]
+ apriel2_out = apriel2_attn(
+ hidden_states, attention_mask=causal_mask, position_embeddings=position_embeddings
+ )[0]
rtol, atol = tolerance
assert_close(
- apriel2_out, mistral_out, rtol=rtol, atol=atol,
- msg=f"Apriel2Attention vs MistralAttention (batch={batch_size}, seq={seq_len}, hidden={hidden_size})"
+ apriel2_out,
+ mistral_out,
+ rtol=rtol,
+ atol=atol,
+ msg=f"Apriel2Attention vs MistralAttention (batch={batch_size}, seq={seq_len}, hidden={hidden_size})",
)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA")
@@ -613,8 +632,9 @@ def test_noncausal_vs_pixtral(
tolerance,
):
"""Verify Apriel2Attention (non-causal) matches PixtralAttention output."""
- from transformers.models.pixtral.modeling_pixtral import PixtralAttention, PixtralRotaryEmbedding
from transformers.models.pixtral.configuration_pixtral import PixtralVisionConfig
+ from transformers.models.pixtral.modeling_pixtral import PixtralAttention, PixtralRotaryEmbedding
+
from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2TextConfig
from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2Attention
@@ -689,8 +709,11 @@ def test_noncausal_vs_pixtral(
rtol, atol = tolerance
assert_close(
- apriel2_out, pixtral_out, rtol=rtol, atol=atol,
- msg=f"Apriel2Attention (non-causal) vs PixtralAttention (batch={batch_size}, seq={seq_len})"
+ apriel2_out,
+ pixtral_out,
+ rtol=rtol,
+ atol=atol,
+ msg=f"Apriel2Attention (non-causal) vs PixtralAttention (batch={batch_size}, seq={seq_len})",
)
@@ -737,6 +760,7 @@ def test_vs_qwen3next(
):
"""Verify Apriel2GatedDeltaNet matches Qwen3NextGatedDeltaNet output."""
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextGatedDeltaNet
+
from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2GatedDeltaNet
value_heads, key_heads, key_head_dim, value_head_dim = gdn_config
@@ -758,8 +782,10 @@ def test_vs_qwen3next(
# Transfer weights
plan = plan_qwen3next_gdn_to_apriel2(
- num_k_heads=key_heads, num_v_heads=value_heads,
- head_k_dim=key_head_dim, head_v_dim=value_head_dim,
+ num_k_heads=key_heads,
+ num_v_heads=value_heads,
+ head_k_dim=key_head_dim,
+ head_v_dim=value_head_dim,
)
source_weights = extract_module_weights(qwen_gdn)
target_weights = execute(plan, source_weights, seed=seed)
@@ -778,8 +804,11 @@ def test_vs_qwen3next(
rtol, atol = tolerance
assert_close(
- apriel2_out, qwen_out, rtol=rtol, atol=atol,
- msg=f"Apriel2GatedDeltaNet vs Qwen3NextGatedDeltaNet (batch={batch_size}, seq={seq_len})"
+ apriel2_out,
+ qwen_out,
+ rtol=rtol,
+ atol=atol,
+ msg=f"Apriel2GatedDeltaNet vs Qwen3NextGatedDeltaNet (batch={batch_size}, seq={seq_len})",
)
@@ -803,6 +832,7 @@ def test_vs_fla(
):
"""Verify Apriel2 KimiDeltaAttention matches FLA KimiDeltaAttention output."""
from fla.layers.kda import KimiDeltaAttention as FLA_KDA
+
from fast_llm_external_models.apriel2.modeling_apriel2 import KimiDeltaAttention as Apriel2_KDA
num_heads, head_dim = kda_config
@@ -853,8 +883,11 @@ def test_vs_fla(
rtol, atol = tolerance
assert_close(
- apriel2_out, fla_out, rtol=rtol, atol=atol,
- msg=f"Apriel2 KDA vs FLA KDA (batch={batch_size}, seq={seq_len}, hidden={hidden_size})"
+ apriel2_out,
+ fla_out,
+ rtol=rtol,
+ atol=atol,
+ msg=f"Apriel2 KDA vs FLA KDA (batch={batch_size}, seq={seq_len}, hidden={hidden_size})",
)
@@ -913,7 +946,4 @@ def test_gdn_fast_vs_slow(self, gdn_config, batch_size):
slow_out = model(hidden_states)[0].clone()
# Looser tolerance for kernel vs reference comparison
- assert_close(
- fast_out, slow_out, rtol=1e-3, atol=1e-3,
- msg="GDN fast path (CUDA) vs slow path (PyTorch)"
- )
+ assert_close(fast_out, slow_out, rtol=1e-3, atol=1e-3, msg="GDN fast path (CUDA) vs slow path (PyTorch)")
diff --git a/fast_llm_external_models/tests/test_apriel2/test_model_structure.py b/fast_llm_external_models/tests/test_apriel2/test_model_structure.py
index 23856be30..56d2bc6a6 100644
--- a/fast_llm_external_models/tests/test_apriel2/test_model_structure.py
+++ b/fast_llm_external_models/tests/test_apriel2/test_model_structure.py
@@ -1,9 +1,9 @@
"""Tests for Apriel2 model structure and architecture validation."""
-import pytest
import torch
-from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForCausalLM
+
from fast_llm_external_models.apriel2.cache import Apriel2Cache, _AttentionCache, _SSMCache
+from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForCausalLM
class TestStochasticMixerStructure:
@@ -14,20 +14,27 @@ def test_all_submixers_present(self, apriel2_config_all_mixers):
model = Apriel2ForCausalLM(apriel2_config_all_mixers)
stochastic_layer = model.model.decoder.blocks[1] # Layer 1 is the "all_mixers" layer
- assert hasattr(stochastic_layer.mixer, 'mixers'), "Stochastic mixer should have 'mixers' attribute"
+ assert hasattr(stochastic_layer.mixer, "mixers"), "Stochastic mixer should have 'mixers' attribute"
assert set(stochastic_layer.mixer.mixers.keys()) == {
- 'attention', 'swa', 'mamba', 'gdn'
+ "attention",
+ "swa",
+ "mamba",
+ "gdn",
}, "Stochastic mixer should contain all 4 configured mixer types"
# Verify each mixer is the correct type
from fast_llm_external_models.apriel2.modeling_apriel2 import (
- Apriel2Attention, Apriel2Mamba, Apriel2GatedDeltaNet
+ Apriel2Attention,
+ Apriel2GatedDeltaNet,
+ Apriel2Mamba,
)
- assert isinstance(stochastic_layer.mixer.mixers['attention'], Apriel2Attention)
- assert isinstance(stochastic_layer.mixer.mixers['swa'], Apriel2Attention) # SWA is Apriel2Attention with sliding_window
- assert isinstance(stochastic_layer.mixer.mixers['mamba'], Apriel2Mamba)
- assert isinstance(stochastic_layer.mixer.mixers['gdn'], Apriel2GatedDeltaNet)
+ assert isinstance(stochastic_layer.mixer.mixers["attention"], Apriel2Attention)
+ assert isinstance(
+ stochastic_layer.mixer.mixers["swa"], Apriel2Attention
+ ) # SWA is Apriel2Attention with sliding_window
+ assert isinstance(stochastic_layer.mixer.mixers["mamba"], Apriel2Mamba)
+ assert isinstance(stochastic_layer.mixer.mixers["gdn"], Apriel2GatedDeltaNet)
def test_main_mixer_is_configured(self, apriel2_config_all_mixers):
"""Verify main_mixer_name is set correctly."""
@@ -44,7 +51,10 @@ def test_cache_has_all_submixer_slots(self, apriel2_config_all_mixers):
assert isinstance(layer_cache, dict), "Stochastic layer cache should be a dict"
assert set(layer_cache.keys()) == {
- 'attention', 'swa', 'mamba', 'gdn'
+ "attention",
+ "swa",
+ "mamba",
+ "gdn",
}, "Cache should have slots for all 4 mixers"
def test_attention_mixers_use_attention_cache(self, apriel2_config_all_mixers):
@@ -53,12 +63,12 @@ def test_attention_mixers_use_attention_cache(self, apriel2_config_all_mixers):
layer_cache = cache.layers[1]
# Attention-based mixers use AttentionCache
- assert isinstance(layer_cache['attention'], _AttentionCache)
- assert isinstance(layer_cache['swa'], _AttentionCache)
+ assert isinstance(layer_cache["attention"], _AttentionCache)
+ assert isinstance(layer_cache["swa"], _AttentionCache)
# SSM-based mixers use SSMCache
- assert isinstance(layer_cache['mamba'], _SSMCache)
- assert isinstance(layer_cache['gdn'], _SSMCache)
+ assert isinstance(layer_cache["mamba"], _SSMCache)
+ assert isinstance(layer_cache["gdn"], _SSMCache)
def test_parameter_counts_differ_by_config(self):
"""Different configs create models with different parameter counts."""
@@ -74,8 +84,10 @@ def test_parameter_counts_differ_by_config(self):
}
config_tiny = Apriel2Config(
- vocab_size=100, hidden_size=64,
- num_attention_heads=4, num_key_value_heads=2,
+ vocab_size=100,
+ hidden_size=64,
+ num_attention_heads=4,
+ num_key_value_heads=2,
decoder={
"type": "fixed",
"num_blocks": 2,
@@ -88,8 +100,10 @@ def test_parameter_counts_differ_by_config(self):
)
config_stochastic = Apriel2Config(
- vocab_size=100, hidden_size=64,
- num_attention_heads=4, num_key_value_heads=2,
+ vocab_size=100,
+ hidden_size=64,
+ num_attention_heads=4,
+ num_key_value_heads=2,
decoder={
"type": "pattern",
"num_blocks": 2,
@@ -106,14 +120,14 @@ def test_parameter_counts_differ_by_config(self):
"main_mixer_name": "attention",
"mixers": {
"attention": attn_config,
- "mamba": {"type": "mamba", "conv_bias": True, "dt_proj_bias": True}
- }
+ "mamba": {"type": "mamba", "conv_bias": True, "dt_proj_bias": True},
+ },
},
"mlp": {"type": "mlp", "intermediate_size": 256, "gated": True},
"normalization": {"type": "rms_norm"},
- }
- }
- }
+ },
+ },
+ },
)
model_tiny = Apriel2ForCausalLM(config_tiny)
@@ -122,8 +136,9 @@ def test_parameter_counts_differ_by_config(self):
params_tiny = sum(p.numel() for p in model_tiny.parameters())
params_stochastic = sum(p.numel() for p in model_stochastic.parameters())
- assert params_stochastic > params_tiny, \
- "Stochastic mixer should have more parameters (has both attention and mamba)"
+ assert (
+ params_stochastic > params_tiny
+ ), "Stochastic mixer should have more parameters (has both attention and mamba)"
def test_weights_are_initialized(self, apriel2_config_all_mixers):
"""Verify model weights are initialized (not all zeros/constant)."""
@@ -136,9 +151,7 @@ def test_weights_are_initialized(self, apriel2_config_all_mixers):
# Basic sanity: at least some parameters should be non-zero
non_zero_params = sum(
- not torch.all(p == 0)
- for mixer in stochastic_layer.mixer.mixers.values()
- for p in mixer.parameters()
+ not torch.all(p == 0) for mixer in stochastic_layer.mixer.mixers.values() for p in mixer.parameters()
)
assert non_zero_params > 0, "At least some mixer parameters should be non-zero"
diff --git a/fast_llm_external_models/tests/test_apriel2/test_modeling.py b/fast_llm_external_models/tests/test_apriel2/test_modeling.py
index 5dbd36159..8e2f610bb 100644
--- a/fast_llm_external_models/tests/test_apriel2/test_modeling.py
+++ b/fast_llm_external_models/tests/test_apriel2/test_modeling.py
@@ -2,18 +2,23 @@
import pytest
import torch
+
from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForCausalLM
class TestApriel2Modeling:
"""End-to-end tests for Apriel2 model with different configurations."""
- @pytest.mark.parametrize("config_name", [
- "apriel2_config_tiny",
- "apriel2_config_stochastic",
- "apriel2_config_multi_mixer",
- "apriel2_config_all_mixers" # Tests all 4 mixer types
- ])
+ @pytest.mark.parametrize(
+ "config_name",
+ [
+ "apriel2_config_tiny",
+ "apriel2_config_stochastic",
+ "apriel2_config_multi_mixer",
+ "apriel2_config_all_mixers", # Tests all 4 mixer types
+ "apriel2_config_with_bias", # Tests per-layer bias and non-gated MLP
+ ],
+ )
def test_model_end_to_end(self, config_name, request):
"""Test instantiation, forward pass, cache correctness, and generation.
@@ -42,7 +47,7 @@ def test_model_end_to_end(self, config_name, request):
# 2. Forward pass - basic shape validation
outputs = model(input_ids, use_cache=False)
assert outputs.logits.shape == (2, seq_len, config.vocab_size)
- assert hasattr(outputs, 'logits')
+ assert hasattr(outputs, "logits")
# 3. Verify cache is actually being used (not dormant)
split_pos = 30
@@ -52,28 +57,23 @@ def test_model_end_to_end(self, config_name, request):
assert outputs_part1.past_key_values is not None
outputs_correct_cache = model(
- input_ids[:, split_pos:split_pos+1],
- past_key_values=outputs_part1.past_key_values,
- use_cache=True
+ input_ids[:, split_pos : split_pos + 1], past_key_values=outputs_part1.past_key_values, use_cache=True
)
# Test 1: Empty cache should give different results than filled cache
# This verifies cache is being used at all
from fast_llm_external_models.apriel2.cache import Apriel2Cache, _AttentionCache
+
empty_cache = Apriel2Cache(config)
outputs_empty_cache = model(
- input_ids[:, split_pos:split_pos+1],
- past_key_values=empty_cache,
- use_cache=True
+ input_ids[:, split_pos : split_pos + 1], past_key_values=empty_cache, use_cache=True
)
- cache_affects_output = not torch.allclose(
- outputs_correct_cache.logits,
- outputs_empty_cache.logits,
- atol=1e-3
- )
- assert cache_affects_output, f"Cache appears dormant for {config_name} - empty cache gives same results as filled cache"
+ cache_affects_output = not torch.allclose(outputs_correct_cache.logits, outputs_empty_cache.logits, atol=1e-3)
+ assert (
+ cache_affects_output
+ ), f"Cache appears dormant for {config_name} - empty cache gives same results as filled cache"
# Test 2: Corrupted cache (zeros) should give different results than correct cache
# This verifies the actual cache VALUES are being used
@@ -98,17 +98,15 @@ def test_model_end_to_end(self, config_name, request):
corrupted_layer[name].value = torch.zeros_like(correct_sub.value)
outputs_corrupted_cache = model(
- input_ids[:, split_pos:split_pos+1],
- past_key_values=corrupted_cache,
- use_cache=True
+ input_ids[:, split_pos : split_pos + 1], past_key_values=corrupted_cache, use_cache=True
)
cache_values_matter = not torch.allclose(
- outputs_correct_cache.logits,
- outputs_corrupted_cache.logits,
- atol=1e-3
+ outputs_correct_cache.logits, outputs_corrupted_cache.logits, atol=1e-3
)
- assert cache_values_matter, f"Cache values not used for {config_name} - zeroed cache gives same results as correct cache"
+ assert (
+ cache_values_matter
+ ), f"Cache values not used for {config_name} - zeroed cache gives same results as correct cache"
# 4. Cache correctness - validate cache produces same results as no-cache
# Compute full sequence without cache
@@ -117,18 +115,14 @@ def test_model_end_to_end(self, config_name, request):
# Compute in two steps with cache
outputs_part1 = model(input_ids[:, :split_pos], use_cache=True)
outputs_part2 = model(
- input_ids[:, split_pos:split_pos+1],
- past_key_values=outputs_part1.past_key_values,
- use_cache=True
+ input_ids[:, split_pos : split_pos + 1], past_key_values=outputs_part1.past_key_values, use_cache=True
)
# Logits should match between cached and non-cached
# Note: GPU execution with bfloat16/float16 has lower precision than CPU float32,
# so we use a looser tolerance here.
assert torch.allclose(
- outputs_full.logits[:, split_pos, :],
- outputs_part2.logits[:, 0, :],
- atol=1e-3
+ outputs_full.logits[:, split_pos, :], outputs_part2.logits[:, 0, :], atol=1e-3
), f"Cache correctness failed for {config_name}: cached and non-cached logits differ"
# 5. Generation - end-to-end validation
diff --git a/fast_llm_external_models/tests/test_apriel2/test_plan_execution.py b/fast_llm_external_models/tests/test_apriel2/test_plan_execution.py
new file mode 100644
index 000000000..ca0c8739f
--- /dev/null
+++ b/fast_llm_external_models/tests/test_apriel2/test_plan_execution.py
@@ -0,0 +1,598 @@
+"""test_plan_execution.py - Plan execution and algebraic composition laws.
+
+This module provides rigorous, parameterized tests for the mathematical properties
+that the conversion system must satisfy. Each test class corresponds to one
+algebraic structure, and each test method verifies one specific law.
+
+Conceptual Types
+================
+
+The conversion system operates on three conceptual types (all ``dict`` at runtime):
+
+- **S (State)**: Complete config without ``init`` fields
+- **P (Partial Surgery)**: Incomplete config, may have ``init`` fields
+- **T (Transition Spec)**: Complete config WITH ``init`` fields
+
+Algebraic Structures
+====================
+
+1. **Partial Surgeries (P)** form a **Monoid** under deep merge::
+
+ compose_configs : P × P → P
+ Identity: {}
+ Associativity: (p1 ∘ p2) ∘ p3 = p1 ∘ (p2 ∘ p3)
+
+2. **Surgeries act on States** to produce Transition Specs::
+
+ compose_configs : S × P → T
+ compose_configs : T × P → T
+
+ Action law (additive surgeries): (s · p1) · p2 = s · (p1 ∘ p2)
+
+3. **Plans** form a **Category** with composition::
+
+ compose : Plan(A→B) × Plan(B→C) → Plan(A→C)
+ Associativity: (P1 ∘ P2) ∘ P3 = P1 ∘ (P2 ∘ P3)
+
+4. **plan_surgery is a Functor** from config pairs to plans::
+
+ plan_surgery : S × T → Plan
+ Functoriality: compose(plan(S,T1), plan(T1,T2)) ≡ plan(S,T2)
+
+ This is semantic equivalence: both produce identical weights when executed.
+
+Important Behaviors Tested
+==========================
+
+- **init stripping**: Between surgery iterations, T → S conversion via
+ ``strip_init_fields()`` ensures ``init: random`` applies only to the surgery
+ that introduces a component.
+
+- **Bias inheritance**: Per-layer bias settings propagate through surgery chains.
+
+- **Plan composition**: Composed plans produce identical weights to direct plans.
+
+Design Principles
+=================
+
+- Each law gets ONE parameterized test, not multiple similar tests
+- Fixtures provide diverse configs (with/without biases)
+- Corner cases are covered via parameterization, not test proliferation
+- Tests document the laws they verify in their docstrings
+"""
+
+from functools import reduce
+
+import pytest
+import torch
+
+from fast_llm_external_models.apriel2.conversion import (
+ Concat,
+ ExprPlan,
+ Init,
+ Ref,
+ Slice,
+ W,
+ compose,
+ compose_configs,
+ execute,
+ plan_surgery,
+)
+
+# Import shared helper from conftest
+from fast_llm_external_models.tests.test_apriel2.conftest import make_weights_for_config
+
+# =============================================================================
+# Fixtures: Use shared fixtures from conftest.py where possible
+# =============================================================================
+# - base_config_dict: Complete config without biases (Llama-style)
+# - base_config_with_bias_dict: Complete config with QKV biases
+# - additive_surgery_chain: [wrap_stochastic, add_sliding_window, add_gdn]
+# =============================================================================
+
+
+# =============================================================================
+# Test: Plan Composition Associativity
+# =============================================================================
+
+
+class TestPlanCompositionAssociativity:
+ """
+ LAW: Plan composition is associative.
+
+ (P₁ ∘ P₂) ∘ P₃ = P₁ ∘ (P₂ ∘ P₃)
+
+ where ∘ denotes compose(P1, P2).
+
+ This must hold for the AST structure, not just semantic equivalence.
+ """
+
+ @pytest.mark.parametrize("expr_type", ["ref_chain", "with_concat", "with_slice", "with_init"])
+ def test_associativity(self, expr_type):
+ """Plan composition is associative for various expression types."""
+ # Build three plans that can be composed
+ if expr_type == "ref_chain":
+ p1 = ExprPlan(mappings={W("b"): Ref(key=W("a"))})
+ p2 = ExprPlan(mappings={W("c"): Ref(key=W("b"))})
+ p3 = ExprPlan(mappings={W("d"): Ref(key=W("c"))})
+ elif expr_type == "with_concat":
+ p1 = ExprPlan(mappings={W("x"): Ref(key=W("a")), W("y"): Ref(key=W("b"))})
+ p2 = ExprPlan(mappings={W("xy"): Concat(exprs=(Ref(key=W("x")), Ref(key=W("y"))), dim=0)})
+ p3 = ExprPlan(mappings={W("final"): Ref(key=W("xy"))})
+ elif expr_type == "with_slice":
+ p1 = ExprPlan(mappings={W("full"): Ref(key=W("src"))})
+ p2 = ExprPlan(mappings={W("part"): Slice(expr=Ref(key=W("full")), slices=((0, 5, None),))})
+ p3 = ExprPlan(mappings={W("out"): Ref(key=W("part"))})
+ elif expr_type == "with_init":
+ p1 = ExprPlan(mappings={W("x"): Ref(key=W("a"))})
+ p2 = ExprPlan(
+ mappings={W("y"): Concat(exprs=(Ref(key=W("x")), Init(shape=(5,), init_type="zeros")), dim=0)}
+ )
+ p3 = ExprPlan(mappings={W("z"): Ref(key=W("y"))})
+
+ left = compose(compose(p1, p2), p3)
+ right = compose(p1, compose(p2, p3))
+
+ assert left.mappings == right.mappings, f"Associativity failed for {expr_type}"
+
+
+# =============================================================================
+# Test: Functoriality of plan_surgery (THE CRITICAL PROPERTY)
+# =============================================================================
+
+
+class TestPlanSurgeryFunctoriality:
+ """
+ LAW: plan_surgery is functorial with respect to config composition.
+
+ For a surgery chain P₁, P₂, ..., Pₙ applied to base state S₀::
+
+ T₁ = compose_configs(S₀, P₁) # S × P → T
+ T₂ = compose_configs(T₁, P₂) # T × P → T (no stripping!)
+ ...
+ Tₙ = compose_configs(Tₙ₋₁, Pₙ)
+
+ Plan functoriality says::
+
+ compose(plan(S₀,T₁), plan(T₁,T₂), ...) ≡ plan(S₀, Tₙ)
+
+ where ≡ denotes semantic equivalence (identical weights when executed).
+
+ NOTE: This tests T × P composition WITHOUT stripping between steps.
+ This differs from build_plan which strips (T → S) between iterations.
+ Both patterns are valid:
+
+ - Without stripping: init fields accumulate, testing plan composition purity
+ - With stripping: init consumed per-step, testing real usage (see
+ test_build_plan_strips_init_between_iterations)
+
+ The functoriality law holds in both cases because plan composition
+ correctly substitutes Ref expressions with their definitions.
+ """
+
+ @pytest.mark.parametrize("chain_length", [1, 2, 3])
+ @pytest.mark.parametrize("use_bias", [True, False])
+ def test_functoriality(
+ self,
+ chain_length,
+ use_bias,
+ base_config_dict,
+ base_config_with_bias_dict,
+ additive_surgery_chain,
+ ):
+ """
+ Composed incremental plans produce same weights as direct plan.
+
+ Parameterized over:
+ - chain_length: Number of surgeries (1, 2, or 3)
+ - use_bias: Whether base config has biases
+ """
+ base_config = base_config_with_bias_dict if use_bias else base_config_dict
+ surgeries = additive_surgery_chain[:chain_length]
+
+ # Build config chain: C₀ → C₁ → ... → Cₙ
+ configs = [base_config]
+ for s in surgeries:
+ configs.append(compose_configs(configs[-1], s))
+
+ # Build incremental plans: Pₖ = plan_surgery(Cₖ₋₁, Cₖ)
+ plans = [plan_surgery(configs[i], configs[i + 1]) for i in range(len(surgeries))]
+
+ # Compose all incremental plans
+ composed_plan = reduce(compose, plans)
+
+ # Build direct plan: plan_surgery(C₀, Cₙ)
+ direct_plan = plan_surgery(configs[0], configs[-1])
+
+ # Execute both on same weights
+ weights = make_weights_for_config(base_config)
+ composed_weights = execute(composed_plan, weights, seed=42)
+ direct_weights = execute(direct_plan, weights, seed=42)
+
+ # Verify semantic equivalence
+ assert set(composed_weights.keys()) == set(
+ direct_weights.keys()
+ ), f"Key sets differ for chain_length={chain_length}, use_bias={use_bias}"
+
+ for key in composed_weights:
+ assert torch.allclose(
+ composed_weights[key], direct_weights[key], atol=1e-6
+ ), f"Weight mismatch for {key} with chain_length={chain_length}, use_bias={use_bias}"
+
+ @pytest.mark.parametrize("split_point", [1, 2])
+ def test_arbitrary_grouping(
+ self,
+ split_point,
+ base_config_with_bias_dict,
+ additive_surgery_chain,
+ ):
+ """
+ Any grouping of surgery chain produces same result.
+
+ For surgeries [S₁, S₂, S₃], tests that:
+ - compose(P₁, compose(P₂, P₃))
+ - compose(compose(P₁, P₂), P₃)
+ - plan_surgery(C₀, C₃)
+
+ all produce identical weights.
+ """
+ surgeries = additive_surgery_chain
+
+ # Build config chain
+ configs = [base_config_with_bias_dict]
+ for s in surgeries:
+ configs.append(compose_configs(configs[-1], s))
+
+ # Build incremental plans
+ plans = [plan_surgery(configs[i], configs[i + 1]) for i in range(3)]
+
+ # Different groupings
+ left_grouped = compose(compose(plans[0], plans[1]), plans[2])
+ right_grouped = compose(plans[0], compose(plans[1], plans[2]))
+ direct = plan_surgery(configs[0], configs[-1])
+
+ # Execute all
+ weights = make_weights_for_config(base_config_with_bias_dict)
+ results = {
+ "left": execute(left_grouped, weights, seed=42),
+ "right": execute(right_grouped, weights, seed=42),
+ "direct": execute(direct, weights, seed=42),
+ }
+
+ # All must match
+ keys = set(results["left"].keys())
+ assert keys == set(results["right"].keys()) == set(results["direct"].keys())
+
+ for key in keys:
+ assert torch.allclose(results["left"][key], results["right"][key], atol=1e-6)
+ assert torch.allclose(results["left"][key], results["direct"][key], atol=1e-6)
+
+
+# =============================================================================
+# Test: Bias Inheritance Preservation (Regression for the specific bug)
+# =============================================================================
+
+
+class TestBiasInheritancePreservation:
+ """
+ PROPERTY: Per-layer bias settings must be preserved through surgery chains.
+
+ When a surgery spec does not mention bias settings, they must be inherited
+ from the source config. This is the specific failure mode of the build_plan
+ bug: passing partial surgery specs to plan_surgery lost inherited fields.
+
+ This test verifies the SYMPTOM (missing biases) rather than the LAW
+ (functoriality). It's kept as a focused regression test.
+ """
+
+ @pytest.mark.parametrize("num_surgeries", [1, 2, 3])
+ def test_qkv_biases_preserved_through_chain(
+ self,
+ num_surgeries,
+ base_config_with_bias_dict,
+ additive_surgery_chain,
+ ):
+ """QKV biases (enabled in source) appear in plan after N surgeries."""
+ surgeries = additive_surgery_chain[:num_surgeries]
+
+ # Build config and plan chain
+ configs = [base_config_with_bias_dict]
+ for s in surgeries:
+ configs.append(compose_configs(configs[-1], s))
+
+ plans = [plan_surgery(configs[i], configs[i + 1]) for i in range(num_surgeries)]
+ final_plan = reduce(compose, plans) if len(plans) > 1 else plans[0]
+
+ # Check bias keys present
+ target_keys = {str(k) for k in final_plan.target_keys()}
+
+ assert any("q_proj.bias" in k for k in target_keys), f"q_proj.bias missing after {num_surgeries} surgeries"
+ assert any("k_proj.bias" in k for k in target_keys), f"k_proj.bias missing after {num_surgeries} surgeries"
+ assert any("v_proj.bias" in k for k in target_keys), f"v_proj.bias missing after {num_surgeries} surgeries"
+ # O bias should NOT be present (disabled in source)
+ assert not any(
+ "o_proj.bias" in k for k in target_keys
+ ), f"o_proj.bias should not be present (disabled in source)"
+
+ def test_bias_values_preserved(
+ self,
+ base_config_with_bias_dict,
+ additive_surgery_chain,
+ ):
+ """Bias tensor values are correctly transferred, not just keys."""
+ surgery = additive_surgery_chain[0] # wrap_stochastic
+ c1 = compose_configs(base_config_with_bias_dict, surgery)
+ plan = plan_surgery(base_config_with_bias_dict, c1)
+
+ weights = make_weights_for_config(base_config_with_bias_dict)
+ result = execute(plan, weights, seed=42)
+
+ # Verify values match (not just that keys exist)
+ for i in range(base_config_with_bias_dict["decoder"]["num_blocks"]):
+ src_key = W(f"model.decoder.blocks.{i}.mixer.q_proj.bias")
+ dst_key = W(f"model.decoder.blocks.{i}.mixer.mixers.attention.q_proj.bias")
+
+ assert dst_key in result, f"Missing {dst_key}"
+ assert torch.allclose(weights[src_key], result[dst_key]), f"Bias values differ for block {i}"
+
+
+# =============================================================================
+# Test: build_plan Integration (Regression test for convert.py)
+# =============================================================================
+
+
+class TestBuildPlanIntegration:
+ """
+ REGRESSION: build_plan must compose configs before calling plan_surgery.
+
+ The bug was:
+ plan_surgery(current_config, surgery_config) # WRONG: partial
+
+ Should be:
+ target = compose_configs(current_config, surgery_config)
+ plan_surgery(current_config, target) # CORRECT: complete
+
+ This test verifies the fix in convert.py's build_plan function.
+ """
+
+ @pytest.mark.parametrize("num_surgeries", [1, 2])
+ def test_build_plan_preserves_inherited_fields(
+ self,
+ num_surgeries,
+ base_config_with_bias_dict,
+ additive_surgery_chain,
+ ):
+ """build_plan produces plans with inherited bias mappings."""
+ from fast_llm_external_models.apriel2.convert import build_plan
+
+ surgeries = additive_surgery_chain[:num_surgeries]
+
+ plan, final_config = build_plan(
+ base_config_with_bias_dict,
+ surgeries,
+ source_format="apriel2",
+ )
+
+ # Verify inherited biases in config
+ if num_surgeries >= 1:
+ attn = final_config["decoder"]["block"]["mixer"]["mixers"]["attention"]
+ assert attn.get("query_layer", {}).get("bias", {}).get("enabled") is True
+
+ # Verify bias mappings in plan
+ target_keys = {str(k) for k in plan.target_keys()}
+ assert any(
+ "q_proj.bias" in k for k in target_keys
+ ), f"build_plan with {num_surgeries} surgeries missing q_proj.bias"
+
+
+# =============================================================================
+# Test: init Field Preservation (Critical for random initialization)
+# =============================================================================
+
+
+class TestInitFieldPreservation:
+ """
+ PROPERTY: The `init` field must be visible to plan_surgery.
+
+ The `init` field controls weight initialization mode:
+ - `init: transfer` → use weight transfer/conversion
+ - `init: random` → use random initialization
+
+ compose_configs must preserve `init` so plan_surgery can see it.
+ Stripping happens only at final output (when saving to disk).
+ """
+
+ def test_init_random_produces_init_expression(self, base_config_with_bias_dict):
+ """Surgery with init: random produces Init expressions in plan."""
+ surgery = {
+ "decoder": {
+ "block": {
+ "mixer": {
+ "type": "stochastic",
+ "main_mixer_name": "attention",
+ "mixers": {
+ "attention": {"init": "transfer"},
+ "gdn": {"type": "gdn", "init": "random", "convolution_layer": {"kernel_size": 4}},
+ },
+ },
+ },
+ },
+ }
+
+ target = compose_configs(base_config_with_bias_dict, surgery)
+ plan = plan_surgery(base_config_with_bias_dict, target)
+
+ # Check that GDN weights use Init expressions (random init)
+ target_keys = {str(k) for k in plan.target_keys()}
+ gdn_keys = [k for k in target_keys if "gdn" in k.lower()]
+
+ assert len(gdn_keys) > 0, "No GDN keys in plan"
+
+ # Verify at least one GDN weight uses Init (random initialization)
+ has_init_expr = False
+ for key in plan.target_keys():
+ if "gdn" in str(key).lower():
+ expr = plan.mappings[key]
+ if isinstance(expr, Init):
+ has_init_expr = True
+ break
+ # Also check inside Concat/other composite expressions
+ if hasattr(expr, "exprs"):
+ for sub in expr.exprs:
+ if isinstance(sub, Init):
+ has_init_expr = True
+ break
+
+ assert has_init_expr, "init: random should produce Init expressions for GDN weights"
+
+ def test_init_transfer_produces_ref_expression(self, base_config_with_bias_dict):
+ """Surgery with init: transfer produces Ref expressions (weight transfer)."""
+ surgery = {
+ "decoder": {
+ "block": {
+ "mixer": {
+ "type": "stochastic",
+ "main_mixer_name": "attention",
+ "mixers": {
+ "attention": {"init": "transfer"},
+ },
+ },
+ },
+ },
+ }
+
+ target = compose_configs(base_config_with_bias_dict, surgery)
+ plan = plan_surgery(base_config_with_bias_dict, target)
+
+ # Check that attention weights use Ref expressions (transfer)
+ has_ref = False
+ for key in plan.target_keys():
+ if "attention" in str(key) and "q_proj.weight" in str(key):
+ expr = plan.mappings[key]
+ if isinstance(expr, Ref):
+ has_ref = True
+ break
+
+ assert has_ref, "init: transfer should produce Ref expressions for attention weights"
+
+ def test_build_plan_respects_init_random(self, base_config_with_bias_dict):
+ """build_plan correctly uses init: random for weight initialization."""
+ from fast_llm_external_models.apriel2.convert import build_plan
+
+ # Mamba requires many config fields for random init
+ surgery = {
+ "decoder": {
+ "block": {
+ "mixer": {
+ "type": "stochastic",
+ "main_mixer_name": "attention",
+ "mixers": {
+ "attention": {"init": "transfer"},
+ "mamba": {
+ "type": "mamba",
+ "init": "random",
+ "d_inner": 512,
+ "d_state": 16,
+ "dt_rank": 16,
+ "d_xb": 64,
+ "d_conv": 4,
+ "repeat_kv_before_conv": False,
+ "conv_bias": True,
+ "dt_proj_bias": True,
+ "dt_min": 0.001,
+ "dt_max": 0.1,
+ "dt_init_floor": 1e-4,
+ },
+ },
+ },
+ },
+ },
+ }
+
+ plan, final_config = build_plan(
+ base_config_with_bias_dict,
+ [surgery],
+ source_format="apriel2",
+ )
+
+ # Verify mamba weights use Init (random init)
+ has_mamba_init = False
+ for key in plan.target_keys():
+ key_str = str(key)
+ if "mamba" in key_str:
+ expr = plan.mappings[key]
+ if isinstance(expr, Init):
+ has_mamba_init = True
+ break
+
+ assert has_mamba_init, "build_plan should use Init for init: random components"
+
+ def test_build_plan_strips_init_between_iterations(self, base_config_with_bias_dict):
+ """build_plan strips init between iterations (T → S conversion).
+
+ This tests that the intermediate state between surgeries has no init fields.
+ The composed plan will show Init expressions because plan composition
+ substitutes Ref → Init, but the semantics are correct: GDN is initialized
+ once (in surgery 1), not re-randomized in surgery 2.
+ """
+ from fast_llm_external_models.apriel2.conversion import compose_configs, plan_surgery, strip_init_fields
+
+ # Surgery 1: Add GDN with random init
+ surgery1 = {
+ "decoder": {
+ "block": {
+ "mixer": {
+ "type": "stochastic",
+ "main_mixer_name": "attention",
+ "mixers": {
+ "attention": {"init": "transfer"},
+ "gdn": {
+ "type": "gdn",
+ "init": "random",
+ "convolution_layer": {"kernel_size": 4},
+ },
+ },
+ },
+ },
+ },
+ }
+
+ # Surgery 2: Add sliding window (doesn't mention GDN)
+ surgery2 = {
+ "decoder": {
+ "block": {
+ "mixer": {
+ "mixers": {
+ "sliding_window": {"init": "transfer", "window_size": 512},
+ },
+ },
+ },
+ },
+ }
+
+ # Simulate build_plan's iteration loop
+ s0 = base_config_with_bias_dict
+
+ # Iteration 1
+ t1 = compose_configs(s0, surgery1)
+ assert t1["decoder"]["block"]["mixer"]["mixers"]["gdn"].get("init") == "random"
+ s1 = strip_init_fields(t1)
+ assert s1["decoder"]["block"]["mixer"]["mixers"]["gdn"].get("init") is None
+
+ # Iteration 2: s1 has no init for GDN
+ t2 = compose_configs(s1, surgery2)
+ assert (
+ t2["decoder"]["block"]["mixer"]["mixers"]["gdn"].get("init") is None
+ ), "GDN should have no init in T2 (wasn't in surgery2, stripped from s1)"
+
+ # plan_surgery(s1, t2) should use Ref for GDN (transfer, not random)
+ plan2 = plan_surgery(s1, t2)
+ gdn_uses_ref = False
+ for key in plan2.target_keys():
+ if "gdn" in str(key):
+ expr = plan2.mappings[key]
+ if isinstance(expr, Ref):
+ gdn_uses_ref = True
+ break
+
+ assert gdn_uses_ref, "plan_surgery(s1, t2) should use Ref for GDN (transfer from s1)"
diff --git a/setup.cfg b/setup.cfg
index 34995ce96..f4ad02c43 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -61,8 +61,7 @@ GENERATION =
lm_eval>=0.4.9
STREAMING =
- redis>=-7.1.0
- orjson>=3.11.5
+ redis>=7.1.0
# Required for supporting vision inputs
VISION =
diff --git a/setup.py b/setup.py
index b273e077e..5c4d0def6 100644
--- a/setup.py
+++ b/setup.py
@@ -1,6 +1,6 @@
-import sys
-import re
import pathlib
+import re
+import sys
try:
import pybind11
@@ -18,6 +18,7 @@
print(f"Error: setuptools version {_SETUPTOOLS_MIN_VERSION} " "or greater is required")
sys.exit(1)
+
def get_version():
"""Read version from fast_llm/__init__.py"""
init_file = pathlib.Path(__file__).parent.joinpath("fast_llm", "__init__.py").read_text()
@@ -26,6 +27,7 @@ def get_version():
return version_match.group(1)
raise RuntimeError("Unable to find version string in fast_llm/__init__.py")
+
cpp_extension = setuptools.Extension(
"fast_llm.csrc.data",
sources=["fast_llm/csrc/data.cpp"],
diff --git a/tests/conftest.py b/tests/conftest.py
index df56c78ab..e3a2df9a3 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -27,14 +27,13 @@
from tests.utils.run_test_script import ( # isort: skip
compare_results_for_all_models,
run_distributed_script,
- run_distributed_script_lean,
run_test_script_base_path,
run_test_script_for_all_models,
)
-from tests.utils.redis import fake_redis_server, stream_config # isort: skip
from tests.utils.model_configs import model_testing_config, ModelTestingConfig, testing_group_enabled # isort: skip
-from tests.utils.utils import result_path, format_resource_report, report_subtest # isort: skip
+from tests.utils.utils import result_path # isort: skip
+from tests.utils.subtest import format_resource_report, report_subtest, run_parallel_script # isort: skip
# Import all dynamic classes.
import fast_llm.cli # isort: skip
@@ -49,7 +48,18 @@ def pytest_addoption(parser):
group = parser.getgroup("fast_llm")
group.addoption("--skip-slow", action="store_true")
group.addoption("--show-skipped", action="store_true")
- group.addoption("--show-gpu-memory", type=int, default=10)
+ group.addoption(
+ "--show-gpu-memory",
+ type=int,
+ default=10,
+ help="Show resource usage stats for the tests and distributed subtests with the highest GPU memory usage.",
+ )
+ group.addoption(
+ "--show-durations",
+ type=int,
+ default=None,
+ help="Show resource usage stats for the slowest tests and distributed subtests.",
+ )
group.addoption("--no-distributed-capture", dest="distributed_capture", action="store_false")
group.addoption("--models", nargs="*")
group.addoption(
@@ -231,6 +241,16 @@ def pytest_terminal_summary(terminalreporter):
for nodeid in sorted_nodeids[: terminalreporter.config.getoption("--show-gpu-memory")]:
terminalreporter.write_line(format_resource_report(nodeid, resource_reports[nodeid]))
+ if (show_durations := terminalreporter.config.getoption("--show-durations")) is not None:
+ terminalreporter.write_sep("=", "Highest durations", bold=True)
+ sorted_nodeids = sorted(
+ resource_reports.keys(),
+ key=lambda nodeid: (resource_reports[nodeid]["duration"] if "duration" in resource_reports[nodeid] else 0),
+ reverse=True,
+ )
+ for nodeid in sorted_nodeids[:show_durations]:
+ terminalreporter.write_line(format_resource_report(nodeid, resource_reports[nodeid]))
+
def pytest_runtest_call(item: pytest.Function):
if torch.cuda.is_available():
diff --git a/tests/data/gptdata_streaming_test.py b/tests/data/gptdata_streaming_test.py
deleted file mode 100644
index 3e388cc45..000000000
--- a/tests/data/gptdata_streaming_test.py
+++ /dev/null
@@ -1,115 +0,0 @@
-import argparse
-import pathlib
-import pickle
-
-from fast_llm.config import NoAutoValidate
-from fast_llm.data.data.gpt.config import GPTDataConfig
-from fast_llm.data.data.gpt.data import GPTData
-from fast_llm.data.dataset.config import IngestionType
-from fast_llm.engine.distributed.config import DistributedConfig
-from fast_llm.engine.distributed.distributed import Distributed
-from fast_llm.models.gpt.config import GPTBatchConfig
-from tests.utils.redis import get_stream_config, make_sampling
-
-
-def distributed_gptdata_streaming_test(
- sequence_length,
- micro_batch_size,
- batch_size,
- tensor_parallel,
- pipeline_parallel,
- sequence_data_parallel,
- total_gpus,
- redis_port,
- result_path,
- ingestion_type,
-):
- stream_config = get_stream_config()
- stream_config = stream_config.from_dict(
- stream_config.to_dict(), {("redis", "port"): redis_port, ("ingestion_type"): ingestion_type}
- )
-
- distributed = Distributed(
- DistributedConfig(
- tensor_parallel=tensor_parallel,
- pipeline_parallel=pipeline_parallel,
- sequence_data_parallel=sequence_data_parallel,
- ),
- use_cpu=total_gpus == 0,
- )
- sampling_data = make_sampling(sequence_length, 0, micro_batch_size, distributed)
-
- data_config = {"datasets": {"streaming1": stream_config.to_dict()}, "sampling": {"shuffle": "disabled"}}
- data_config = GPTDataConfig.from_dict(data_config)
-
- data = GPTData(data_config, distributed.config)
-
- data.setup(
- distributed=distributed,
- sampling_parameters={"streaming1": sampling_data.parameters},
- preprocessing={},
- cache_directory="/tmp",
- )
-
- with NoAutoValidate():
- batch_config = GPTBatchConfig(
- micro_batch_size=micro_batch_size, batch_size=batch_size, sequence_length=sequence_length
- )
- batch_config.setup(distributed_config=distributed.config)
- batch_config.validate()
-
- data_iter = data.get_iterator(batch_config, "streaming1", consumed_samples=0, num_workers=1, prefetch_factor=1)
-
- batch = next(data_iter)
- # TODO: save result per batch_data_group and rank
- assert batch.tokens.tokens.shape == (micro_batch_size, sequence_length)
-
- result_path = (
- pathlib.Path(result_path)
- / (
- f"{distributed.config.batch_data_rank}_"
- f"{distributed.model_and_sequence_data_group.rank() if distributed.model_and_sequence_data_group is not None else 0}"
- )
- / "batch.pkl"
- )
- result_path.parent.mkdir(exist_ok=True, parents=True)
- with result_path.open("wb") as f:
- pickle.dump(batch, f)
-
-
-def parse_args():
- parser = argparse.ArgumentParser(description="Run distributed GPT data streaming test.")
-
- parser.add_argument("--sequence-length", type=int, required=True, help="Sequence length of the model input.")
- parser.add_argument("--micro-batch-size", type=int, required=True, help="Micro batch size.")
- parser.add_argument("--batch-size", type=int, required=True, help="Global batch size.")
- parser.add_argument("--tensor-parallel", type=int, required=True, help="Tensor parallel degree.")
- parser.add_argument("--pipeline-parallel", type=int, required=True, help="Pipeline parallel degree.")
- parser.add_argument("--sequence-data-parallel", type=int, required=True, help="Sequence data parallel degree.")
- parser.add_argument("--total-gpus", type=int, required=True, help="Total number of GPUs available.")
- parser.add_argument("--redis-port", type=int, required=True, help="Redis port to connect to.")
- parser.add_argument("--result-path", type=str, required=True, help="Path to save test results.")
- parser.add_argument("--ingestion-type", type=str, required=True, help="Ingestion type used in streaming dataset.")
-
- return parser.parse_args()
-
-
-def main():
- args = parse_args()
-
- distributed_gptdata_streaming_test(
- sequence_length=args.sequence_length,
- micro_batch_size=args.micro_batch_size,
- batch_size=args.batch_size,
- tensor_parallel=args.tensor_parallel,
- pipeline_parallel=args.pipeline_parallel,
- sequence_data_parallel=args.sequence_data_parallel,
- total_gpus=args.total_gpus,
- redis_port=args.redis_port,
- result_path=args.result_path,
- ingestion_type=IngestionType(args.ingestion_type),
- )
-
-
-if __name__ == "__main__":
- main()
diff --git a/tests/data/test_streaming.py b/tests/data/test_streaming.py
index a0bfae316..d8953488f 100644
--- a/tests/data/test_streaming.py
+++ b/tests/data/test_streaming.py
@@ -1,399 +1,241 @@
+import contextlib
import logging
-import os
-import pickle
+import pathlib
+import typing
import fakeredis
import pytest
+import redis
import torch
-from fast_llm.data.dataset.config import IngestionType
-from fast_llm.data.dataset.streaming import StreamingDataset
+from fast_llm.config import NoAutoValidate
+from fast_llm.core.distributed import safe_barrier
+from fast_llm.data.data.gpt.config import GPTDataConfig
+from fast_llm.data.data.gpt.data import GPTData
+from fast_llm.data.dataset.config import RedisConfig, SamplingParameters, StreamingDatasetConfig
+from fast_llm.data.dataset.streaming import RedisStreamingDataset
+from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig
from fast_llm.data.sample.language_model import LanguageModelSample
-from fast_llm.engine.distributed.config import DistributedConfig
+from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames
from fast_llm.engine.distributed.distributed import Distributed
+from fast_llm.models.gpt.config import GPTBatchConfig
+from fast_llm.utils import Assert
+from tests.conftest import WorkerResources
from tests.utils.redis import make_sampling, push_msg, redis_batch_producer
+from tests.utils.subtest import DistributedTestContext
from tests.utils.utils import requires_cuda
logger = logging.getLogger(__name__)
-# ---------------------------------------------------------------------
-# Fixtures
-# ---------------------------------------------------------------------
-
-
-@pytest.fixture
-def fake_redis():
- """Return a FakeRedis instance."""
- return fakeredis.FakeRedis()
-
-
@pytest.fixture
-def monkeypatched_redis(monkeypatch, fake_redis):
- """Monkeypatch redis.Redis globally (works even for imports inside functions)."""
- import redis
-
+def fake_redis(monkeypatch):
+ """Monkeypatch redis.Redis globally."""
+ fake_redis = fakeredis.FakeRedis()
monkeypatch.setattr(redis, "Redis", lambda *args, **kwargs: fake_redis)
- return fake_redis
-
-
-# ---------------------------------------------------------------------
-# Helpers
-# ---------------------------------------------------------------------
-
-
-def generate_parallelism_variants(total_gpus: int):
- """
- Generate all valid variants of (data_groups, tensor_parallel, pipeline_parallel, sequence_parallel)
- for a number of GPUs up to the total_gpus.
- If total_gpus is odd and > 1, fallback to nearest lower even number for decomposable parallelism.
- """
- if total_gpus > 1 and total_gpus % 2 == 1:
- total_gpus = total_gpus - 1
-
- if total_gpus < 2:
- # No gpu and one gpu tests are the same,
- # so no need of creation of variant for a single gpu
- return []
-
- variants = []
-
- for gpus in range(2, total_gpus + 1, 2):
- # try all possible numbers of data groups (1..total_gpus)
- for data_groups in range(1, gpus + 1):
- if gpus % data_groups != 0:
- continue # cannot evenly split
-
- gpus_per_group = gpus // data_groups
-
- # now find all decompositions of gpus_per_group into tp*pp*sp
- for tp in range(1, gpus_per_group + 1):
- if gpus_per_group % tp != 0:
- continue
- rem_after_tp = gpus_per_group // tp
- # TODO: currently streaming dataset does not support pipeline parallel setup
- # for pp in range(1, rem_after_tp + 1):
- for pp in range(1, 2):
- if rem_after_tp % pp != 0:
- continue
- sp = rem_after_tp // pp
- try:
- # instead of repeating all safeguards here just try to
- # instantiate distributed config to check if combination is valid
- dist_config = DistributedConfig(
- tensor_parallel=tp,
- pipeline_parallel=pp,
- sequence_data_parallel=sp,
- world_size=gpus,
- # TODO: works only on one node
- local_world_size=gpus,
- rank=0,
- )
- except Exception:
- continue
-
- variants.append(
- {
- "data_groups": data_groups,
- "batch_data_parallel": dist_config.batch_data_parallel,
- "tensor_parallel": tp,
- "pipeline_parallel": pp,
- "sequence_data_parallel": sp,
- "total_gpus": gpus,
- }
- )
- return variants
-
-
-def run_distributed_gptdata_streaming_test(
- fake_redis_server,
- variant,
- run_distributed_script,
- result_path,
- request,
- ingestion_type: IngestionType,
-):
- import tests.data.gptdata_streaming_test
-
- stream_config, fake_redis, fake_redis_server_killer = fake_redis_server
- stream_config = stream_config.from_dict(stream_config.to_dict(), {("ingestion_type"): ingestion_type})
-
- sequence_length = 10
- micro_batch_size = 2
- batch_size = micro_batch_size * variant["batch_data_parallel"]
- tensor_parallel = variant["tensor_parallel"]
- pipeline_parallel = variant["pipeline_parallel"]
- sequence_data_parallel = variant["sequence_data_parallel"]
- total_gpus = variant["total_gpus"]
- redis_port = stream_config.redis.port
-
- result_path = result_path / "distributed_gptdata_streaming_test" / request.node.name
-
- with redis_batch_producer(
- redis_client=fake_redis,
- fake_redis_server_killer=fake_redis_server_killer,
- stream_config=stream_config,
- batch_size=batch_size,
- sequence_length=10,
- ):
- if total_gpus > 0:
- script = [
- "-m",
- tests.data.gptdata_streaming_test.__name__,
- "--sequence-length",
- str(sequence_length),
- "--micro-batch-size",
- str(micro_batch_size),
- "--batch-size",
- str(batch_size),
- "--tensor-parallel",
- str(tensor_parallel),
- "--pipeline-parallel",
- str(pipeline_parallel),
- "--sequence-data-parallel",
- str(sequence_data_parallel),
- "--total-gpus",
- str(total_gpus),
- "--result-path",
- str(result_path),
- "--redis-port",
- str(redis_port),
- "--ingestion-type",
- str(ingestion_type.value),
- ]
- # TODO: distributed_capture is ignored now inside the script
- if request.config.getoption("distributed_capture"):
- logger.warning(
- "Capturing output and forwarding to associated tests. Run with `--no-distributed-capture` to disable."
- )
- else:
- script.append("--no-distributed-capture")
-
- env = os.environ.copy()
- env["PYTHONHASHSEED"] = "42"
- run_distributed_script(script, num_gpus=total_gpus, env=env)
- else:
- tests.data.gptdata_streaming_test.distributed_gptdata_streaming_test(
- sequence_length=sequence_length,
- micro_batch_size=micro_batch_size,
- batch_size=batch_size,
- tensor_parallel=tensor_parallel,
- pipeline_parallel=pipeline_parallel,
- sequence_data_parallel=sequence_data_parallel,
- total_gpus=total_gpus,
- redis_port=redis_port,
- result_path=result_path,
- ingestion_type=ingestion_type,
- )
-
- check_distributed_gptdata_streaming_test_results(
- result_path=result_path,
- micro_batch_size=micro_batch_size,
- batch_data_parallel=variant["batch_data_parallel"],
- total_gpu=variant["total_gpus"],
- )
+ try:
+ yield fake_redis
+ finally:
+ fake_redis.close()
-def check_distributed_gptdata_streaming_test_results(
- result_path,
- micro_batch_size,
- batch_data_parallel,
- total_gpu,
+@pytest.mark.parametrize(
+ "messages",
+ [
+ (range(3),),
+ (range(3), range(3, 7)),
+ (range(3), range(5), [], [9, 4]),
+ ],
+)
+def test_streaming_dataset(
+ fake_redis: fakeredis.FakeRedis,
+ messages: tuple[list[int], ...],
+ worker_resources: WorkerResources,
):
- batch_data_parallel_size = total_gpu // batch_data_parallel if total_gpu > 0 else 1
- sample_idx = set()
- for i in range(batch_data_parallel):
- ref_batch = None
- for j in range(batch_data_parallel_size):
- with (result_path / f"{i}_{j}" / "batch.pkl").open("rb") as f:
- batch = pickle.load(f)
- if ref_batch is None:
- ref_batch = batch
- else:
- # batches for same batch_data_parallel_group must be equal on all ranks
- assert torch.equal(batch.tokens.tokens, ref_batch.tokens.tokens)
- for j in range(micro_batch_size):
- val = ref_batch.tokens.tokens[j, 0].item()
- # all samples in batches between groups and in the batch must be unique
- assert val not in sample_idx
- sample_idx.add(val)
- # unique sample count must be the same as global batch size
- assert len(sample_idx) == micro_batch_size * batch_data_parallel
-
-
-# ---------------------------------------------------------------------
-# Tests
-# ---------------------------------------------------------------------
-
-
-def test_streaming_dataset_reads_single_message(monkeypatched_redis, stream_config):
"""StreamingDataset should read a message and convert it into LanguageModelSample."""
- fake_redis = monkeypatched_redis
-
- distributed = Distributed(DistributedConfig(), use_cpu=True)
- dataset = StreamingDataset(stream_config, distributed)
-
- # Insert a message
- push_msg(fake_redis, stream_config, [1, 2, 3])
-
- it = iter(dataset)
- sample = next(it)
-
- assert isinstance(sample, LanguageModelSample)
- assert torch.equal(sample.tokens.tokens, torch.tensor([1, 2, 3], dtype=torch.int64))
- assert sample.tokens.lengths == [3]
- assert sample.loss_masking_spans is None
- assert sample.chosen_spans is None
- assert sample.rejected_spans is None
+ stream_config = StreamingDatasetConfig(port=worker_resources.torchrun_port)
+ dataset_iterator = iter(RedisStreamingDataset(stream_config, DistributedConfig()))
+ for message in messages:
+ push_msg(fake_redis, list(message))
+ for message in messages:
+ sample = next(dataset_iterator)
+ assert isinstance(sample, LanguageModelSample)
+ Assert.eq(sample.tokens.tokens.tolist(), list(message))
+ Assert.eq(sample.tokens.lengths, [len(message)])
+ assert sample.loss_masking_spans is None
+ assert sample.chosen_spans is None
+ assert sample.rejected_spans is None
-def test_streaming_dataset_reads_multiple_messages(monkeypatched_redis, stream_config):
+@pytest.mark.parametrize(
+ ("messages", "expected_samples", "expected_lengths"),
+ [
+ ((range(5),), (range(5),), ([5],)), # Single message, exact fit.
+ ((range(3), [3, 4]), (range(5),), ([3, 2],)), # Two messages, exact fit.
+ ((range(6), range(5)), (range(5),), ([5],)), # Two messages, one dropped.
+ (
+ (range(3), range(5)),
+ (
+ [0, 1, 2, -100, -100],
+ range(5),
+ ),
+ (
+ [3, 2],
+ [5],
+ ),
+ ), # Two messages, one padded.
+ ],
+)
+def test_streaming_sampled_dataset(
+ fake_redis: fakeredis.FakeRedis,
+ messages: tuple[list[int], ...],
+ expected_samples: tuple[list[int], ...],
+ expected_lengths: tuple[int, ...],
+ worker_resources: WorkerResources,
+):
"""StreamingDataset should read a message and convert it into LanguageModelSample."""
- fake_redis = monkeypatched_redis
-
+ stream_config = StreamingDatasetConfig(port=worker_resources.torchrun_port)
distributed = Distributed(DistributedConfig(), use_cpu=True)
- dataset = StreamingDataset(stream_config, distributed)
-
- # Insert a message
- push_msg(fake_redis, stream_config, [1, 2, 3])
- push_msg(fake_redis, stream_config, [1, 2, 3])
- push_msg(fake_redis, stream_config, [1, 2, 3])
-
- it = iter(dataset)
- for i in range(3):
- sample = next(it)
-
+ dataset_iterator = iter(
+ RedisStreamingDataset(stream_config, distributed.config).sample(make_sampling(5, 1, distributed))
+ )
+ for message in messages:
+ push_msg(fake_redis, list(message))
+ for expected_sample, expected_lengths_ in zip(expected_samples, expected_lengths, strict=True):
+ sample = next(dataset_iterator)
assert isinstance(sample, LanguageModelSample)
- assert torch.equal(sample.tokens.tokens, torch.tensor([1, 2, 3], dtype=torch.int64))
- assert sample.tokens.lengths == [3]
+ Assert.eq(sample.tokens.tokens.tolist(), list(expected_sample))
+ Assert.eq(sample.tokens.lengths, expected_lengths_)
assert sample.loss_masking_spans is None
assert sample.chosen_spans is None
assert sample.rejected_spans is None
-def test_sampling_1_doc_exact_fit(monkeypatched_redis, stream_config):
- """Docs exactly fill one sample."""
- fake_redis = monkeypatched_redis
-
- push_msg(fake_redis, stream_config, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
-
- distributed = Distributed(DistributedConfig(), use_cpu=True)
- sampler = StreamingDataset(stream_config, distributed).sample(make_sampling(10, 0, 1, distributed))
-
- out = next(iter(sampler))
-
- assert isinstance(out, LanguageModelSample)
- assert len(out) == 10
- assert out.tokens.tokens.tolist() == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
-
-
-def test_sampling_2_docs_exact_fit(monkeypatched_redis, stream_config):
- """Docs exactly fill one sample."""
- fake_redis = monkeypatched_redis
-
- # Two rollouts: lengths 4 and 6 -> exactly 10
- push_msg(fake_redis, stream_config, [1, 2, 3, 4])
- push_msg(fake_redis, stream_config, [5, 6, 7, 8, 9, 10])
-
- distributed = Distributed(DistributedConfig(), use_cpu=True)
- sampler = StreamingDataset(stream_config, distributed).sample(make_sampling(10, 0, 1, distributed))
-
- out = next(iter(sampler))
-
- assert isinstance(out, LanguageModelSample)
- assert len(out) == 10
- assert out.tokens.tokens.tolist() == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
-
-
-def test_sampling_skips_too_long_doc_and_padding_final(monkeypatched_redis, stream_config):
- """Rollout longer than sample_length must be dropped."""
- fake_redis = monkeypatched_redis
-
- push_msg(fake_redis, stream_config, list(range(20))) # skip: too long
- push_msg(fake_redis, stream_config, list(range(10))) # usable
-
- distributed = Distributed(DistributedConfig(), use_cpu=True)
- sampler = StreamingDataset(stream_config, distributed).sample(make_sampling(10, 0, 1, distributed))
+_NUM_BATCHES = 1
- out = next(iter(sampler))
- # too big message is skipped and next message is returned instead
- assert len(out) == 10
- assert out.tokens.tokens.tolist() == list(range(10))
+def _get_distributed_and_batch_config(
+ distributed_config_dict: dict[str, typing.Any], world_size: int = 1
+) -> tuple[DistributedConfig, GPTBatchConfig]:
+ distributed_config = DistributedConfig.from_dict(
+ distributed_config_dict, {"world_size": world_size, "local_world_size": world_size}
+ )
+ with NoAutoValidate():
+ batch_config = GPTBatchConfig(micro_batch_size=2, sequence_length=10)
+ batch_config.setup(distributed_config=distributed_config)
+ batch_config.validate()
+ return distributed_config, batch_config
-def test_sampling_overflow_creates_two(monkeypatched_redis, stream_config):
- """A document overflowing the boundary triggers padding + next sample."""
- fake_redis = monkeypatched_redis
+def _run_test_data_streaming(
+ path: pathlib.Path, distributed_config: DistributedConfig, batch_config: GPTBatchConfig, port: int
+):
+ redis_config = RedisConfig(port=port + 100)
- push_msg(fake_redis, stream_config, list(range(6)))
- push_msg(fake_redis, stream_config, list(range(10)))
+ data = GPTData(GPTDataConfig(datasets={"train": {"type": "streaming", "port": port + 100}}), distributed_config)
+ distributed = Distributed(distributed_config)
+ with (
+ redis_batch_producer(redis_config, batch_config) if distributed_config.rank == 0 else contextlib.nullcontext()
+ ):
+ data.setup(
+ distributed=distributed,
+ sampling_parameters={
+ "train": SamplingParameters(
+ sequence_length=batch_config.sequence_length,
+ extra_tokens=0,
+ num_samples=batch_config.batch_size * _NUM_BATCHES,
+ truncate_documents=False,
+ )
+ },
+ preprocessing=LanguageModelPreprocessingConfig(),
+ cache_directory=path / "cache",
+ timeout=5,
+ )
+
+ data_iter = data.get_iterator(batch_config, "train", consumed_samples=0, num_workers=0, prefetch_factor=None)
+ batches = [next(data_iter) for _ in range(_NUM_BATCHES)]
+ path.mkdir(parents=True, exist_ok=True)
+ torch.save(
+ torch.stack([batch.tokens.tokens[:, 0] for batch in batches]),
+ path / f"rank_{distributed_config.batch_data_rank}_"
+ f"{distributed_config.get_distributed_dim(DistributedDimNames.model_and_sequence_data).rank}.pt",
+ )
+ # Wait for other processes to finish before shutting down the server.
+ safe_barrier(distributed.world_group, "streaming test end")
+
+
+def check_data_streaming_results(
+ path: pathlib.Path,
+ distributed_config: DistributedConfig,
+ batch_config: GPTBatchConfig,
+):
+ sample_indexes = set()
+ for batch_data_rank in range(distributed_config.batch_data_parallel):
+ batches_tokens = torch.load(path / f"rank_{batch_data_rank}_0.pt")
+ Assert.eq(batches_tokens.shape, (_NUM_BATCHES, batch_config.micro_batch_size))
+ for model_and_sequence_data_rank in range(
+ 1, distributed_config.get_distributed_dim(DistributedDimNames.model_and_sequence_data).size
+ ):
+ Assert.all_equal(
+ torch.load(path / f"rank_{batch_data_rank}_{model_and_sequence_data_rank}.pt"), batches_tokens
+ )
+ sample_indexes.update(batches_tokens.flatten().tolist())
+ Assert.eq(len(sample_indexes), _NUM_BATCHES * batch_config.batch_size)
- distributed = Distributed(DistributedConfig(), use_cpu=True)
- sampler = StreamingDataset(stream_config, distributed).sample(make_sampling(10, 0, 2, distributed))
- sampler_iter = iter(sampler)
- out = [next(sampler_iter)]
- out.append(next(sampler_iter))
+def _run_test_data_streaming_distributed(
+ test_context: DistributedTestContext, base_path: pathlib.Path, port: int
+) -> None:
+ # Import all dynamic classes. TODO: needed?
+ import fast_llm.cli # noqa
- # sample 1: 0..5 + pad(4)
- assert out[0].tokens.tokens.tolist() == list(range(6)) + [-100, -100, -100, -100]
+ for name, num_gpus, distributed_config_dict in _DISTRIBUTED_TESTING_CONFIGS:
+ with test_context.subtest(base_path, name, num_gpus) as subtest:
+ if subtest.do_run:
+ distributed_config, batch_config = _get_distributed_and_batch_config(distributed_config_dict, num_gpus)
+ _run_test_data_streaming(base_path / name, distributed_config, batch_config, port)
- # sample 2: 0..5 + pad(4)
- assert out[1].tokens.tokens.tolist() == list(range(10))
+@requires_cuda
+def test_data_streaming(result_path, worker_resources):
+ distributed_config, batch_config = _get_distributed_and_batch_config({})
+ path = result_path / "data_streaming/single_gpu"
+ _run_test_data_streaming(path, distributed_config, batch_config, worker_resources.torchrun_port)
+ check_data_streaming_results(path, distributed_config, batch_config)
+
+
+_DISTRIBUTED_TESTING_CONFIGS = [
+ ("dp2", 2, {}),
+ ("sdp2", 2, {"sequence_data_parallel": 2}),
+ ("tp2", 2, {"tensor_parallel": 2}),
+ ("pp2", 2, {"pipeline_parallel": 2}),
+ ("dp2_sdp2", 4, {"sequence_data_parallel": 2}),
+ ("dp2_tp2", 4, {"tensor_parallel": 2}),
+ ("dp2_pp2", 4, {"pipeline_parallel": 2}),
+ ("sdp2_tp2", 4, {"sequence_data_parallel": 2, "tensor_parallel": 2}),
+ ("sdp2_pp2", 4, {"sequence_data_parallel": 2, "pipeline_parallel": 2}),
+ ("tp2_pp2", 4, {"tensor_parallel": 2, "pipeline_parallel": 2}),
+]
-@pytest.mark.parametrize(
- "ingestion_type",
- [
- IngestionType.CONSUMER_GROUP,
- # TODO: need to implement wait_until_stream_empty for those variants on test side to enable tests for them
- # IngestionType.ONE_STREAM,
- # IngestionType.N_STREAMS,
- ],
-)
-def test_gptdata_streaming_single_consumer(
- fake_redis_server, run_distributed_script_lean, ingestion_type, result_path, request
-):
- run_distributed_gptdata_streaming_test(
- fake_redis_server=fake_redis_server,
- variant={
- "data_groups": 1,
- "tensor_parallel": 1,
- "pipeline_parallel": 1,
- "sequence_data_parallel": 1,
- "total_gpus": 0,
- "batch_data_parallel": 1,
- },
- run_distributed_script=run_distributed_script_lean,
- result_path=result_path,
- request=request,
- ingestion_type=ingestion_type,
+@requires_cuda
+@pytest.mark.slow
+@pytest.mark.depends_on(on=["test_data_streaming"])
+def test_run_data_streaming_distributed(run_parallel_script, result_path, worker_resources):
+ if torch.cuda.device_count() < 2:
+ pytest.skip(f"Not enough GPUs")
+ run_parallel_script(
+ _run_test_data_streaming_distributed,
+ (result_path / "data_streaming", worker_resources.torchrun_port),
+ world_size=torch.cuda.device_count(),
)
-variants = generate_parallelism_variants(torch.cuda.device_count())
-
-
-@pytest.mark.slow
@requires_cuda
-@pytest.mark.parametrize(
- "variant",
- variants,
- ids=[
- f"dg{v['data_groups']}_tp{v['tensor_parallel']}_pp{v['pipeline_parallel']}_sp{v['sequence_data_parallel']}_gpu{v['total_gpus']}"
- for v in variants
- ],
-)
-def test_gptdata_streamin_gpus(fake_redis_server, variant, run_distributed_script_lean, result_path, request):
- # TODO: make tests on the same number of gpu as subtests
- # similar to how it is done in the test_model for speed
- run_distributed_gptdata_streaming_test(
- fake_redis_server=fake_redis_server,
- variant=variant,
- run_distributed_script=run_distributed_script_lean,
- result_path=result_path,
- request=request,
- ingestion_type=IngestionType.CONSUMER_GROUP,
- )
+@pytest.mark.slow
+@pytest.mark.depends_on(on=["test_data_streaming"])
+@pytest.mark.parametrize(("name", "num_gpus", "distributed_config_dict"), _DISTRIBUTED_TESTING_CONFIGS)
+def test_data_streaming_distributed(result_path, name, num_gpus, distributed_config_dict, report_subtest):
+ report_subtest(path := result_path / f"data_streaming/{name}", num_gpus)
+ distributed_config, batch_config = _get_distributed_and_batch_config(distributed_config_dict, num_gpus)
+ check_data_streaming_results(path, distributed_config, batch_config)
diff --git a/tests/data/test_tokenizer.py b/tests/data/test_tokenizer.py
index c7fdef9ca..4e9e2fdd5 100644
--- a/tests/data/test_tokenizer.py
+++ b/tests/data/test_tokenizer.py
@@ -40,3 +40,263 @@ def test_tokenize_with_spans(common_tokenizer, spans, expected_token_spans, expe
expected_token_spans = [(begin + 1, end + 1) for begin, end in expected_token_spans]
Assert.eq(tokens.tolist(), expected_tokens)
Assert.eq(token_spans, expected_token_spans)
+
+
+def test_validate_chat_template_no_template(common_tokenizer):
+ """Tokenizer without chat template raises."""
+ with pytest.raises(ValueError, match="does not have a chat template"):
+ common_tokenizer.validate_chat_template()
+
+
+def test_validate_chat_template_no_markers(common_tokenizer):
+ """Tokenizer with chat template but no markers raises."""
+ common_tokenizer.tokenizer.chat_template = "{{ messages }}"
+ with pytest.raises(ValueError, match="does not contain.*generation"):
+ common_tokenizer.validate_chat_template()
+
+
+def test_validate_chat_template_with_markers(common_tokenizer):
+ """Tokenizer with generation markers validates."""
+ common_tokenizer.tokenizer.chat_template = "{% generation %}{{ m }}{% endgeneration %}"
+ common_tokenizer.validate_chat_template()
+
+
+# Realistic chat template following HF conventions (e.g., SmolLM3):
+# The generation block includes the full assistant turn: opening tag, content, and closing tag.
+# This ensures the model learns to emit the closing tag.
+CHAT_TEMPLATE = (
+ "{% for message in messages %}"
+ "{% if message.role == 'assistant' %}"
+ "{% generation %}{{ message.content }}{% endgeneration %}"
+ "{% else %}"
+ "<{{ message.role }}>{{ message.content }}{{ message.role }}>"
+ "{% endif %}"
+ "{% endfor %}"
+)
+
+
+@pytest.mark.parametrize(
+ ("messages", "expected_tokens", "expected_loss_masking_spans"),
+ (
+ # Single turn: full assistant turn (Hello) is trainable
+ # 15 tokens, trainable indices 7-13, loss mask spans cover 0-6 and 14
+ (
+ [{"role": "user", "content": "Hi"}, {"role": "assistant", "content": "Hello"}],
+ [49152, 27, 789, 29, 16946, 750, 789, 2293, 17822, 29, 7371, 750, 17822, 29, 49152],
+ [(0, 7), (14, 15)],
+ ),
+ # Multi-turn: both assistant turns are fully trainable
+ # 27 tokens, trainable indices 7-13 and 19-25
+ (
+ [
+ {"role": "user", "content": "A"},
+ {"role": "assistant", "content": "B"},
+ {"role": "user", "content": "C"},
+ {"role": "assistant", "content": "D"},
+ ],
+ [
+ 49152,
+ 27,
+ 789,
+ 29,
+ 32,
+ 750,
+ 789,
+ 2293,
+ 17822,
+ 29,
+ 33,
+ 750,
+ 17822,
+ 2293,
+ 789,
+ 29,
+ 34,
+ 750,
+ 789,
+ 2293,
+ 17822,
+ 29,
+ 35,
+ 750,
+ 17822,
+ 29,
+ 49152,
+ ],
+ [(0, 7), (14, 19), (26, 27)],
+ ),
+ # System + user + assistant: full assistant turn trainable
+ # 23 tokens, trainable indices 15-21
+ (
+ [
+ {"role": "system", "content": "You are helpful."},
+ {"role": "user", "content": "Hi"},
+ {"role": "assistant", "content": "Hello"},
+ ],
+ [
+ 49152,
+ 27,
+ 3144,
+ 29,
+ 5815,
+ 1139,
+ 44569,
+ 6928,
+ 3144,
+ 2293,
+ 789,
+ 29,
+ 16946,
+ 750,
+ 789,
+ 2293,
+ 17822,
+ 29,
+ 7371,
+ 750,
+ 17822,
+ 29,
+ 49152,
+ ],
+ [(0, 15), (22, 23)],
+ ),
+ # User only: no trainable tokens
+ # 9 tokens, no trainable indices
+ (
+ [{"role": "user", "content": "Hi"}],
+ [49152, 27, 789, 29, 16946, 750, 789, 29, 49152],
+ [(0, 9)],
+ ),
+ # Long multi-turn (85 tokens, 3 assistant responses with tags, tests span machinery)
+ # Trainable: indices 27-40, 49-62, 70-83
+ (
+ [
+ {"role": "system", "content": "You are a helpful assistant that answers questions."},
+ {"role": "user", "content": "What is the capital of France?"},
+ {"role": "assistant", "content": "The capital of France is Paris."},
+ {"role": "user", "content": "What about Germany?"},
+ {"role": "assistant", "content": "The capital of Germany is Berlin."},
+ {"role": "user", "content": "And Italy?"},
+ {"role": "assistant", "content": "The capital of Italy is Rome."},
+ ],
+ [
+ 49152,
+ 27,
+ 3144,
+ 29,
+ 5815,
+ 1139,
+ 373,
+ 44569,
+ 2424,
+ 11886,
+ 954,
+ 15737,
+ 14516,
+ 6928,
+ 3144,
+ 2293,
+ 789,
+ 29,
+ 13938,
+ 438,
+ 331,
+ 25016,
+ 457,
+ 12409,
+ 562,
+ 35838,
+ 789,
+ 2293,
+ 17822,
+ 29,
+ 2111,
+ 25016,
+ 457,
+ 12409,
+ 562,
+ 438,
+ 4235,
+ 280,
+ 6928,
+ 17822,
+ 2293,
+ 789,
+ 29,
+ 13938,
+ 5028,
+ 759,
+ 42226,
+ 35838,
+ 789,
+ 2293,
+ 17822,
+ 29,
+ 2111,
+ 25016,
+ 457,
+ 759,
+ 42226,
+ 438,
+ 29784,
+ 3556,
+ 6928,
+ 17822,
+ 2293,
+ 789,
+ 29,
+ 1996,
+ 4413,
+ 3326,
+ 35838,
+ 789,
+ 2293,
+ 17822,
+ 29,
+ 2111,
+ 25016,
+ 457,
+ 4413,
+ 3326,
+ 438,
+ 613,
+ 1361,
+ 6928,
+ 17822,
+ 29,
+ 49152,
+ ],
+ [(0, 27), (41, 49), (63, 70), (84, 85)],
+ ),
+ ),
+)
+def test_tokenize_chat(common_tokenizer, messages, expected_tokens, expected_loss_masking_spans):
+ common_tokenizer.tokenizer.chat_template = CHAT_TEMPLATE
+ tokens, loss_masking_spans = common_tokenizer.tokenize_chat(messages)
+ Assert.eq(tokens.tolist(), expected_tokens)
+ Assert.eq(loss_masking_spans, expected_loss_masking_spans)
+
+
+@pytest.mark.parametrize(
+ ("train_mask", "expected_loss_spans"),
+ (
+ # All masked (no trainable tokens)
+ ([False, False, False], [(0, 3)]),
+ # All trainable (no spans)
+ ([True, True, True], []),
+ # Single trainable at start
+ ([True, False, False], [(1, 3)]),
+ # Single trainable at end
+ ([False, False, True], [(0, 2)]),
+ # Single trainable in middle
+ ([False, True, False], [(0, 1), (2, 3)]),
+ # Multiple trainable regions (simulates multi-turn conversation)
+ ([False, False, True, True, False, False, True, True, True, False], [(0, 2), (4, 6), (9, 10)]),
+ # Alternating
+ ([False, True, False, True, False], [(0, 1), (2, 3), (4, 5)]),
+ ),
+)
+def test_train_mask_to_loss_spans(train_mask, expected_loss_spans):
+ from fast_llm.data.preprocessing.tokenizer import _train_mask_to_loss_spans
+
+ Assert.eq(_train_mask_to_loss_spans(train_mask), expected_loss_spans)
diff --git a/tests/functional/test_cross_entropy.py b/tests/functional/test_cross_entropy.py
index 72644d061..20d16bb96 100644
--- a/tests/functional/test_cross_entropy.py
+++ b/tests/functional/test_cross_entropy.py
@@ -104,7 +104,9 @@ def _reverse_kl_forward_backward_torch(logits: torch.Tensor, target: torch.Tenso
reduction="none",
log_target=True,
).sum(dim=-1)
- output = per_sample.mean() if loss_mask is None else (per_sample * loss_mask).sum() / loss_mask.sum()
+ if loss_mask is not None:
+ per_sample = per_sample * loss_mask
+ output = per_sample.mean()
output.backward()
return output, logits.grad
diff --git a/tests/models/distributed_test_checkpoint.py b/tests/models/distributed_test_checkpoint.py
deleted file mode 100644
index 001eb36da..000000000
--- a/tests/models/distributed_test_checkpoint.py
+++ /dev/null
@@ -1,90 +0,0 @@
-import gc
-import logging
-
-import torch
-
-from fast_llm.cli import fast_llm_main_wrapper
-from fast_llm.config import NoAutoValidate
-from fast_llm.core.distributed import safe_barrier
-from fast_llm.engine.checkpoint.config import (
- CheckpointLoadConfig,
- CheckpointSaveConfig,
- DistributedCheckpointFormat,
- FastLLMCheckpointFormat,
-)
-from fast_llm.engine.distributed.config import DistributedConfig
-from fast_llm.engine.distributed.distributed import ProcessGroupPool
-from fast_llm.engine.multi_stage.config import StageMode
-from fast_llm.utils import Assert, header
-from tests.utils.model_configs import ModelTestingConfig
-from tests.utils.run_test_script import parse_run_distributed_script
-from tests.utils.save_load_configs import DISTRIBUTED_SAVE_LOAD_CONFIGS, DistributedSaveLoadConfig
-from tests.utils.utils import DistributedSubtestContext
-
-logger = logging.getLogger(__name__)
-
-
-def _test_load_and_save_parallel(
- model_testing_config: ModelTestingConfig,
- config: DistributedSaveLoadConfig,
-):
- logger.info(header(config.name))
- logger.info(f"Loading {config.load_format} checkpoint from {config.load_path}")
- with NoAutoValidate():
- load_config = CheckpointLoadConfig(path=config.load_path, format=config.load_format)
- load_config.setup(model_testing_config.model_config_class)
- load_config.validate()
- model = model_testing_config.model_class.from_pretrained(
- load_config,
- # The world size and rank are already set through environment variable.
- {"distributed": config.distributed},
- mode=StageMode.inference,
- )
- for save_format in (DistributedCheckpointFormat, FastLLMCheckpointFormat):
- logger.info(f"Saving {save_format.name} checkpoint to {config.save_path / save_format.name}")
- model.save_checkpoint(CheckpointSaveConfig(path=config.save_path / save_format.name, format=save_format))
- del model
- gc.collect()
- torch.cuda.empty_cache()
-
-
-def main(args: list[str] | None = None) -> None:
- base_path, model_testing_config, do_capture = parse_run_distributed_script(args)
-
- if do_capture:
- logger.warning(
- "Capturing output and forwarding to associated tests. Run with `--no-distributed-capture` to disable."
- )
-
- with ProcessGroupPool(timeout=20) as pool:
- failures = []
- world_size = DistributedConfig.default_world_size
- rank = DistributedConfig.default_rank
- group = pool.get_process_group(range(world_size), rank)
-
- for config in DISTRIBUTED_SAVE_LOAD_CONFIGS.values():
- if config.load_format == "{checkpoint_format}" and model_testing_config.checkpoint_format is None:
- continue
- config = config.resolve(base_path, model_testing_config)
- Assert.eq(world_size, config.num_gpus)
- with DistributedSubtestContext(base_path, config.name, group, world_size, enabled=do_capture) as subtest:
- _test_load_and_save_parallel(
- model_testing_config=model_testing_config,
- config=config,
- )
- if not subtest.success:
- failures.append(config.name)
-
- # Final barrier to ensure everything is done before torchrun potentially kills workers.
- safe_barrier(group, "testing end")
- # Let pytest know how things went.
- # These should already be reported above, we repeat for convenience.
- if failures:
- raise RuntimeError(f"The following subtests failed: {", ".join(failures)}")
- else:
- logger.warning("All tests passed")
-
-
-if __name__ == "__main__":
- with fast_llm_main_wrapper():
- main()
diff --git a/tests/models/distributed_test_model.py b/tests/models/distributed_test_model.py
deleted file mode 100644
index 890a75077..000000000
--- a/tests/models/distributed_test_model.py
+++ /dev/null
@@ -1,54 +0,0 @@
-import logging
-
-from fast_llm.cli import fast_llm_main_wrapper
-from fast_llm.core.distributed import safe_barrier
-from fast_llm.engine.distributed.config import DistributedConfig
-from fast_llm.engine.distributed.distributed import ProcessGroupPool
-from tests.utils.distributed_configs import DISTRIBUTED_TESTING_CONFIGS
-from tests.utils.run_test_script import do_run_test_script_for_all_models, parse_run_distributed_script
-from tests.utils.utils import DistributedSubtestContext
-
-logger = logging.getLogger(__name__)
-
-
-def main(args: list[str] | None = None) -> None:
- base_path, model_testing_config, do_capture = parse_run_distributed_script(args)
-
- if do_capture:
- logger.warning(
- "Capturing output and forwarding to associated tests. Run with `--no-distributed-capture` to disable."
- )
-
- # TODO: Why are barriers needed?
- with ProcessGroupPool(timeout=60) as pool:
- failures = []
- world_size = DistributedConfig.default_world_size
- rank = DistributedConfig.default_rank
- group = pool.get_process_group(range(world_size), rank)
- safe_barrier(group, "start")
-
- for name, config in DISTRIBUTED_TESTING_CONFIGS.items():
- if model_testing_config.should_skip(config):
- continue
- if world_size < config.num_gpus:
- logger.warning(f"{name} {f"SKIPPED (not enough GPUs: {world_size} < {config.num_gpus})"})")
- continue
- with DistributedSubtestContext(base_path, name, group, config.num_gpus, enabled=do_capture) as subtest:
- if rank < config.num_gpus:
- do_run_test_script_for_all_models(config, model_testing_config, base_path)
- if not subtest.success:
- failures.append(name)
-
- # Final barrier to ensure everything is done before torchrun potentially kills workers.
- safe_barrier(group, "testing end")
- # Let pytest know how things went.
- # These should already be reported above, we repeat for convenience.
- if failures:
- raise RuntimeError(f"The following subtests failed: {", ".join(failures)}")
- else:
- logger.warning("All tests passed")
-
-
-if __name__ == "__main__":
- with fast_llm_main_wrapper():
- main()
diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py
index 53804d878..9a3bc4345 100644
--- a/tests/models/test_checkpoint.py
+++ b/tests/models/test_checkpoint.py
@@ -1,3 +1,4 @@
+import gc
import logging
import pathlib
import shutil
@@ -7,6 +8,7 @@
import torch
import yaml
+from fast_llm.config import NoAutoValidate
from fast_llm.engine.checkpoint.config import (
CheckpointFormat,
CheckpointLoadConfig,
@@ -16,12 +18,13 @@
ModelConfigType,
)
from fast_llm.engine.checkpoint.convert import ConvertConfig
-from fast_llm.engine.multi_stage.config import FastLLMModelConfig, ShardName
-from fast_llm.utils import Assert
+from fast_llm.engine.multi_stage.config import FastLLMModelConfig, ShardName, StageMode
+from fast_llm.utils import Assert, header
from tests.utils.compare_tensor_logs import CompareConfig
from tests.utils.distributed_configs import DistributedTestingConfig
from tests.utils.model_configs import ModelTestingConfig, ModelTestingGroup
from tests.utils.save_load_configs import DISTRIBUTED_SAVE_LOAD_CONFIGS, DistributedSaveLoadConfig
+from tests.utils.subtest import DistributedTestContext
from tests.utils.utils import requires_cuda
logger = logging.getLogger(__name__)
@@ -152,7 +155,7 @@ def test_conversion(model_testing_config, run_conversion, get_convert_path):
)
-def _compare_safetensor_files(
+def compare_safetensor_files(
reference: pathlib.Path | dict[str, torch.Tensor],
*other_paths: pathlib.Path,
expected_keys: set[str] | None = None,
@@ -166,9 +169,10 @@ def _compare_safetensor_files(
for other_path in other_paths:
other = safetensors.torch.load_file(other_path)
- Assert.eq(other.keys(), expected_keys)
+ if other.keys() != expected_keys:
+ raise ValueError(f"Expected keys {expected_keys} but got {other.keys()} in {other_path}")
for key in expected_keys:
- Assert.all_equal(reference[key], other[key])
+ Assert.all_equal(reference[key], other[key], msg=f"tensor = {key}, path = {other_path}")
@requires_cuda
@@ -177,24 +181,24 @@ def _compare_safetensor_files(
def test_converted_round_trip(model_testing_config, get_convert_path):
# Test that the various possible conversion paths yield identical results.
if model_testing_config.checkpoint_format is None:
- _compare_safetensor_files(
+ compare_safetensor_files(
get_convert_path() / "rank_0.safetensors",
get_convert_path(DistributedCheckpointFormat, FastLLMCheckpointFormat) / "rank_0.safetensors",
expected_keys={_WEIGHT_SHARD_SAVE_NAME},
)
else:
- _compare_safetensor_files(
+ compare_safetensor_files(
get_convert_path() / "rank_0.safetensors",
get_convert_path(DistributedCheckpointFormat, FastLLMCheckpointFormat) / "rank_0.safetensors",
get_convert_path(DistributedCheckpointFormat, model_testing_config.checkpoint_format)
/ "rank_0.safetensors",
expected_keys={_WEIGHT_SHARD_SAVE_NAME},
)
- _compare_safetensor_files(
+ compare_safetensor_files(
get_convert_path(FastLLMCheckpointFormat, DistributedCheckpointFormat) / "model_0.safetensors",
get_convert_path(FastLLMCheckpointFormat, model_testing_config.checkpoint_format) / "model_0.safetensors",
)
- _compare_safetensor_files(
+ compare_safetensor_files(
get_convert_path(model_testing_config.checkpoint_format, DistributedCheckpointFormat)
/ "model_0.safetensors",
get_convert_path(model_testing_config.checkpoint_format, FastLLMCheckpointFormat) / "model_0.safetensors",
@@ -391,31 +395,55 @@ def test_huggingface_model(model_testing_config, get_convert_path):
raise ValueError(f"Comparison failed ({len(errors)} errors)")
+def _save_and_load_in_parallel(
+ test_context: DistributedTestContext, base_path: pathlib.Path, model_testing_config: ModelTestingConfig
+) -> None:
+ # Import all dynamic classes.
+ import fast_llm.cli # noqa
+
+ for config in DISTRIBUTED_SAVE_LOAD_CONFIGS.values():
+ if config.load_format == "{checkpoint_format}" and model_testing_config.checkpoint_format is None:
+ continue
+ config = config.resolve(base_path, model_testing_config)
+ with test_context.subtest(base_path, config.name, config.num_gpus) as subtest:
+ if subtest.do_run:
+ logger.info(header(config.name))
+ logger.info(f"Loading {config.load_format} checkpoint from {config.load_path}")
+ with NoAutoValidate():
+ load_config = CheckpointLoadConfig(path=config.load_path, format=config.load_format)
+ load_config.setup(model_testing_config.model_config_class)
+ load_config.validate()
+ model = model_testing_config.model_class.from_pretrained(
+ load_config,
+ # The world size and rank are already set through environment variable.
+ {"distributed": config.distributed},
+ mode=StageMode.inference,
+ )
+ for save_format in (DistributedCheckpointFormat, FastLLMCheckpointFormat):
+ logger.info(f"Saving {save_format.name} checkpoint to {config.save_path / save_format.name}")
+ model.save_checkpoint(
+ CheckpointSaveConfig(path=config.save_path / save_format.name, format=save_format)
+ )
+ del model
+ gc.collect()
+ torch.cuda.empty_cache()
+
+
@requires_cuda
@pytest.mark.depends_on(on=["test_load_pretrained[{model_testing_config}]"])
@pytest.mark.model_testing_group(ModelTestingGroup.convert, ModelTestingGroup.distributed)
-def test_save_and_load_in_parallel(run_distributed_script, run_test_script_base_path, model_testing_config, request):
+def test_save_and_load_in_parallel(run_parallel_script, run_test_script_base_path, model_testing_config):
# Save and load checkpoints to and from various distributed configurations.
# Combined in a single test to mitigate process creation overhead.
# TODO: Test beyond 2 gpu configs?
- import tests.models.distributed_test_checkpoint
-
if torch.cuda.device_count() < 2:
- pytest.skip(f"Not enough GPUs: {torch.cuda.device_count()} < 2")
-
- script = [
- "-m",
- tests.models.distributed_test_checkpoint.__name__,
- str(run_test_script_base_path),
- model_testing_config.name,
- ]
- if request.config.getoption("distributed_capture"):
- logger.warning(
- "Capturing output and forwarding to associated tests. Run with `--no-distributed-capture` to disable."
- )
- else:
- script.append("--no-distributed-capture")
- run_distributed_script(script, num_gpus=2)
+ pytest.skip(f"Not enough GPUs2")
+ run_parallel_script(
+ _save_and_load_in_parallel,
+ (run_test_script_base_path, model_testing_config),
+ world_size=2,
+ backend=model_testing_config.distributed_backend,
+ )
@pytest.fixture(scope="module")
@@ -431,7 +459,6 @@ def reference_distributed_shard(get_convert_path) -> torch.Tensor | None:
# We don't want to depend on `test_save_and_load_in_parallel` because we still want to run this in cas of failure.
# This should still run after `test_save_and_load_in_parallel`
@requires_cuda
-# NOTE: Should it depend on test_model_distributed instead?
@pytest.mark.depends_on(on=["test_load_pretrained[{model_testing_config}]"])
@pytest.mark.model_testing_group(ModelTestingGroup.convert, ModelTestingGroup.distributed)
def test_load_parallel_checkpoint_in_single_gpu(
@@ -473,7 +500,7 @@ def test_parallel_checkpoint_consistency(model_testing_config, run_test_script_b
# Compare Distributed checkpoints
for config in ("dp2", "tp2", "stp2", "pp2"):
for rank in range(2):
- _compare_safetensor_files(
+ compare_safetensor_files(
*[
DISTRIBUTED_SAVE_LOAD_CONFIGS[f"load_{format_}_in_{config}"]
.resolve(base_path=run_test_script_base_path, model_testing_config=model_testing_config)
@@ -511,7 +538,7 @@ def test_multi_gpu_fast_llm_checkpoint(
base_path=run_test_script_base_path, model_testing_config=model_testing_config
)
- _compare_safetensor_files(
+ compare_safetensor_files(
reference_fast_llm_shard,
distributed_save_load_config_non_pp.save_path / f"{FastLLMCheckpointFormat.name}/model_0.safetensors",
)
diff --git a/tests/models/test_model.py b/tests/models/test_model.py
index d14721142..58768bc52 100644
--- a/tests/models/test_model.py
+++ b/tests/models/test_model.py
@@ -1,4 +1,5 @@
import logging
+import pathlib
import pytest
import torch
@@ -8,8 +9,10 @@
SIMPLE_TESTING_CONFIG,
SINGLE_GPU_TESTING_CONFIGS,
)
-from tests.utils.model_configs import ModelTestingGroup
-from tests.utils.utils import check_subtest_success, requires_cuda, set_subtest_success
+from tests.utils.model_configs import ModelTestingConfig, ModelTestingGroup
+from tests.utils.run_test_script import do_run_test_script_for_all_models
+from tests.utils.subtest import DistributedTestContext, check_subtest_success, set_subtest_success
+from tests.utils.utils import requires_cuda
logger = logging.getLogger(__name__)
@@ -49,27 +52,34 @@ def test_and_compare_model(
compare_results_for_all_models(config)
+def _run_model_distributed(
+ test_context: DistributedTestContext, base_path: pathlib.Path, model_testing_config: ModelTestingConfig
+) -> None:
+ # Import all dynamic classes.
+ import fast_llm.cli # noqa
+
+ for name, config in DISTRIBUTED_TESTING_CONFIGS.items():
+ if model_testing_config.should_skip(config):
+ continue
+ with test_context.subtest(base_path, name, config.num_gpus) as subtest:
+ if subtest.do_run:
+ do_run_test_script_for_all_models(config, model_testing_config, base_path)
+
+
@requires_cuda
@pytest.mark.depends_on(on=["test_model_simple[{model_testing_config}]"])
@pytest.mark.model_testing_group(
ModelTestingGroup.distributed,
)
-def test_run_model_distributed(run_distributed_script, model_testing_config, run_test_script_base_path, request):
- import tests.models.distributed_test_model
-
- script = [
- "-m",
- tests.models.distributed_test_model.__name__,
- str(run_test_script_base_path),
- model_testing_config.name,
- ]
- if request.config.getoption("distributed_capture"):
- logger.warning(
- "Capturing output and forwarding to associated tests. Run with `--no-distributed-capture` to disable."
- )
- else:
- script.append("--no-distributed-capture")
- run_distributed_script(script, num_gpus=torch.cuda.device_count())
+def test_run_model_distributed(run_parallel_script, model_testing_config, run_test_script_base_path):
+ if torch.cuda.device_count() < 2:
+ pytest.skip(f"Not enough GPUs")
+ run_parallel_script(
+ _run_model_distributed,
+ (run_test_script_base_path, model_testing_config),
+ world_size=torch.cuda.device_count(),
+ backend=model_testing_config.distributed_backend,
+ )
# We don't want to depend on `test_model_distributed` because we still want to run this in cas of failure.
diff --git a/tests/models/test_streaming.py b/tests/models/test_streaming.py
new file mode 100644
index 000000000..f132b465d
--- /dev/null
+++ b/tests/models/test_streaming.py
@@ -0,0 +1,213 @@
+import contextlib
+import dataclasses
+import functools
+import json
+import logging
+import pathlib
+
+import pytest
+import safetensors
+import torch
+
+from fast_llm.engine.training.config import StreamingTrainerCallbackConfig
+from fast_llm.engine.training.streaming import REDIS_TRAINING_FIELD, REDIS_TRAINING_STREAM
+from fast_llm.utils import Assert
+from tests.conftest import WorkerResources
+from tests.models.test_checkpoint import compare_safetensor_files
+from tests.utils.distributed_configs import DistributedTestingConfig
+from tests.utils.model_configs import ModelTestingConfig, ModelTestingGroup, update_and_add_testing_config
+from tests.utils.redis import redis_batch_producer
+from tests.utils.run_test_script import do_run_test_script_for_all_models
+from tests.utils.subtest import DistributedTestContext
+from tests.utils.utils import requires_cuda
+
+
+@dataclasses.dataclass(kw_only=True)
+class StreamingDistributedTestingConfig(DistributedTestingConfig):
+ consumer_count: int = (1,)
+
+ @functools.cached_property
+ def total_gpus(self) -> int:
+ return self.num_gpus + self.consumer_count
+
+
+_DISTRIBUTED_STREAMING_CONFIGS = [
+ StreamingDistributedTestingConfig(name="streaming_simple", config_args=[], num_gpus=1, consumer_count=1),
+ StreamingDistributedTestingConfig(name="streaming_dp2", config_args=[], num_gpus=2, consumer_count=1),
+ StreamingDistributedTestingConfig(
+ name="streaming_sdp2_c2",
+ config_args=["model.distributed.sequence_data_parallel=2"],
+ num_gpus=2,
+ consumer_count=2,
+ ),
+ StreamingDistributedTestingConfig(
+ name="streaming_tp2", config_args=["model.distributed.tensor_parallel=2"], num_gpus=2, consumer_count=2
+ ),
+ StreamingDistributedTestingConfig(
+ name="streaming_stp2_c2",
+ config_args=[
+ "model.distributed.tensor_parallel=2",
+ "model.distributed.sequence_tensor_parallel=true",
+ "callbacks.streaming.broadcast.external_world_size=2",
+ ],
+ num_gpus=2,
+ consumer_count=2,
+ ),
+]
+
+
+def _run_event_consumer(
+ streaming_config: StreamingTrainerCallbackConfig, consumer_index: int, base_path: pathlib.Path
+) -> None:
+ client = streaming_config.get_client()
+ init_method = f"tcp://{streaming_config.broadcast.host}:{streaming_config.broadcast.port}"
+ logging.info(f"Waiting for weights broadcast rendezvous at {init_method} ...")
+ path = base_path / "streaming"
+ path.mkdir(parents=True, exist_ok=True)
+ field = REDIS_TRAINING_FIELD.encode()
+ # TODO: Create a custom process group instead.
+ try:
+ process_group = torch.distributed.init_process_group(
+ backend="nccl",
+ init_method=init_method,
+ world_size=streaming_config.broadcast.external_world_size + 1,
+ rank=consumer_index + 1,
+ )
+ last_id = "0-0"
+ while True:
+ result = client.xread(
+ streams={REDIS_TRAINING_STREAM: last_id},
+ count=1,
+ block=10000,
+ )
+ if not result:
+ raise TimeoutError("No message received after 10000 ms...")
+
+ ((stream, events),) = result
+ Assert.eq(stream.decode(), REDIS_TRAINING_STREAM)
+ Assert.eq(len(events), 1)
+ for last_id, message in events:
+ Assert.eq(message.keys(), {field})
+ message = json.loads(message[field].decode())
+ logging.info(f"Received: {message}")
+ if message["type"] == "training_finished":
+ return
+ elif message["type"] == "weights_ready":
+ weights = {}
+ while True:
+ meta = [None]
+ torch.distributed.broadcast_object_list(meta, group=process_group, group_src=0)
+ if meta[0] is None:
+ print(f"Weight broadcast finished")
+ break
+ logging.info(f"receiving {meta[0]}")
+ shard_name, layer_name, tensor_size, tensor_type = meta[0]
+ tensor = torch.zeros(tuple(tensor_size), dtype=tensor_type, device="cuda")
+ torch.distributed.broadcast(tensor, group=process_group, group_src=0)
+ if shard_name == "weights":
+ weights[layer_name] = tensor
+ safetensors.torch.save_file(
+ weights, path / f"rank_{consumer_index}_step_{message["step"]}.safetensors"
+ )
+
+ finally:
+ torch.distributed.destroy_process_group()
+
+
+def _run_model_streaming_configs(
+ test_context: DistributedTestContext, base_path: pathlib.Path, model_testing_config: ModelTestingConfig, port: int
+) -> None:
+ # Import all dynamic classes.
+ import fast_llm.cli # noqa
+
+ for config in _DISTRIBUTED_STREAMING_CONFIGS:
+ model_testing_config = update_and_add_testing_config(
+ model_testing_config,
+ None,
+ updates={
+ ("data", "datasets"): {"training": {"port": port}},
+ ("training", "export"): {"format": model_testing_config.checkpoint_format.name, "interval": 1},
+ "callbacks": {
+ "streaming": {
+ "type": "streaming",
+ "port": port,
+ "broadcast": {
+ "port": port + 1000,
+ "external_world_size": config.consumer_count,
+ },
+ "export": {"format": model_testing_config.checkpoint_format.name},
+ }
+ },
+ # Disable tensor logging.
+ ("run", "tensor_logs"): {},
+ ("model", "multi_stage"): {},
+ },
+ groups={},
+ )
+ with test_context.subtest(base_path, config.name, config.total_gpus) as subtest:
+ if subtest.do_run:
+ if test_context.rank < config.num_gpus:
+ do_run_test_script_for_all_models(config, model_testing_config, base_path)
+ elif test_context.rank < config.total_gpus:
+ training_config = model_testing_config.trainer_config_class.from_dict(
+ model_testing_config.config_dict
+ )
+ with (
+ redis_batch_producer(training_config.callbacks["streaming"], training_config.batch)
+ if test_context.rank == config.num_gpus
+ else contextlib.nullcontext()
+ ):
+ _run_event_consumer(
+ training_config.callbacks["streaming"],
+ test_context.rank - config.num_gpus,
+ base_path / config.name,
+ )
+
+
+@requires_cuda
+@pytest.mark.slow
+@pytest.mark.model_testing_group(ModelTestingGroup.streaming, ModelTestingGroup.distributed)
+def test_model_streaming(run_parallel_script, model_testing_config, run_test_script_base_path, worker_resources):
+ # `test_run_model_distributed_streaming` and `test_model_distributed_streaming need a common dependency
+ # so they are placed in the same testing group and run in the same distributed process.
+ pass
+
+
+@requires_cuda
+@pytest.mark.slow
+@pytest.mark.depends_on(on=["test_model_streaming[{model_testing_config}]"])
+@pytest.mark.model_testing_group(ModelTestingGroup.streaming, ModelTestingGroup.distributed)
+def test_run_model_distributed_streaming(
+ run_parallel_script, model_testing_config, run_test_script_base_path, worker_resources
+):
+ if torch.cuda.device_count() < 2:
+ pytest.skip(f"Not enough GPUs")
+ run_parallel_script(
+ _run_model_streaming_configs,
+ (run_test_script_base_path, model_testing_config, worker_resources.torchrun_port),
+ world_size=torch.cuda.device_count(),
+ backend=model_testing_config.distributed_backend,
+ )
+
+
+@pytest.mark.slow
+@requires_cuda
+@pytest.mark.depends_on(on=["test_model_streaming[{model_testing_config}]"])
+@pytest.mark.model_testing_group(ModelTestingGroup.streaming, ModelTestingGroup.distributed)
+@pytest.mark.parametrize("config", _DISTRIBUTED_STREAMING_CONFIGS)
+def test_model_distributed_streaming(
+ config: StreamingDistributedTestingConfig,
+ run_distributed_script,
+ model_testing_config,
+ run_test_script_base_path,
+ worker_resources: WorkerResources,
+ report_subtest,
+):
+ report_subtest(path := run_test_script_base_path / config.name, config.total_gpus)
+ compare_safetensor_files(
+ path / "export" / model_testing_config.checkpoint_format.name / f"1/model_0.safetensors",
+ *(
+ path / "streaming" / f"rank_{consumer_index}_step_1.safetensors"
+ for consumer_index in range(config.consumer_count)
+ ),
+ )
diff --git a/tests/trainer/events_fake_consumer.py b/tests/trainer/events_fake_consumer.py
deleted file mode 100644
index 4c2d30891..000000000
--- a/tests/trainer/events_fake_consumer.py
+++ /dev/null
@@ -1,105 +0,0 @@
-import sys
-from pathlib import Path
-
-import orjson
-import redis
-import safetensors.torch
-import torch.distributed
-import yaml
-
-
-def main():
- if len(sys.argv) != 2:
- print("Usage: python -m tests.trainer.events_fake_consumer ")
- sys.exit(1)
-
- config_path = Path(sys.argv[1])
- if not config_path.exists():
- print(f"Config file {config_path} does not exist")
- sys.exit(1)
-
- with config_path.open("rt") as f:
- config = yaml.safe_load(f)
-
- consumer_cfg = config["consumer"]
- world_size = consumer_cfg["world_size"]
- rank = consumer_cfg["rank"]
- results_path = Path(consumer_cfg["results_path"])
- results_path.mkdir(parents=True, exist_ok=True)
-
- consumer_id = f"[Consumer {rank}/{world_size}]"
-
- print(f"{consumer_id} Started with config:")
- print(yaml.safe_dump(config))
-
- assert config["events"]["weights_broadcast"]["enabled"]
- assert config["events"]["training_finished"]["enabled"]
-
- redis_client = redis.Redis(host=config["events"]["redis"]["host"], port=config["events"]["redis"]["port"])
-
- print(f"{consumer_id} waiting for pg rendezvous...")
- weights_pg = torch.distributed.init_process_group(
- backend="nccl",
- init_method=f'tcp://{config["events"]["weights_broadcast"]["rdvz_master_address"]}:'
- f'{config["events"]["weights_broadcast"]["rdvz_master_port"]}',
- world_size=world_size,
- rank=rank,
- )
- broadcast_source_rank = config["events"]["weights_broadcast"]["rank"]
-
- last_id = "0-0"
- msg_key = config["events"]["redis"]["payload_key"].encode()
- stream_key = config["events"]["redis"]["stream_key"]
-
- print(f"{consumer_id} waiting for messages...")
- while True:
- result = redis_client.xread(
- streams={stream_key: last_id},
- count=1,
- block=200,
- )
-
- if not result:
- continue
-
- _, events = result[0]
-
- for event_id, msg in events:
- last_id = event_id
- assert msg_key in msg
- msg = orjson.loads(msg[msg_key].decode())
- print(f"{consumer_id} msg received: {msg}")
- if msg["type"] == config["events"]["weights_broadcast"]["weights_ready_message_type"] or (
- msg["type"] == config["events"]["weights_broadcast"]["initial_weights_step_message_type"]
- and config["events"]["weights_broadcast"]["initial_weights_step_message_includes_weights"]
- ):
- weights = {}
- while True:
- meta = [None]
- torch.distributed.broadcast_object_list(meta, group=weights_pg, group_src=broadcast_source_rank)
- meta = meta[0]
- if meta is None:
- print(f"{consumer_id} weight broadcast finished")
- break
- shard_name, layer_name, tensor_size, tensor_type = meta
- tensor = torch.zeros(
- tuple(tensor_size), dtype=tensor_type, device="cuda"
- ) # so far consumer is single gpu only
- torch.distributed.broadcast(tensor, group=weights_pg, group_src=broadcast_source_rank)
- print(f"{consumer_id} {shard_name} layer {layer_name} {tensor_size} {tensor_type} received")
- if shard_name == "weights":
- weights[layer_name] = tensor
- safetensors.torch.save_file(weights, results_path / f"{msg["step"]}.safetensors")
-
- elif msg["type"] == config["events"]["training_finished"]["training_finished_message_type"]:
- torch.distributed.destroy_process_group()
- (results_path / "training_finished").touch()
- return
- else:
- raise RuntimeError(f"{consumer_id} Received unknown message type {msg}")
- if msg["type"] == config["events"]["weights_broadcast"]["initial_weights_step_message_type"]:
- (results_path / "initial_weights_step").touch()
-
-
-if __name__ == "__main__":
- main()
diff --git a/tests/trainer/test_events.py b/tests/trainer/test_events.py
deleted file mode 100644
index 4894a022f..000000000
--- a/tests/trainer/test_events.py
+++ /dev/null
@@ -1,407 +0,0 @@
-import contextlib
-import copy
-import os
-import pathlib
-import subprocess
-import time
-import typing
-
-import pytest
-import safetensors
-import torch
-import yaml
-
-from tests.utils.model_configs import MODEL_CONFIGS
-from tests.utils.redis import redis_batch_producer
-from tests.utils.utils import requires_cuda
-
-
-@contextlib.contextmanager
-def run_fake_events_consumers(
- model_config: dict,
- test_result_path: pathlib.Path,
- broadcast_world_size: int,
- fake_consumers_broadcast_ranks: list[int],
- assigned_gpus: list[str],
- timeout: float = 30.0, # seconds
-):
- """
- Context manager to run fake event consumer subprocesses for testing.
-
- Each subprocess gets a separate config and CUDA_VISIBLE_DEVICES.
-
- After exiting the context, all subprocesses are ensured to terminate.
- Raises RuntimeError if any subprocess exits with non-zero code.
- """
- import tests.trainer.events_fake_consumer
-
- assert len(assigned_gpus) > 0
- assert len(assigned_gpus) == len(fake_consumers_broadcast_ranks)
-
- processes = []
-
- try:
- for i, gpu in enumerate(assigned_gpus):
- consumer_path = test_result_path / str(i)
- consumer_path.mkdir(parents=True, exist_ok=True)
-
- # Deep copy config and update per consumer
- this_config = copy.deepcopy(model_config)
- this_config["consumer"] = {
- "idx": i,
- "results_path": consumer_path / "results",
- "world_size": broadcast_world_size,
- "rank": fake_consumers_broadcast_ranks[i],
- }
- this_config_path = consumer_path / "config.yaml"
-
- # Save config as YAML
- with open(this_config_path, "w") as f:
- yaml.safe_dump(convert_paths(this_config), f)
-
- # Build subprocess command
- script = [
- "python",
- "-m",
- tests.trainer.events_fake_consumer.__name__,
- str(this_config_path),
- ]
- env = os.environ.copy()
- env["CUDA_VISIBLE_DEVICES"] = str(gpu)
-
- # Start subprocess
- proc = subprocess.Popen(script, env=env)
- processes.append(proc)
-
- # Yield control to the caller while subprocesses run
- yield
-
- finally:
- # Wait for processes to exit or kill after timeout
- start_time = time.time()
- for proc in processes:
- try:
- remaining = max(0, timeout - (time.time() - start_time))
- proc.wait(timeout=remaining)
- except subprocess.TimeoutExpired:
- proc.kill()
-
- # Check exit codes
- errors = [(i, p.returncode) for i, p in enumerate(processes) if p.returncode != 0]
- if errors:
- raise RuntimeError(f"Some fake consumer subprocesses failed: {errors}")
-
-
-def run_fast_llm_training(model_config, run_distributed_script, assigned_gpus):
- import fast_llm.cli
-
- config_path = model_config["run"]["experiment_dir"] / "load_config.yaml"
- config_path.parent.mkdir(parents=True, exist_ok=True)
- with config_path.open("wt") as f:
- yaml.safe_dump(convert_paths(model_config), f)
-
- script = [
- "-m",
- fast_llm.cli.__name__,
- "train",
- "gpt",
- "--config",
- str(config_path),
- ]
-
- env = os.environ.copy()
- env["PYTHONHASHSEED"] = "42"
- env["CUDA_VISIBLE_DEVICES"] = ",".join(str(gpu) for gpu in assigned_gpus)
- run_distributed_script(script, num_gpus=len(assigned_gpus), env=env)
-
-
-def compare_test_tensors_to_checkpoint(test_safetensor_path: str, checkpoint_dir: str):
- """
- Compare a test-saved safetensor file (a dict of tensors)
- to all safetensors in a checkpoint directory.
-
- Checks:
- - tensor names must match
- - shapes must match
- - dtypes must match
- - values must match (exact)
- """
-
- # -------------------------
- # Load test tensor file
- # -------------------------
- test_tensors = {}
- with safetensors.safe_open(test_safetensor_path, framework="pt", device="cpu") as f:
- for key in f.keys():
- test_tensors[key] = f.get_tensor(key)
-
- assert len(test_tensors) > 0, f"No tensors found in {test_safetensor_path}."
-
- # -------------------------
- # Load checkpoint tensors
- # -------------------------
- checkpoint_tensors = {}
-
- for file in os.listdir(checkpoint_dir):
- if file.endswith(".safetensors"):
- path = os.path.join(checkpoint_dir, file)
- with safetensors.safe_open(path, framework="pt", device="cpu") as f:
- for key in f.keys():
- if key in checkpoint_tensors:
- raise AssertionError(
- f"Duplicate tensor name '{key}' across checkpoint {checkpoint_dir} files."
- )
- checkpoint_tensors[key] = f.get_tensor(key)
-
- assert len(checkpoint_tensors) > 0, f"No safetensors found in checkpoint directory: {checkpoint_dir}"
-
- # -------------------------
- # Compare tensor sets
- # -------------------------
- test_names = set(test_tensors.keys())
- ckpt_names = set(checkpoint_tensors.keys())
-
- unexpected_in_test = test_names - ckpt_names
- missing_in_test = ckpt_names - test_names
-
- assert not missing_in_test, "Tensors missing in {test_safetensor_path}:\n" + "\n".join(sorted(missing_in_test))
- assert not unexpected_in_test, "Unexpected tensors in {test_safetensor_path}:\n" + "\n".join(
- sorted(unexpected_in_test)
- )
-
- # -------------------------
- # Compare individual tensors
- # -------------------------
- for name in sorted(test_names):
- t_test = test_tensors[name]
- t_ckpt = checkpoint_tensors[name]
-
- # dtype
- assert t_test.dtype == t_ckpt.dtype, f"Mismatch in dtype for '{name}': " f"{t_test.dtype} != {t_ckpt.dtype}"
-
- # shape
- assert t_test.shape == t_ckpt.shape, (
- f"Mismatch in shape for '{name}': " f"{tuple(t_test.shape)} != {tuple(t_ckpt.shape)}"
- )
-
- # values
- if not torch.equal(t_test, t_ckpt):
- diff = (t_test - t_ckpt).abs()
- max_diff = diff.max().item()
- idx = (diff > 0).nonzero(as_tuple=False)
- example = idx[0].tolist() if idx.numel() > 0 else "unknown"
-
- raise AssertionError(
- f"Tensor content mismatch for '{name}'.\n"
- f"Max difference: {max_diff}\n"
- f"Example differing index: {example}"
- )
-
- # If we reached here → all is good
- return True
-
-
-def check_events_results(
- test_results_path_fast_llm,
- test_results_path_consumers,
- consumer_count,
- training_steps,
- model_checkpoint_format,
-):
- for consumer_idx in range(consumer_count):
- consumer_test_results_path = test_results_path_consumers / str(consumer_idx) / "results"
- assert (consumer_test_results_path / "training_finished").is_file()
- assert (consumer_test_results_path / "initial_weights_step").is_file()
- # NOTE: We do not test the initial weights broadcast result when enabled,
- # because it is identical to subsequent broadcasts.
- for training_step in range(1, training_steps + 1):
- compare_test_tensors_to_checkpoint(
- consumer_test_results_path / f"{training_step}.safetensors",
- test_results_path_fast_llm / "export" / model_checkpoint_format / str(training_step),
- )
-
-
-def convert_paths(obj):
- if isinstance(obj, dict):
- return {k: convert_paths(v) for k, v in obj.items()}
- elif isinstance(obj, list):
- return [convert_paths(v) for v in obj]
- elif isinstance(obj, tuple):
- return tuple(convert_paths(v) for v in obj)
- elif isinstance(obj, pathlib.Path):
- return str(obj)
- else:
- return obj
-
-
-def parallelism_variants(num_gpus: int) -> list[dict[str, int]]:
- if num_gpus == 1:
- return [{"tp": 1, "pp": 1, "sp": 1}]
-
- if num_gpus == 2:
- return [
- # NOTE: Streaming dataset is currently not compatible with pipeline parallelism.
- {"tp": 2, "pp": 1, "sp": 1},
- # {"tp": 1, "pp": 2, "sp": 1},
- {"tp": 1, "pp": 1, "sp": 2},
- ]
-
- if num_gpus == 4:
- return [
- # NOTE: Streaming dataset is currently not compatible with pipeline parallelism.
- {"tp": 4, "pp": 1, "sp": 1},
- # {"tp": 1, "pp": 4, "sp": 1},
- {"tp": 1, "pp": 1, "sp": 4},
- # {"tp": 2, "pp": 2, "sp": 1},
- # {"tp": 1, "pp": 2, "sp": 2},
- {"tp": 2, "pp": 1, "sp": 2},
- ]
-
- raise ValueError(f"Invalid gpu count for fast_llm parallelism {num_gpus}")
-
-
-def consumer_counts(num_gpus: int) -> int:
- if num_gpus == 2:
- return 1
- if num_gpus == 3:
- return 1
- if num_gpus == 4:
- return 2
- if num_gpus == 5:
- return 1
- if num_gpus == 6:
- return 2
- if num_gpus == 7:
- return 3
- if num_gpus >= 8:
- return 4
-
-
-def generate_variants(num_gpus: int) -> list[dict[str, typing.Any]]:
- """
- Generate all (consumer_count, tp/pp/sp) variants for given GPU count.
- """
- results = []
-
- if num_gpus < 2:
- return results
- if num_gpus == 2:
- num_gpus = [2]
- elif num_gpus <= 4:
- num_gpus = [2, num_gpus]
- else:
- num_gpus = [2, 4, min(num_gpus, 8)]
-
- for gpus in num_gpus:
- consumers = consumer_counts(gpus)
- remaining = gpus - consumers
- par_vars = parallelism_variants(remaining)
- for pv in par_vars:
- results.append(
- {
- "total_gpus": gpus,
- "consumers_gpu_count": consumers,
- "fast_llm_gpus_count": remaining,
- "consumers_gpus": list(range(consumers)),
- "fast_llm_gpus": list(range(consumers, gpus)),
- "tensor_parallel": pv["tp"],
- "pipeline_parallel": pv["pp"],
- "sequence_data_parallel": pv["sp"],
- }
- )
-
- return results
-
-
-variants = generate_variants(torch.cuda.device_count())
-
-
-@pytest.mark.slow
-@requires_cuda
-@pytest.mark.parametrize(
- "variant",
- variants,
- ids=[
- f"gpu{v['total_gpus']}_cgpus{v['consumers_gpu_count']}_fgpus{v['fast_llm_gpus_count']}"
- f"_tp{v['tensor_parallel']}_pp{v['pipeline_parallel']}_sp{v['sequence_data_parallel']}"
- for v in variants
- ],
-)
-def test_trainer_events_with_streaming(fake_redis_server, variant, run_distributed_script_lean, result_path, request):
- stream_config, fake_redis_client, fake_redis_server_killer = fake_redis_server
- test_result_path = result_path / request.node.name
- test_result_path_fast_llm = test_result_path / "fast_llm"
- test_result_path_consumers = test_result_path / "consumers"
-
- broadcast_world_size = variant["consumers_gpu_count"] + 1
- fake_consumers_broadcast_ranks = list(range(variant["consumers_gpu_count"]))
- fake_consumers_assigned_gpus = variant["consumers_gpus"]
- fast_llm_broadcast_rank = variant["consumers_gpu_count"]
- fast_llm_assigned_gpus = variant["fast_llm_gpus"]
- train_iters = 2
-
- model_config = copy.deepcopy(MODEL_CONFIGS["mistral"].config_dict)
- model_config["data"]["datasets"] = {"training": stream_config.to_dict()}
- model_config["data"]["sampling"] = {"shuffle": "disabled"}
- model_config["training"]["train_iters"] = train_iters
- model_config["training"]["export"] = {"interval": 1, "format": MODEL_CONFIGS["mistral"].checkpoint_format.name}
- model_config["batch"]["micro_batch_size"] = 1
- model_config["batch"]["truncate_documents"] = False
- model_config["run"]["experiment_dir"] = test_result_path_fast_llm
- model_config["model"]["distributed"]["tensor_parallel"] = variant["tensor_parallel"]
- model_config["model"]["distributed"]["pipeline_parallel"] = variant["pipeline_parallel"]
- model_config["model"]["distributed"]["sequence_data_parallel"] = variant["sequence_data_parallel"]
-
- # We use same stream for messages in the test. Also make all fields explicit,
- # so fake consumers can read them as well from this dict config
- model_config["events"] = {
- "redis": {
- "host": stream_config.redis.host,
- "port": stream_config.redis.port,
- "stream_key": "fast_llm_events",
- "payload_key": "event",
- },
- "weights_broadcast": {
- "enabled": True,
- "initial_weights_step_message_type": "initial_weights_step",
- "initial_weights_step_message_includes_weights": True,
- "weights_ready_message_type": "weights_ready",
- "rdvz_master_address": "127.0.0.1",
- "rdvz_master_port": 19999,
- "world_size": broadcast_world_size,
- "rank": fast_llm_broadcast_rank,
- },
- "training_finished": {
- "enabled": True,
- "training_finished_message_type": "training_finished",
- },
- }
-
- batch_size = model_config["batch"]["batch_size"]
- sequence_length = model_config["batch"]["sequence_length"]
- with redis_batch_producer(
- redis_client=fake_redis_client,
- fake_redis_server_killer=fake_redis_server_killer,
- stream_config=stream_config,
- batch_size=batch_size,
- sequence_length=sequence_length,
- ):
- with run_fake_events_consumers(
- model_config=model_config,
- test_result_path=test_result_path_consumers,
- broadcast_world_size=broadcast_world_size,
- fake_consumers_broadcast_ranks=fake_consumers_broadcast_ranks,
- assigned_gpus=fake_consumers_assigned_gpus,
- ):
- run_fast_llm_training(
- model_config=model_config,
- run_distributed_script=run_distributed_script_lean,
- assigned_gpus=fast_llm_assigned_gpus,
- )
- check_events_results(
- test_results_path_fast_llm=test_result_path_fast_llm,
- test_results_path_consumers=test_result_path_consumers,
- consumer_count=len(fake_consumers_assigned_gpus),
- training_steps=train_iters,
- model_checkpoint_format=MODEL_CONFIGS["mistral"].checkpoint_format.name,
- )
diff --git a/tests/utils/distributed_configs.py b/tests/utils/distributed_configs.py
index 83ed6836a..9c1cc9369 100644
--- a/tests/utils/distributed_configs.py
+++ b/tests/utils/distributed_configs.py
@@ -222,6 +222,17 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon
num_gpus=2,
compare_config=_compare_layer_match,
),
+ # Depth-first micro-batches, tensor-parallel
+ DistributedTestingConfig(
+ name="tp2_df4",
+ compare="df4",
+ config_args=[
+ "model.distributed.tensor_parallel=2",
+ "batch.depth_first_micro_batches=4",
+ ],
+ num_gpus=2,
+ compare_config=_compare_layer_match,
+ ),
# Cross-entropy splits
DistributedTestingConfig(
name="stp2_ce4",
@@ -247,17 +258,6 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon
num_gpus=4,
compare_config=_compare_layer_match,
),
- # Depth-first micro-batches, tensor-parallel
- DistributedTestingConfig(
- name="tp2_df4",
- compare="df4",
- config_args=[
- "model.distributed.tensor_parallel=2",
- "batch.depth_first_micro_batches=4",
- ],
- num_gpus=4,
- compare_config=_compare_layer_match,
- ),
# Breadth-first micro-batches
DistributedTestingConfig(
name="sdp2_stp2_bf4",
diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py
index 6156cb709..2834b0728 100644
--- a/tests/utils/model_configs.py
+++ b/tests/utils/model_configs.py
@@ -12,6 +12,7 @@
from fast_llm.config import set_nested_dict_value
from fast_llm.engine.checkpoint.config import CheckpointFormat
+from fast_llm.engine.distributed.config import DistributedBackend
from fast_llm.engine.multi_stage.config import FastLLMModelConfig
from fast_llm.engine.training.config import TrainerConfig
from fast_llm.models.gpt.conversion.config import (
@@ -51,6 +52,7 @@ class ModelTestingGroup(enum.StrEnum):
generate = "generate"
megatron = "megatron"
distributed = "distributed"
+ streaming = "streaming"
class ModelTestingGroupAction(enum.StrEnum):
@@ -147,13 +149,17 @@ def model_class(self):
def base_model_config_class(self):
return self.model_config_class.get_base_model_config_class()
+ @functools.cached_property
+ def distributed_backend(self):
+ return DistributedBackend(self.config_dict["model"]["distributed"]["backend"])
+
def should_skip(self, distributed_config: DistributedTestingConfig) -> bool:
return any(re.search(pattern, distributed_config.name) for pattern in self.skip_tests)
-def _update_and_add_testing_config(
- old_name: str,
- new_name: str,
+def update_and_add_testing_config(
+ old_name: str | ModelTestingConfig,
+ new_name: str | None,
*,
model_type: str | None = None,
updates: dict[str | tuple[str, ...], typing.Any] | None = None,
@@ -162,7 +168,7 @@ def _update_and_add_testing_config(
**kwargs,
) -> ModelTestingConfig:
- config = MODEL_CONFIGS[old_name]
+ config = old_name if isinstance(old_name, ModelTestingConfig) else MODEL_CONFIGS[old_name]
config_dict = copy.deepcopy(config.config_dict)
if updates is not None:
for keys, update in updates.items():
@@ -174,14 +180,15 @@ def _update_and_add_testing_config(
megatron_args = config.megatron_args + megatron_args
new_config = dataclasses.replace(
config,
- name=new_name,
+ name=config.name if new_name is None else new_name,
model_type=config.model_type if model_type is None else model_type,
groups=groups,
config_dict=config_dict,
megatron_args=megatron_args,
**kwargs,
)
- MODEL_CONFIGS[new_name] = new_config
+ if new_name is not None:
+ MODEL_CONFIGS[new_name] = new_config
return new_config
@@ -254,6 +261,7 @@ def _update_and_add_testing_config(
"distributed": {
"reproducible_init": True,
"timeout": 20,
+ "backend": "nccl",
},
},
"batch": {"batch_size": 8, "sequence_length": 512},
@@ -304,7 +312,7 @@ def _update_and_add_testing_config(
},
)
-_update_and_add_testing_config(
+update_and_add_testing_config(
# Tests MQA.
"gpt_2",
"starcoder",
@@ -323,7 +331,7 @@ def _update_and_add_testing_config(
},
)
-_update_and_add_testing_config(
+update_and_add_testing_config(
# Tests intermediate between gpt2 and llama, closest converter to gpt2.
"gpt_2",
"starcoder_2",
@@ -352,7 +360,7 @@ def _update_and_add_testing_config(
del MODEL_CONFIGS["starcoder_2"].config_dict["model"]["base_model"]["embeddings"]["num_position_embeddings"]
-_update_and_add_testing_config(
+update_and_add_testing_config(
# Main tested model.
"starcoder_2",
"llama",
@@ -381,10 +389,11 @@ def _update_and_add_testing_config(
ModelTestingGroup.generate: ModelTestingGroupAction.broken,
ModelTestingGroup.megatron: ModelTestingGroupAction.normal,
ModelTestingGroup.distributed: ModelTestingGroupAction.normal,
+ ModelTestingGroup.streaming: ModelTestingGroupAction.normal,
},
)
-_update_and_add_testing_config(
+update_and_add_testing_config(
# Tests llama3-style rotary embeddings.
"llama",
"llama_3",
@@ -404,7 +413,7 @@ def _update_and_add_testing_config(
},
)
-_update_and_add_testing_config(
+update_and_add_testing_config(
# Tests yarn-style rotary embeddings.
"llama",
"llama_yarn",
@@ -424,7 +433,7 @@ def _update_and_add_testing_config(
},
)
-_update_and_add_testing_config(
+update_and_add_testing_config(
# Tests diffusion llama converter.
"llama_yarn",
"diffusion_llama",
@@ -449,7 +458,7 @@ def _update_and_add_testing_config(
_llama_block = MODEL_CONFIGS["llama"].config_dict["model"]["base_model"]["decoder"]["block"]
-_update_and_add_testing_config(
+update_and_add_testing_config(
# Tests multi-token prediction, custom HF model and converter.
"llama",
"mtp_llama",
@@ -479,7 +488,7 @@ def _update_and_add_testing_config(
skip_tests=(r"ce4", r"ms"),
)
-_update_and_add_testing_config(
+update_and_add_testing_config(
# Tests partial linear biases, Qwen2 converter.
"llama",
"qwen_2",
@@ -501,7 +510,7 @@ def _update_and_add_testing_config(
},
)
-_update_and_add_testing_config(
+update_and_add_testing_config(
# Tests diffusion dream converter.
"qwen_2",
"dream",
@@ -523,7 +532,7 @@ def _update_and_add_testing_config(
auto_model_class=transformers.AutoModel,
)
-_update_and_add_testing_config(
+update_and_add_testing_config(
# Tests sliding window attention, mistral converter.
"llama",
"mistral",
@@ -546,7 +555,7 @@ def _update_and_add_testing_config(
_mistral_base_model = MODEL_CONFIGS["mistral"].config_dict["model"]["base_model"]
-_update_and_add_testing_config(
+update_and_add_testing_config(
# Tests logit distillation.
"mistral",
"mistral_distill_logits",
@@ -574,7 +583,7 @@ def _update_and_add_testing_config(
skip_tests=("ms", "pp2s1_bf4", "pp2s2_bf4", "sdp2"),
)
-_update_and_add_testing_config(
+update_and_add_testing_config(
"mistral_distill_logits",
"mistral_reverse_kl",
updates={
@@ -595,7 +604,7 @@ def _update_and_add_testing_config(
skip_tests=("sdp", "ms", "pp"),
)
-_update_and_add_testing_config(
+update_and_add_testing_config(
"mistral_distill_logits",
"mistral_distill_activations",
updates={
@@ -626,7 +635,7 @@ def _update_and_add_testing_config(
skip_tests=("sdp", "ms", "pp", "tp", GRAD_ACC, "fp16"),
)
-_update_and_add_testing_config(
+update_and_add_testing_config(
# Tests mixture of experts, mixtral converter.
"llama",
"mixtral",
@@ -654,7 +663,7 @@ def _update_and_add_testing_config(
)
-_update_and_add_testing_config(
+update_and_add_testing_config(
# Tests hybrid Mamba 2.
"llama",
"hybrid_mamba",
@@ -695,7 +704,7 @@ def _update_and_add_testing_config(
skip_tests=("sdp", "ms"),
)
-_update_and_add_testing_config(
+update_and_add_testing_config(
# Tests vision multimodal.
"llama",
"llava",
@@ -738,7 +747,7 @@ def _update_and_add_testing_config(
)
-_update_and_add_testing_config(
+update_and_add_testing_config(
# Tests hybrid with attention + gated delta net mixer.
"llama",
"apriel2_text_gdn_hybrid",
@@ -789,7 +798,7 @@ def _update_and_add_testing_config(
skip_tests=("sdp", "ms", TP_NO_STP),
)
-_update_and_add_testing_config(
+update_and_add_testing_config(
# Tests apriel2 format with pattern decoder mixing all mixer types.
# This comprehensive test exercises: attention, mamba, stochastic mixer, sliding window attention, gdn.
"llama",
@@ -912,7 +921,7 @@ def _update_and_add_testing_config(
)
-_update_and_add_testing_config(
+update_and_add_testing_config(
# Tests apriel2 multimodal format combining pattern decoder with vision encoder.
# Uses the same decoder as apriel2_text_all_hybrid but adds vision capabilities.
"apriel2_text_all_hybrid",
@@ -955,7 +964,7 @@ def _update_and_add_testing_config(
)
-_update_and_add_testing_config(
+update_and_add_testing_config(
# Tests hybrid with KDA mixer.
"llama",
"hybrid_kda",
@@ -1009,7 +1018,7 @@ def testing_group_enabled(item: pytest.Function, skip_slow: bool, skip_extra_slo
model_testing_config = item.callspec.params["model_testing_config"]
model_config: ModelTestingConfig = MODEL_CONFIGS[model_testing_config]
for group in groups:
- action = model_config.groups[group]
+ action = model_config.groups.get(group, ModelTestingGroupAction.unimportant)
if action == ModelTestingGroupAction.main:
pass
elif action == ModelTestingGroupAction.normal and not skip_slow:
diff --git a/tests/utils/redis.py b/tests/utils/redis.py
index 7e9d3b689..591ee74e6 100644
--- a/tests/utils/redis.py
+++ b/tests/utils/redis.py
@@ -1,36 +1,25 @@
import contextlib
+import itertools
+import json
import pathlib
import socket
import threading
import time
import fakeredis
-import orjson
-import pytest
from fast_llm.data.dataset.config import (
- IngestionType,
+ REDIS_DATA_STREAM,
+ REDIS_FIELD,
+ REDIS_GROUP_NAME,
+ RedisConfig,
SamplingConfig,
SamplingData,
SamplingParameters,
- ShufflingType,
StreamingDatasetConfig,
- StreamingDatasetRedisConfig,
)
from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig
-
-
-def get_stream_config():
- return StreamingDatasetConfig(
- redis=StreamingDatasetRedisConfig(
- host="localhost",
- port=6379,
- stream_key="test_stream",
- payload_key="data",
- ),
- group_name="test_group",
- consumer_name_prefix="consumer",
- )
+from fast_llm.models.gpt.config import GPTBatchConfig
def find_free_port():
@@ -40,32 +29,9 @@ def find_free_port():
return s.getsockname()[1]
-def push_msg(redis_client, config, tokens=None, stream_key_suffix=None):
+def push_msg(redis_client, tokens):
"""Push a message into FakeRedis stream."""
- msg = {
- "tokens": tokens,
- "tokens_dtype": "int64",
- }
- stream_key = config.redis.stream_key
- if stream_key_suffix is not None:
- stream_key += stream_key_suffix
- redis_client.xadd(stream_key, {config.redis.payload_key: orjson.dumps(msg)})
-
-
-class FakeRedisServerKiller:
- def __init__(self, server):
- self._server = server
- self._is_killed = False
- self._lock = threading.Lock()
-
- def kill(self):
- with self._lock:
- if not self._is_killed:
- try:
- self._server.shutdown()
- self._server.server_close()
- finally:
- self._is_killed = True
+ redis_client.xadd(REDIS_DATA_STREAM, {REDIS_FIELD: json.dumps({"tokens": tokens, "tokens_dtype": "int64"})})
def wait_until_stream_empty(
@@ -73,20 +39,7 @@ def wait_until_stream_empty(
stream_key,
consumer_group,
stop_event,
- consumer_count,
- ingestion_type: IngestionType,
):
- if ingestion_type == IngestionType.CONSUMER_GROUP:
- return wait_until_stream_empty_consumer_group(redis_client, stream_key, consumer_group, stop_event)
- elif ingestion_type == IngestionType.ONE_STREAM:
- raise NotImplementedError()
- elif ingestion_type == IngestionType.N_STREAMS:
- raise NotImplementedError()
- else:
- raise ValueError(f"Unknown ingestion type {ingestion_type.value}")
-
-
-def wait_until_stream_empty_consumer_group(redis_client, stream_key, consumer_group, stop_event):
"""
Wait until lag == 0, meaning all messages have been delivered AND acknowledged.
Absence of group mean test has not started yet, so we wait
@@ -106,7 +59,7 @@ def wait_until_stream_empty_consumer_group(redis_client, stream_key, consumer_gr
def get_consumer_count(redis_client, stop_event, config: StreamingDatasetConfig):
while not stop_event.is_set():
- res = redis_client.hget(f"{config.redis.stream_key}:consumer_count", "0")
+ res = redis_client.hget(f"{REDIS_DATA_STREAM}:consumer_count", "0")
if res is None:
time.sleep(0.05)
continue
@@ -114,67 +67,39 @@ def get_consumer_count(redis_client, stop_event, config: StreamingDatasetConfig)
@contextlib.contextmanager
-def redis_batch_producer(
- redis_client, fake_redis_server_killer, stream_config, batch_size, sequence_length, num_batches=None
-):
- stop_event = threading.Event()
- thread_exc = []
-
- def producer_loop():
- is_n_streams = stream_config.ingestion_type == IngestionType.N_STREAMS
- try:
- consumer_count = get_consumer_count(redis_client, stop_event, stream_config)
- stream = stream_config.redis.stream_key
- group = stream_config.group_name
- batch_idx = 0
- while not stop_event.is_set():
- if num_batches is not None and batch_idx >= num_batches:
+def redis_batch_producer(config: RedisConfig, batch_config: GPTBatchConfig):
+ with fake_redis_server(config):
+ stop_event = threading.Event()
+ client = config.get_client()
+
+ def producer_loop():
+ for sample_index in itertools.count():
+ if stop_event.is_set():
break
- for i in range(batch_size):
- if stop_event.is_set():
- return
- push_msg(
- redis_client,
- stream_config,
- [batch_idx * batch_size + i] * sequence_length,
- stream_key_suffix=f"_{i % consumer_count}" if is_n_streams else None,
- )
- wait_until_stream_empty(
- redis_client,
- stream,
- group,
- stop_event,
- consumer_count=consumer_count,
- ingestion_type=stream_config.ingestion_type,
- )
- batch_idx += 1
- except Exception as e:
- # if failed to push messages kill redis server so waiting side in the test would unlock
- fake_redis_server_killer.kill()
- thread_exc.append(e)
- raise
+ push_msg(client, [sample_index] * batch_config.sequence_length)
+ if sample_index % 5 == 0:
+ wait_until_stream_empty(client, REDIS_DATA_STREAM, REDIS_GROUP_NAME, stop_event)
- thread = threading.Thread(target=producer_loop, daemon=True)
- thread.start()
+ thread = threading.Thread(target=producer_loop, daemon=True)
+ thread.start()
- try:
- yield
- finally:
- stop_event.set()
- thread.join(timeout=10)
- if thread_exc:
- raise thread_exc[0]
+ try:
+ yield
+ finally:
+ stop_event.set()
+ thread.join(timeout=1)
+ client.close()
-def make_sampling(sequence_length, extra_tokens, num_samples, distributed):
+def make_sampling(sequence_length, num_samples, distributed):
return SamplingData(
parameters=SamplingParameters(
sequence_length=sequence_length,
- extra_tokens=extra_tokens,
+ extra_tokens=0,
num_samples=num_samples,
truncate_documents=False,
),
- config=SamplingConfig(shuffle=ShufflingType.disabled),
+ config=SamplingConfig(),
distributed=distributed,
dataset_name="test",
cache_directory=pathlib.Path("/tmp"),
@@ -182,17 +107,9 @@ def make_sampling(sequence_length, extra_tokens, num_samples, distributed):
)
-@pytest.fixture
-def stream_config():
- return get_stream_config()
-
-
-@pytest.fixture
-def fake_redis_server(stream_config):
+@contextlib.contextmanager
+def fake_redis_server(config: RedisConfig):
# We search for free port as port from previous test can still be not free even after server shutdown
- stream_config = stream_config.from_dict(stream_config.to_dict(), {("redis", "port"): find_free_port()})
-
- server_address = (stream_config.redis.host, stream_config.redis.port)
# ----- Monkey-patch handler to suppress broken pipes -----
orig_handle = fakeredis._tcp_server.TCPFakeRequestHandler.handle
@@ -209,27 +126,14 @@ def safe_handle(self):
fakeredis._tcp_server.TCPFakeRequestHandler.handle = safe_handle
- server = fakeredis.TcpFakeServer(server_address, server_type="redis")
- server_killer = FakeRedisServerKiller(server)
-
- # ----- Start server thread -----
- def serve():
- try:
- server.serve_forever()
- except Exception:
- # Extra safety: catch anything from serve_forever
- pass
-
- thread = threading.Thread(target=serve, daemon=True)
+ server = fakeredis.TcpFakeServer((config.host, config.port), server_type="redis")
+ thread = threading.Thread(target=server.serve_forever, daemon=True)
thread.start()
- # ----- reate a redis-py client pointing at the fake serve -----
- import redis
-
- client = redis.Redis(host=server_address[0], port=server_address[1])
-
- yield stream_config, client, server_killer
-
- # ----- Teardown -----
- server_killer.kill()
- thread.join()
+ try:
+ yield
+ finally:
+ # ----- Teardown -----
+ server.shutdown()
+ server.server_close()
+ thread.join()
diff --git a/tests/utils/run_test_script.py b/tests/utils/run_test_script.py
index e880e67ef..0b8232cf7 100644
--- a/tests/utils/run_test_script.py
+++ b/tests/utils/run_test_script.py
@@ -1,4 +1,3 @@
-import argparse
import functools
import os
import pathlib
@@ -12,7 +11,7 @@
from fast_llm.engine.distributed.config import DistributedConfig
from fast_llm.utils import Assert
from tests.utils.distributed_configs import DistributedTestingConfig
-from tests.utils.model_configs import MODEL_CONFIGS, ModelTestingConfig
+from tests.utils.model_configs import ModelTestingConfig
if typing.TYPE_CHECKING:
from tests.conftest import WorkerResources
@@ -47,22 +46,7 @@ def do_run_distributed_script(
@pytest.fixture(scope="session")
-def run_distributed_script(
- worker_resources: "WorkerResources",
- run_test_script_base_path: pathlib.Path,
- model_testing_config: ModelTestingConfig,
-):
- return functools.partial(
- do_run_distributed_script,
- rendezvous_port=worker_resources.rendezvous_port,
- torchrun_port=worker_resources.torchrun_port,
- )
-
-
-@pytest.fixture(scope="session")
-def run_distributed_script_lean(
- worker_resources: "WorkerResources",
-):
+def run_distributed_script(worker_resources: "WorkerResources"):
return functools.partial(
do_run_distributed_script,
rendezvous_port=worker_resources.rendezvous_port,
@@ -155,16 +139,6 @@ def run_test_script_for_all_models(
)
-def parse_run_distributed_script(args: list[str] | None = None):
- parser = argparse.ArgumentParser()
- parser.add_argument("base_path", type=pathlib.Path)
- parser.add_argument("model_testing_config", type=str)
- parser.add_argument("--no-distributed-capture", dest="distributed_capture", action="store_false")
-
- parsed = parser.parse_args(args)
- return parsed.base_path, MODEL_CONFIGS[parsed.model_testing_config], parsed.distributed_capture
-
-
@pytest.fixture(scope="session")
def compare_results_for_all_models(
worker_resources: "WorkerResources",
diff --git a/tests/utils/subtest.py b/tests/utils/subtest.py
new file mode 100644
index 000000000..4fea1fbba
--- /dev/null
+++ b/tests/utils/subtest.py
@@ -0,0 +1,273 @@
+import functools
+import json
+import logging
+import math
+import pathlib
+import sys
+import time
+import traceback
+import typing
+
+import pytest
+import torch
+
+from fast_llm.core.distributed import allreduce_scalar, safe_barrier
+from fast_llm.engine.config_utils.logging import configure_logging
+from fast_llm.engine.distributed.config import DistributedBackend, DistributedConfig
+from fast_llm.engine.distributed.distributed import ProcessGroupPool
+from fast_llm.utils import Assert, get_and_reset_memory_usage_mib, header
+
+logger = logging.getLogger(__name__)
+
+
+class DistributedTestContext:
+ def __init__(
+ self,
+ do_capture: bool,
+ timeout: float = 20.0,
+ init_method: str = "env://",
+ backend: DistributedBackend = DistributedBackend.nccl,
+ ) -> None:
+ self._do_capture = do_capture
+ self._timeout = timeout
+ self._init_method = init_method
+ self._backend = backend
+
+ def __enter__(self):
+ if self._do_capture:
+ logger.warning(
+ "Capturing output and forwarding to associated tests. Run with `--no-distributed-capture` to disable."
+ )
+
+ self._pool = ProcessGroupPool(
+ timeout=self._timeout, init_method=self._init_method, backend=self._backend
+ ).__enter__()
+ self._rank = self._pool.rank
+ self._world_size = self._pool.world_size
+ self._failures = []
+ self._configure_logging()
+ self._group = self._pool.get_process_group(range(self._world_size), self._rank)
+ # TODO: Barriers needed?
+ safe_barrier(self._group, "start")
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ # Final barrier to ensure everything is done before torchrun potentially kills workers.
+ safe_barrier(self._group, "testing end")
+ # Let pytest know how things went.
+ # These should already be reported above, we repeat for convenience.
+ if self._failures:
+ raise RuntimeError(f"The following subtests failed: {", ".join(self._failures)}")
+ else:
+ logger.warning("All tests passed")
+
+ def subtest(self, base_path: pathlib.Path, name: str, num_gpus: int):
+ return self.DistributedSubtestContext(self, base_path, name, num_gpus)
+
+ def _configure_logging(self):
+ configure_logging(rank=self._rank, world_size=self._world_size)
+
+ @property
+ def rank(self) -> int:
+ return self._rank
+
+ @property
+ def world_size(self) -> int:
+ return self._world_size
+
+ class DistributedSubtestContext:
+ def __init__(
+ self, test_context: "DistributedTestContext", base_path: pathlib.Path, name: str, num_gpus: int
+ ) -> None:
+ self._test_context = test_context
+ self._path = base_path / name
+ self._name = name
+ self._num_gpus = num_gpus
+ self._skip = self._test_context._world_size < self._num_gpus
+ self._do_run = self._test_context._rank < num_gpus and not self._skip
+ self._do_capture = self._test_context._do_capture and self._do_run
+ self._success = False
+
+ def __enter__(self) -> typing.Self:
+ if self._do_capture:
+ self._sys_stdout = sys.stdout
+ self._sys_stderr = sys.stderr
+ self._path.mkdir(parents=True, exist_ok=True)
+ sys.stdout = self._path.joinpath(f"pytest_stdout_{self._test_context._rank}").open("w")
+ sys.stderr = self._path.joinpath(f"pytest_stderr_{self._test_context._rank}").open("w")
+ self._test_context._configure_logging()
+ # Logging is set to log to the old stdout, so we need to reconfigure.
+ self._start = time.perf_counter()
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ if self._skip:
+ # Skipped tests should exit right away.
+ Assert.none(exc_val)
+ logger.warning(
+ f"{self._name} {f"SKIPPED (not enough GPUs: {self._test_context._world_size} < {self._num_gpus})"})"
+ )
+ return
+
+ if self._do_capture:
+ try:
+ stdout_handle = sys.stdout
+ stderr_handle = sys.stderr
+ sys.stdout = self._sys_stdout
+ sys.stderr = self._sys_stderr
+ stdout_handle.close()
+ stderr_handle.close()
+ finally:
+ assert DistributedConfig.default_world_size > 1
+ self._test_context._configure_logging()
+
+ if exc_type is None:
+ self._success = True
+ else:
+ self._path.mkdir(parents=True, exist_ok=True)
+ self._path.joinpath(f"pytest_traceback_{self._test_context._rank}").write_text(traceback.format_exc())
+
+ logger.warning(f"{self._name} done, waiting for other ranks ({"PASSED" if self._success else "FAILED"})")
+
+ if (group := self._test_context._group) is not None:
+ # Barrier so `allreduce_scalar` doesn't go crazy in case of desync.
+ safe_barrier(group, self._name)
+ self._success = allreduce_scalar(self._success, dtype=torch.int64, group=group) == group.size()
+
+ if self._do_capture:
+ # Free resources to limit memory usage.
+ report = get_and_reset_memory_usage_mib(clear_cache=True, global_stats=True, reset_global_stats=True)
+ report["duration"] = time.perf_counter() - self._start
+
+ json.dump(report, self._path.joinpath(f"pytest_report_{self._test_context._rank}").open("w"))
+
+ if self._test_context._rank == 0:
+ set_subtest_success(self._path, self._success)
+ logger.warning(f"{self._name} {"PASSED" if self._success else "FAILED"}")
+ if not self._success:
+ self._test_context._failures.append(self._name)
+
+ return True
+
+ @property
+ def do_run(self) -> bool:
+ return self._do_run and not self._skip
+
+
+def set_subtest_success(path: pathlib.Path, success: bool = True):
+ path.joinpath("pytest_success").write_text(str(int(success)))
+
+
+def check_subtest_success(path: pathlib, fail: bool = True) -> bool:
+ if not path.is_dir():
+ if fail:
+ pytest.fail(f"Test {path.name} did not run", pytrace=False)
+ else:
+ return False
+ try:
+ return bool(int(path.joinpath("pytest_success").read_text()))
+ except OSError:
+ return False
+
+
+def format_resource_report(title: str, report: dict[str, float]) -> str:
+ return "".join(
+ [
+ f"{title}:\n ",
+ f"Max Reserved: {report.get("max_reserved", math.nan):.0f} MiB",
+ f"| Max Allocated: {report.get("max_allocated", math.nan):.0f} MiB".ljust(26),
+ f"| End Reserved: {report.get("reserved", math.nan):.0f} MiB".ljust(25),
+ f"| End Allocated: {report.get("allocated", math.nan):.0f} MiB".ljust(26),
+ f"| Duration: {report.get("duration", math.nan):.2f}".ljust(18),
+ f"| GPUs: {report["gpus"]:.0f}" if "gpus" in report else "",
+ ]
+ )
+
+
+@pytest.fixture(scope="function")
+def report_subtest(request: pytest.FixtureRequest):
+ verbose = request.config.getoption("verbose")
+ do_capture = request.config.getoption("distributed_capture")
+
+ def do_report_subtest(path: pathlib.Path, world_size: int) -> None:
+ success = check_subtest_success(path)
+ if not do_capture:
+ logger.warning("Distributed capture is disabled. See distributed test for run output.")
+ elif verbose > 1 or not success:
+ for rank in range(world_size):
+ for fd, file_ in (("stdout", sys.stdout), ("stderr", sys.stdout), ("traceback", sys.stderr)):
+ print(header(f"{fd} rank {rank}", 80), file=file_)
+ file_path = path / f"pytest_{fd}_{rank}"
+ try:
+ print(file_path.read_text(), file=file_)
+ except OSError:
+ print(f"<<< not found {file_path}>>>", file=file_)
+ else:
+ print("Set verbose > 1 to show run output.")
+
+ reports = {}
+ for rank in range(world_size):
+ try:
+ reports[f"rank_{rank}"] = json.load(path.joinpath(f"pytest_report_{rank}").open("r"))
+ except OSError:
+ reports[rank] = {}
+ keys = {key for report in reports.values() for key in report}
+ report = {key: max(report[key] for report in reports.values() if key in report) for key in keys}
+ report["gpus"] = world_size
+ reports["global"] = report
+
+ print(header(f"Resource usage", 80), file=sys.stderr)
+ for name, report in reports.items():
+ print(format_resource_report(name, report), file=sys.stderr)
+ setattr(request.node, "fast_llm_resource_report", report)
+
+ if not success:
+ raise RuntimeError(f"test {path.name} failed")
+
+ return do_report_subtest
+
+
+def parallel_worker(
+ rank: int,
+ world_size: int,
+ init_method: str,
+ backend: DistributedBackend,
+ do_capture: bool,
+ fn: typing.Callable,
+ fn_args: typing.Sequence[typing.Any],
+):
+ DistributedConfig.default_rank = rank
+ DistributedConfig.default_world_size = world_size
+ DistributedConfig.default_local_world_size = world_size
+ with DistributedTestContext(do_capture, 60, init_method, backend) as test_context:
+ fn(test_context, *fn_args)
+
+
+def do_run_parallel_script(
+ fn: typing.Callable,
+ fn_args: typing.Sequence[typing.Any],
+ port: int,
+ do_capture: bool,
+ world_size: int,
+ timeout: float = 240,
+ backend: DistributedBackend = DistributedBackend.nccl,
+):
+ if do_capture:
+ logger.warning(
+ "Capturing output and forwarding to associated tests. Run with `--no-distributed-capture` to disable."
+ )
+ torch.multiprocessing.spawn(
+ parallel_worker,
+ args=(world_size, f"tcp://localhost:{port}", backend, do_capture, fn, fn_args),
+ nprocs=world_size,
+ join=False,
+ ).join(timeout, grace_period=5)
+
+
+@pytest.fixture(scope="session")
+def run_parallel_script(worker_resources: "WorkerResources", request: pytest.FixtureRequest):
+ return functools.partial(
+ do_run_parallel_script,
+ port=worker_resources.rendezvous_port,
+ do_capture=request.config.getoption("distributed_capture"),
+ )
diff --git a/tests/utils/utils.py b/tests/utils/utils.py
index 3b79f7607..f0ca20db8 100644
--- a/tests/utils/utils.py
+++ b/tests/utils/utils.py
@@ -1,23 +1,14 @@
-import json
import logging
-import math
-import pathlib
-import sys
-import time
-import traceback
import typing
import pytest
import torch
-from fast_llm.core.distributed import ProcessGroup, allreduce_scalar, safe_barrier
from fast_llm.engine.base_model.base_model import Layer
from fast_llm.engine.base_model.config import set_model_names
-from fast_llm.engine.config_utils.logging import configure_logging
from fast_llm.engine.distributed.distributed import Distributed
from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageConfig
from fast_llm.engine.multi_stage.stage import Stage
-from fast_llm.utils import get_and_reset_memory_usage_mib, header
from tests.utils.global_variables import TEST_RESULTS_PATH
logger = logging.getLogger(__name__)
@@ -65,137 +56,3 @@ def get_stage(
stage.restore_parameters()
stage.reset_gradients()
return stage
-
-
-class DistributedSubtestContext:
- def __init__(
- self, base_path: pathlib.Path, name: str, group: ProcessGroup | None, num_gpus: int, enabled: bool = True
- ) -> None:
- self._path = base_path / name
- self._name = name
- self._group = group
- self._rank = 0 if group is None else group.rank()
- self._rank_enabled = self._rank < num_gpus
- self._enabled = enabled and self._rank_enabled
- self.success = False
-
- def __enter__(self) -> typing.Self:
- if self._enabled:
- self._sys_stdout = sys.stdout
- self._sys_stderr = sys.stderr
- self._path.mkdir(parents=True, exist_ok=True)
- sys.stdout = self._path.joinpath(f"pytest_stdout_{self._rank}").open("w")
- sys.stderr = self._path.joinpath(f"pytest_stderr_{self._rank}").open("w")
- # Logging is set to log to the old stdout, so we need to reconfigure.
- configure_logging()
- self._start = time.perf_counter()
- return self
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- if self._enabled:
- try:
- stdout_handle = sys.stdout
- stderr_handle = sys.stderr
- sys.stdout = self._sys_stdout
- sys.stderr = self._sys_stderr
- stdout_handle.close()
- stderr_handle.close()
- finally:
- configure_logging()
-
- if exc_type is None:
- self.success = True
- else:
- self._path.mkdir(parents=True, exist_ok=True)
- self._path.joinpath(f"pytest_traceback_{self._rank}").write_text(traceback.format_exc())
-
- if self._group is not None:
- # Barrier so `allreduce_scalar` doesn't go crazy in case of desync.
- safe_barrier(self._group, self._name)
- self.success = allreduce_scalar(self.success, dtype=torch.int64, group=self._group) == self._group.size()
-
- if self._rank_enabled:
- # Free resources to limit memory usage.
- report = get_and_reset_memory_usage_mib(clear_cache=True, global_stats=True, reset_global_stats=True)
- report["duration"] = time.perf_counter() - self._start
-
- json.dump(report, self._path.joinpath(f"pytest_report_{self._rank}").open("w"))
-
- logger.warning(f"{self._name} {"PASSED" if self.success else "FAILED"})")
- if self._rank == 0:
- set_subtest_success(self._path, self.success)
-
- return True
-
-
-def set_subtest_success(path: pathlib.Path, success: bool = True):
- path.joinpath("pytest_success").write_text(str(int(success)))
-
-
-def check_subtest_success(path: pathlib, fail: bool = True) -> bool:
- if not path.is_dir():
- if fail:
- pytest.fail(f"Test {path.name} did not run", pytrace=False)
- else:
- return False
- try:
- return bool(int(path.joinpath("pytest_success").read_text()))
- except OSError:
- return False
-
-
-def format_resource_report(title: str, report: dict[str, float]) -> str:
- return "".join(
- [
- f"{title}:\n ",
- f"Max Reserved: {report.get("max_reserved", math.nan):.0f} MiB",
- f"| Max Allocated: {report.get("max_allocated", math.nan):.0f} MiB".ljust(26),
- f"| End Reserved: {report.get("reserved", math.nan):.0f} MiB".ljust(25),
- f"| End Allocated: {report.get("allocated", math.nan):.0f} MiB".ljust(26),
- f"| Duration: {report.get("duration", math.nan):.2f}".ljust(18),
- f"| GPUs: {report["gpus"]:.0f}" if "gpus" in report else "",
- ]
- )
-
-
-@pytest.fixture(scope="function")
-def report_subtest(request: pytest.FixtureRequest):
- verbose = request.config.getoption("verbose")
- do_capture = request.config.getoption("distributed_capture")
-
- def do_report_subtest(path: pathlib.Path, world_size: int) -> None:
- success = check_subtest_success(path)
- if not do_capture:
- logger.warning("Distributed capture is disabled. See distributed test for run output.")
- elif verbose > 1 or not success:
- for rank in range(world_size):
- for fd, file_ in (("stdout", sys.stdout), ("stderr", sys.stdout), ("traceback", sys.stderr)):
- print(header(f"{fd} rank {rank}", 80), file=file_)
- file_path = path / f"pytest_{fd}_{rank}"
- try:
- print(file_path.read_text(), file=file_)
- except OSError:
- print(f"<<< not found {file_path}>>>", file=file_)
- else:
- print("Set verbose > 1 to show run output.")
-
- reports = {}
- for rank in range(world_size):
- try:
- reports[f"rank_{rank}"] = json.load(path.joinpath(f"pytest_report_{rank}").open("r"))
- except OSError:
- reports[rank] = {}
- keys = {key for report in reports.values() for key in report}
- report = {key: max(report[key] for report in reports.values() if key in report) for key in keys}
- report["gpus"] = world_size
- reports["global"] = report
-
- print(header(f"Resource usage", 80), file=sys.stderr)
- for name, report in reports.items():
- print(format_resource_report(name, report), file=sys.stderr)
- setattr(request.node, "fast_llm_resource_report", report)
-
- if not success:
- raise RuntimeError(f"test {path.name} failed")
-
- return do_report_subtest