From 9f0704c18293b3c6247a02ec88cc440b7ebcc26f Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 17 Dec 2025 12:59:02 -0500 Subject: [PATCH 01/12] Simplified pipeline RL --- fast_llm/data/data/data_loader_wrapper.py | 16 +- fast_llm/data/data/gpt/data.py | 67 ++---- fast_llm/data/dataset/abstract_iterable.py | 30 --- fast_llm/data/dataset/config.py | 89 +------- fast_llm/data/dataset/sampled.py | 69 +++--- fast_llm/data/dataset/streaming.py | 235 ++++++--------------- fast_llm/engine/distributed/config.py | 7 +- fast_llm/engine/training/config.py | 8 - fast_llm/engine/training/trainer_events.py | 10 +- fast_llm/redis/config.py | 12 -- setup.cfg | 1 - tests/data/gptdata_streaming_test.py | 12 +- tests/data/test_streaming.py | 58 ++--- tests/models/test_checkpoint.py | 1 - tests/trainer/test_events.py | 1 - tests/utils/redis.py | 57 +---- 16 files changed, 172 insertions(+), 501 deletions(-) delete mode 100644 fast_llm/data/dataset/abstract_iterable.py diff --git a/fast_llm/data/data/data_loader_wrapper.py b/fast_llm/data/data/data_loader_wrapper.py index f9e517248..a44aa191d 100644 --- a/fast_llm/data/data/data_loader_wrapper.py +++ b/fast_llm/data/data/data_loader_wrapper.py @@ -1,4 +1,3 @@ -import torch.distributed import torch.utils.data.dataloader from fast_llm.core.distributed import broadcast_object @@ -12,19 +11,16 @@ class DistributedDataLoaderWrapper: def __init__( self, - dataloader: torch.utils.data.dataloader.DataLoader | None, - rank: int, + data_loader: torch.utils.data.dataloader.DataLoader, process_group: torch.distributed.ProcessGroup | None, ): - self.dataloader = dataloader - self.rank = rank + self._data_loader = data_loader + self._rank = 0 if process_group is None else process_group.rank() self.process_group = process_group - assert (self.rank == 0 and self.dataloader is not None) or (self.rank > 0 and self.dataloader is None) - def __iter__(self): - if self.rank == 0: - self.iterator = iter(self.dataloader) + if self._rank == 0: + self.iterator = iter(self._data_loader) if self.process_group is None: return self.iterator return self @@ -37,7 +33,7 @@ def __next__(self): # entire Batch object, which is inefficient for tensors because it serializes # (pickles) them before sending. - if self.rank == 0: + if self._rank == 0: try: data = next(self.iterator) # may raise StopIteration except Exception as e: diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 70966a051..9e1574437 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -11,7 +11,6 @@ from fast_llm.data.data.data_loader_wrapper import DistributedDataLoaderWrapper from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.data.dataset.abstract import SampledDataset -from fast_llm.data.dataset.abstract_iterable import SampledIterableDataset from fast_llm.data.dataset.config import SamplingParameters from fast_llm.data.dataset.gpt.config import GPTSamplingData from fast_llm.data.dataset.monitor import DatasetMonitor @@ -92,12 +91,7 @@ def setup( dataset_name=dataset_name, ) dataset = self._config.datasets[dataset_name].build_and_sample(sampling) - if isinstance(dataset, SampledDataset): - self._datasets[dataset_name] = DatasetMonitor(dataset, self._config.data_sample_warn_time_ms) - else: - # Do not set monitor for iterable dataset as monitor only works with map style datasets - assert isinstance(dataset, SampledIterableDataset) - self._datasets[dataset_name] = dataset + self._datasets[dataset_name] = DatasetMonitor(dataset, self._config.data_sample_warn_time_ms) safe_barrier(self._distributed.world_group, "data_preparation", timeout) self._is_setup = True @@ -123,45 +117,24 @@ def get_iterator( Assert.in_range_incl(batch_config.sequence_length, 1, sampling_parameters.sequence_length) log_main_rank(f"Initializing {dataset_name} dataset iterator from sample {consumed_samples}...") - dataset = self._datasets[dataset_name] - - if isinstance(dataset, SampledDataset): - data_loader = torch.utils.data.DataLoader( - dataset, # noqa - batch_sampler=SampledDatasetIterator( - total_samples=len(self._datasets[dataset_name]), - begin_index=consumed_samples, - micro_batch_size=batch_config.micro_batch_size, - data_rank=self._distributed.config.batch_data_rank, - data_parallel=self._distributed.config.batch_data_parallel, - ), - num_workers=num_workers, - prefetch_factor=prefetch_factor, - pin_memory=True, - collate_fn=LanguageModelBatch.from_samples, - multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None, - ) - - elif isinstance(dataset, SampledIterableDataset): - if ( - self.distributed.model_and_sequence_data_group is None - or self.distributed.model_and_sequence_data_group.rank() == 0 - ): - rank = 0 - data_loader = torch.utils.data.DataLoader( - dataset, # noqa - batch_size=batch_config.micro_batch_size, - num_workers=0 if num_workers == 0 else 1, - prefetch_factor=prefetch_factor, - pin_memory=True, - collate_fn=LanguageModelBatch.from_samples, - multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None, - ) - else: - rank = self.distributed.model_and_sequence_data_group.rank() - data_loader = None - data_loader = DistributedDataLoaderWrapper( - data_loader, rank, self.distributed.model_and_sequence_data_group - ) + data_loader = torch.utils.data.DataLoader( + self._datasets[dataset_name], # noqa + batch_sampler=SampledDatasetIterator( + total_samples=len(self._datasets[dataset_name]), + begin_index=consumed_samples, + micro_batch_size=batch_config.micro_batch_size, + data_rank=self._distributed.config.batch_data_rank, + data_parallel=self._distributed.config.batch_data_parallel, + ), + num_workers=num_workers, + prefetch_factor=prefetch_factor, + pin_memory=True, + collate_fn=LanguageModelBatch.from_samples, + multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None, + ) + + if False: + # TODO: ====== do ====== + data_loader = DistributedDataLoaderWrapper(data_loader, self.distributed.model_and_sequence_data_group) return iter(data_loader) diff --git a/fast_llm/data/dataset/abstract_iterable.py b/fast_llm/data/dataset/abstract_iterable.py deleted file mode 100644 index 770f4f97c..000000000 --- a/fast_llm/data/dataset/abstract_iterable.py +++ /dev/null @@ -1,30 +0,0 @@ -import abc -import typing - -import torch.utils.data - -from fast_llm.data.sample.abstract import Sample - -if typing.TYPE_CHECKING: - from fast_llm.data.dataset.config import SamplingData - - -# NOTE: We need to inherit from IterableDataset otherwise torch data loader can not detect it properly -class SampledIterableDataset[SampleType: Sample](torch.utils.data.IterableDataset[SampleType]): - """ - A sampled dataset class that provides an iterator over samples. - """ - - # NOTE: We add name here so it is compatible with Fast-LLM Dataset - @property - @abc.abstractmethod - def name(self) -> str: - """ - A name for the dataset to facilitate identification and debugging. - """ - - -class SamplableIterableDataset[SampleType: Sample](SampledIterableDataset[SampleType]): - @abc.abstractmethod - def sample(self, config: "SamplingData") -> SampledIterableDataset[SampleType]: - pass diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index f2a24c48b..1da79e214 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -7,7 +7,7 @@ import pathlib import typing -from fast_llm.config import Config, Field, FieldHint, FieldUpdate, UpdateType, check_field, config_class +from fast_llm.config import Config, Field, FieldHint, UpdateType, check_field, config_class from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset from fast_llm.data.preprocessing.abstract import PreprocessingConfig from fast_llm.data.sample.abstract import Sample @@ -15,8 +15,8 @@ from fast_llm.utils import Assert, normalize_probabilities if typing.TYPE_CHECKING: - from fast_llm.data.dataset.abstract_iterable import SamplableIterableDataset, SampledIterableDataset from fast_llm.data.dataset.indexed import ConcatenatedDataset, DatasetSlice, IndexedDataset + from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.engine.distributed.distributed import Distributed logger = logging.getLogger(__name__) @@ -111,22 +111,16 @@ class SampledDatasetConfig[SampleType: Sample](DatasetConfig[SampleType]): A sampled dataset containing a prepared list or iterable of samples to be indexed sequentially (as-is) during training. """ - def build_and_sample( - self, sampling: SamplingData - ) -> "SampledDataset[SampleType] | SampledIterableDataset[SampleType]": + def build_and_sample(self, sampling: SamplingData) -> SampledDataset[SampleType]: raise NotImplementedError() @config_class() class SamplableDatasetConfig[SampleType: Sample](SampledDatasetConfig[SampleType]): - def build( - self, preprocessing: PreprocessingConfig - ) -> "SamplableDataset[SampleType] | SamplableIterableDataset[SampleType]": + def build(self, preprocessing: PreprocessingConfig) -> SamplableDataset[SampleType]: raise NotImplementedError() - def build_and_sample( - self, sampling: SamplingData - ) -> "SampledDataset[SampleType] | SampledIterableDataset[SampleType]": + def build_and_sample(self, sampling: SamplingData) -> SampledDataset[SampleType]: return self.build(sampling.preprocessing).sample(sampling) @@ -308,39 +302,6 @@ def build(self, preprocessing: PreprocessingConfig) -> "IndexedDataset[SampleTyp raise FileNotFoundError(self.path) -@config_class() -class StreamingDatasetRedisConfig(RedisConfig): - stream_key: str = FieldUpdate(default="fast_llm_streaming") - - payload_key: str = FieldUpdate( - default="data", - ) - - -class IngestionType(str, enum.Enum): - CONSUMER_GROUP = "consumer_group" - ONE_STREAM = "one_stream" - N_STREAMS = "n_streams" - - -class HashType(str, enum.Enum): - MESSAGE_INDEX = "message_index" - """Use the index of the received message for hashing. Provides precise distribution but may not be well shuffled.""" - - MESSAGE_ID = "message_id" - """Hash messages based on their unique message ID. Good for probabilistic distribution. - Redis message IDs are regenerated each time, so this is not reproducible. - """ - - MESSAGE_BODY = "message_body" - """Hash messages based on their payload content (bytes). Distributes messages roughly evenly. - Deterministic based on message content, but not perfectly balanced across ranks. - """ - - PRODUCER_PROVIDED = "producer_provided" - """Use the hash or index provided by the producer. Allows deterministic splitting and perfect balance.""" - - @config_class(dynamic_type={SampledDatasetConfig: "streaming"}) class StreamingDatasetConfig[SampleType: LanguageModelSample](SamplableDatasetConfig[SampleType]): """ @@ -349,48 +310,18 @@ class StreamingDatasetConfig[SampleType: LanguageModelSample](SamplableDatasetCo _abstract = False - redis: StreamingDatasetRedisConfig = Field( + redis: RedisConfig = Field( desc="Redis connection and stream settings used to fetch incoming training data.", hint=FieldHint.core, ) - group_name: str = Field( - default="fast_llm_dp_group", - desc="Name of the Redis consumer group used for data-parallel reading.", - hint=FieldHint.core, - ) - - consumer_name_prefix: str = Field( - default="fast_llm_dp_group_consumer", - desc="Prefix used to generate unique consumer names for each rank in Redis consumer group.", - hint=FieldHint.core, - ) - - ingestion_type: IngestionType = Field( - default=IngestionType.CONSUMER_GROUP, - desc="Strategy used to ingest data from Redis streams (consumer group, single stream, or multiple streams).", - hint=FieldHint.core, - ) - - hash_type: HashType = Field( - default=HashType.MESSAGE_ID, - desc="How to compute hash for assigning messages to ranks.", - hint=FieldHint.core, - ) - - hash_key: str = Field( - default="hash", - desc="Key in the message dict containing the hash or index provided by the producer.", - hint=FieldHint.core, - ) - - ack_period_per_consumer: int = Field( + acknowledge_interval: int = Field( default=10, desc="Number of messages after which the consumer acknowledges received IDs back to the Redis hash.", hint=FieldHint.core, ) - def build_and_sample(self, sampling: SamplingData) -> "SampledIterableDataset[SampleType]": - from fast_llm.data.dataset.streaming import StreamingDataset + def build_and_sample(self, sampling: SamplingData) -> SampledDataset[SampleType]: + from fast_llm.data.dataset.streaming import RedisStreamingDataset - return StreamingDataset[SampleType](self, sampling.distributed).sample(sampling) + return RedisStreamingDataset[SampleType](self, sampling.distributed.config).sample(sampling) diff --git a/fast_llm/data/dataset/sampled.py b/fast_llm/data/dataset/sampled.py index 979fd7a60..8ce780558 100644 --- a/fast_llm/data/dataset/sampled.py +++ b/fast_llm/data/dataset/sampled.py @@ -9,7 +9,6 @@ import yaml from fast_llm.data.dataset.abstract import SampledDataset -from fast_llm.data.dataset.abstract_iterable import SamplableIterableDataset, SampledIterableDataset from fast_llm.data.dataset.config import SamplingData, ShufflingType from fast_llm.data.dataset.indexed import IndexedDataset from fast_llm.data.sample.abstract import Sample @@ -432,53 +431,49 @@ def _load_yaml_data(self, data: dict[str, typing.Any]) -> None: self._unshuffled_documents = data["unshuffled_epochs"] * self._documents_per_epoch -class NaiveSampledIterableDataset[SampleType: Sample](SampledIterableDataset[SampleType]): +class SampledIterableDataset[SampleType: Sample](SampledDataset[SampleType]): def __init__( self, - iterable_dataset: SamplableIterableDataset[SampleType], + iterable_dataset: typing.Iterator[SampleType], sampling: SamplingData, ): - self._dataset = iterable_dataset + self._iterator = iterable_dataset self._config = sampling.config self._parameters = sampling.parameters + self._documents: list[SampleType] = [] + self._current_length = 0 + self._sample_length = self._parameters.sequence_length + self._parameters.extra_tokens - assert self._parameters.truncate_documents == False - assert self._config.shuffle == ShufflingType.disabled - - def __iter__(self) -> typing.Iterator[SampleType]: - sample_length = self._parameters.sequence_length + self._parameters.extra_tokens - current_sample_length = 0 - documents: list[SampleType] = [] - for doc in self._dataset: - if len(doc) > sample_length: - logging.warning(f"Dropping doc with length {len(doc)} higher then sample_length {sample_length}") + def __getitem__(self, index: int) -> SampleType: + while self._current_length < self._sample_length: + document = next(self._iterator) + if len(document) > self._sample_length: + logging.warning(f"Dropping document with length {len(document)} > {self._sample_length}.") continue - if current_sample_length + len(doc) > sample_length: - padding_length = sample_length - current_sample_length - assert padding_length > 0 - documents.append(documents[-1].get_padding(padding_length)) + self._documents.append(document) + self._current_length += len(document) - yield documents[0].from_documents(documents) - - documents = [doc] - current_sample_length = len(doc) + if self._current_length == self._sample_length: + documents = self._documents + self._documents = [] + self._current_length = 0 + else: + last_length = len(self._documents[-1]) + remaining_length = last_length - (self._current_length - self._sample_length) + if self._parameters.truncate_documents: + documents = self._documents[:-1] + [self._documents[-1].crop(0, remaining_length)] + self._documents = [self._documents[-1].crop(remaining_length, last_length)] else: - documents.append(doc) - current_sample_length += len(doc) + documents = self._documents[:-1] + [self._documents[0].get_padding(remaining_length)] + self._documents = [self._documents[-1]] + self._current_length = len(self._documents[0]) + sample = documents[0].from_documents(documents) + Assert.eq(len(sample), self._sample_length) + return sample - if current_sample_length == sample_length: - yield documents[0].from_documents(documents) - - documents = [] - current_sample_length = 0 - - if current_sample_length > 0: - padding_length = sample_length - current_sample_length - assert padding_length > 0 - documents.append(documents[-1].get_padding(padding_length)) - - yield documents[0].from_documents(documents) + def __len__(self) -> int: + return self._parameters.num_samples @property def name(self) -> str: - return self._dataset.name + return self._iterator.name diff --git a/fast_llm/data/dataset/streaming.py b/fast_llm/data/dataset/streaming.py index 1aabf60cc..0898125ad 100644 --- a/fast_llm/data/dataset/streaming.py +++ b/fast_llm/data/dataset/streaming.py @@ -1,19 +1,16 @@ +import json import typing +import redis import torch.utils.data -import xxhash -from fast_llm.data.dataset.abstract_iterable import SamplableIterableDataset, SampledIterableDataset -from fast_llm.data.dataset.config import HashType, IngestionType, SamplingData, StreamingDatasetConfig -from fast_llm.data.dataset.sampled import NaiveSampledIterableDataset +from fast_llm.data.dataset.abstract import SamplableDataset +from fast_llm.data.dataset.config import SamplingData, StreamingDatasetConfig +from fast_llm.data.dataset.sampled import SampledIterableDataset from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.data.sample.range import RangeSample from fast_llm.data.sample.token import TokenSample -from fast_llm.engine.config_utils.run import is_main_rank -from fast_llm.engine.distributed.distributed import Distributed - -if typing.TYPE_CHECKING: - import redis +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames def dtype_from_string(name: str) -> torch.dtype: @@ -23,70 +20,64 @@ def dtype_from_string(name: str) -> torch.dtype: raise ValueError(f"Unknown torch dtype: {name}") -class StreamingDataset[SampleType: LanguageModelSample](SamplableIterableDataset[SampleType]): - def __init__(self, config: StreamingDatasetConfig, distributed: Distributed): +REDIS_DATA_KEY = "fast_llm_streaming" +REDIS_GROUP_NAME = "fast_llm_group" + + +class RedisStreamingDataset[SampleType: LanguageModelSample](SamplableDataset[SampleType]): + def __init__(self, config: StreamingDatasetConfig, distributed_config: DistributedConfig): super().__init__() - if distributed.config.pipeline_parallel > 1: + if distributed_config.pipeline_parallel > 1: # NOTE: It is not yet clear whether the issue comes from the streaming dataset # itself or from the distributed data-loader wrappers, but currently it # interferes with pipeline-parallel training and causes a timeout during # the training step. raise NotImplementedError("Streaming dataset support is not implemented for pipeline-parallel training.") - self._name = f"redis[{config.redis.host}:{config.redis.port}]({config.redis.stream_key}|{config.group_name})[{config.redis.payload_key}]" + self._name = f"redis[{config.redis.host}:{config.redis.port}]({REDIS_DATA_KEY}|{REDIS_GROUP_NAME})[data]" self._config = config - self.batch_data_rank = distributed.config.batch_data_rank - self.batch_data_parallel = distributed.config.batch_data_parallel + self.batch_data_rank = distributed_config.batch_data_rank + self.batch_data_parallel = distributed_config.batch_data_parallel self.is_batch_data_group_leader = ( - distributed.model_and_sequence_data_group is None or distributed.model_and_sequence_data_group.rank() == 0 + distributed_config.get_distributed_dim(DistributedDimNames.model_and_sequence_data).rank == 0 ) - self.payload_key_b = self._config.redis.payload_key.encode() - self.hash_key_b = self._config.hash_key.encode() - self._set_consumer_count() + if distributed_config.rank == 0: + redis_client = redis.Redis(host=self._config.redis.host, port=self._config.redis.port) + # TODO: Not needed? + redis_client.hset(f"{REDIS_DATA_KEY}:consumer_count", "0", str(self.batch_data_parallel)) @property def name(self) -> str: return self._name def sample(self, config: SamplingData) -> SampledIterableDataset[LanguageModelSample]: - # TODO: actually sample the dataset and not return docs - return NaiveSampledIterableDataset(self, config) - - def _set_consumer_count(self): - import redis - - if is_main_rank(): - redis_client = redis.Redis(host=self._config.redis.host, port=self._config.redis.port) - redis_client.hset(f"{self._config.redis.stream_key}:consumer_count", "0", self.batch_data_parallel) - - def __getstate__(self) -> tuple[str, StreamingDatasetConfig, int, int, bool, bytes, bytes]: - return ( - self._name, - self._config, - self.batch_data_parallel, - self.batch_data_rank, - self.is_batch_data_group_leader, - self.payload_key_b, - self.hash_key_b, - ) - - def __setstate__(self, state: tuple[str, StreamingDatasetConfig, int, bool, bytes, bytes]): - name, config, batch_data_parallel, batch_data_rank, is_batch_data_group_leader, payload_key_b, hash_key_b = ( - state - ) - self._name = name - self._config = config - self.batch_data_parallel = batch_data_parallel - self.batch_data_rank = batch_data_rank - self.is_batch_data_group_leader = is_batch_data_group_leader - self.payload_key_b = payload_key_b - self.hash_key_b = hash_key_b + return SampledIterableDataset(iter(self), config) + + # def __getstate__(self) -> tuple[str, StreamingDatasetConfig, int, int, bool, bytes, bytes]: + # return ( + # self._name, + # self._config, + # self.batch_data_parallel, + # self.batch_data_rank, + # self.is_batch_data_group_leader, + # self.payload_key_b, + # self.hash_key_b, + # ) + + # def __setstate__(self, state: tuple[str, StreamingDatasetConfig, int, bool, bytes, bytes]): + # name, config, batch_data_parallel, batch_data_rank, is_batch_data_group_leader, payload_key_b, hash_key_b = ( + # state + # ) + # self._name = name + # self._config = config + # self.batch_data_parallel = batch_data_parallel + # self.batch_data_rank = batch_data_rank + # self.is_batch_data_group_leader = is_batch_data_group_leader + # self.payload_key_b = payload_key_b + # self.hash_key_b = hash_key_b def __iter__(self) -> typing.Iterator[LanguageModelSample]: - import orjson - import redis - worker_info = torch.utils.data.get_worker_info() if worker_info is not None and worker_info.num_workers > 1: raise RuntimeError("StreamingDataset can work only with one instance per rank") @@ -94,41 +85,12 @@ def __iter__(self) -> typing.Iterator[LanguageModelSample]: if not self.is_batch_data_group_leader: raise RuntimeError("Must be only called on the batch data group leader") - redis_client = redis.Redis(host=self._config.redis.host, port=self._config.redis.port) - - match self._config.ingestion_type: - case IngestionType.CONSUMER_GROUP: - messages_iter = self._iter_consumer_group - case IngestionType.ONE_STREAM: - messages_iter = self._iter_stream - case IngestionType.N_STREAMS: - messages_iter = self._iter_stream - case _: - raise ValueError(f"Unknown ingestion type {self._config.ingestion_type}") - - ack_hash = f"{self._config.redis.stream_key}:ack" - consumer_id = str(self.batch_data_rank) - # If one stream each groups receives all data otherwise each groups receives just its data - if self._config.ingestion_type == IngestionType.ONE_STREAM: - ack_period = self._config.ack_period_per_consumer * self.batch_data_parallel - else: - ack_period = self._config.ack_period_per_consumer - - for msg_data in messages_iter(redis_client, ack_hash=ack_hash, consumer_id=consumer_id, ack_period=ack_period): - data = orjson.loads(msg_data[self.payload_key_b]) - yield self._sample_from_msg_data(data) - - def _iter_consumer_group( - self, redis_client: "redis.Redis", ack_hash: str, consumer_id: str, ack_period: int - ) -> typing.Iterator[LanguageModelSample]: - import redis.exceptions + client = redis.Redis(host=self._config.redis.host, port=self._config.redis.port) # Create the consumer group at the start of the stream ("0") # If the stream already exists, XGROUP CREATE will fail unless we add mkstream=True try: - redis_client.xgroup_create( - name=self._config.redis.stream_key, groupname=self._config.group_name, id="0", mkstream=True - ) + client.xgroup_create(name=REDIS_DATA_KEY, groupname=REDIS_GROUP_NAME, id="0", mkstream=True) except redis.exceptions.ResponseError as e: if "BUSYGROUP" in str(e): # Consumer group already exists @@ -141,11 +103,11 @@ def _iter_consumer_group( # XREADGROUP reads from the consumer group # COUNT: max number of messages to fetch at once # BLOCK: wait for new messages (milliseconds) - messages = redis_client.xreadgroup( - groupname=self._config.group_name, - consumername=f"{self._config.consumer_name_prefix}_{self.batch_data_rank}", + messages = client.xreadgroup( + groupname=REDIS_GROUP_NAME, + consumername=f"fast_llm_consumer_{self.batch_data_rank}", # ">" reads only new messages that have not been delivered to any consumer - streams={self._config.redis.stream_key: ">"}, + streams={REDIS_DATA_KEY: ">"}, count=1, block=1000, # No explicit ACK: messages are processed immediately; on rank failure the job restarts, @@ -154,106 +116,33 @@ def _iter_consumer_group( ) if messages: for stream_key, msgs in messages: - assert stream_key == self._config.redis.stream_key.encode() + assert stream_key == REDIS_DATA_KEY.encode() for msg_id, msg_data in msgs: processed += 1 # TODO: or do it after processing all received messaged then count > 1? - if processed % ack_period == 0: - redis_client.hset(ack_hash, consumer_id, msg_id) - - yield msg_data - - def _iter_stream( - self, redis_client: "redis.Redis", ack_hash: str, consumer_id: str, ack_period: int - ) -> typing.Iterator[LanguageModelSample]: - last_id = "0-0" - stream_key = self._config.redis.stream_key - if self._config.ingestion_type == IngestionType.N_STREAMS: - stream_key += f"_{self.batch_data_rank}" - stream_key_b = stream_key.encode() - processed = 0 - while True: - messages = redis_client.xread( - streams={stream_key: last_id}, - count=1, - block=1000, - ) - if not messages: - continue - for this_stream_key_b, msgs in messages: - assert this_stream_key_b == stream_key_b - for msg_id, msg_data in msgs: - last_id = msg_id - - processed += 1 - # TODO: or do it after processing all received messaged then count > 1? - if processed % ack_period == 0: - redis_client.hset(ack_hash, consumer_id, last_id) + if processed % self._config.acknowledge_interval == 0: + client.hset(f"{REDIS_DATA_KEY}:ack", str(self.batch_data_rank), msg_id) - if self._config.ingestion_type == IngestionType.N_STREAMS or self._is_for_this_rank( - msg_id, msg_data, processed - 1 - ): - yield msg_data - - def _is_for_this_rank(self, msg_id: bytes, msg_data: dict, msg_index: int) -> bool: - hash_type = self._config.hash_type - - if hash_type is HashType.MESSAGE_INDEX: - h = msg_index - elif hash_type is HashType.MESSAGE_ID: - h = xxhash.xxh64(msg_id).intdigest() - elif hash_type is HashType.MESSAGE_BODY: - h = xxhash.xxh64(msg_data[self.payload_key_b]).intdigest() - elif hash_type is HashType.PRODUCER_PROVIDED: - h = self._get_hash_key_value(msg_data[self.hash_key_b]) - else: - raise ValueError(f"Unknown hash_type: {hash_type}") - - return (h % self.batch_data_parallel) == self.batch_data_rank - - def _get_hash_key_value(self, value: bytes | int | str): - if isinstance(value, int): - # already an integer - msg_hash = value - elif isinstance(value, bytes): - try: - # try decoding as UTF-8 string and converting to int - msg_hash = int(value.decode("utf-8")) - except ValueError: - # not an integer, treat as a hash string - import xxhash - - msg_hash = xxhash.xxh64(value).intdigest() - elif isinstance(value, str): - try: - msg_hash = int(value) - except ValueError: - # not an integer, treat as a hash string - import xxhash - - msg_hash = xxhash.xxh64(value.encode("utf-8")).intdigest() - else: - raise TypeError(f"Unexpected type for hash key: {type(value)}") - return msg_hash + yield self._read_document(json.loads(msg_data[b"data"])) - def _sample_from_msg_data(self, data: dict) -> LanguageModelSample: + def _read_document(self, data: dict) -> LanguageModelSample: tokens = torch.tensor(data["tokens"], dtype=dtype_from_string(data["tokens_dtype"])) sample_size = len(tokens) if "loss_masking_spans" in data: - loss_masking_spans = [tuple(el) for el in data["loss_masking_spans"]] + loss_masking_spans = RangeSample([(begin, end) for begin, end in data["loss_masking_spans"]], sample_size) else: loss_masking_spans = None if "chosen_spans" in data: - chosen_spans = [tuple(el) for el in data["chosen_spans"]] + chosen_spans = RangeSample([(begin, end) for begin, end in data["chosen_spans"]], sample_size) else: chosen_spans = None if "rejected_spans" in data: - rejected_spans = [tuple(el) for el in data["rejected_spans"]] + rejected_spans = RangeSample([(begin, end) for begin, end in data["rejected_spans"]], sample_size) else: rejected_spans = None return LanguageModelSample( TokenSample(tokens, [sample_size]), - RangeSample(loss_masking_spans, sample_size) if loss_masking_spans is not None else None, - RangeSample(chosen_spans, sample_size) if chosen_spans is not None else None, - RangeSample(rejected_spans, sample_size) if rejected_spans is not None else None, + loss_masking_spans, + chosen_spans, + rejected_spans, ) diff --git a/fast_llm/engine/distributed/config.py b/fast_llm/engine/distributed/config.py index 8b9c0c13f..532dfda25 100644 --- a/fast_llm/engine/distributed/config.py +++ b/fast_llm/engine/distributed/config.py @@ -87,7 +87,9 @@ def from_sizes_and_strides(cls, name: str, global_rank: int, *sizes_and_strides: start = global_rank rank = 0 world_size = 1 - for size, stride in sizes_and_strides: + for i, (size, stride) in enumerate(sizes_and_strides): + if i > 0: + Assert.multiple(stride, sizes_and_strides[i - 1][1]) rank_ = global_rank // stride % size start -= rank_ * stride rank += world_size * rank_ @@ -267,8 +269,6 @@ class DistributedConfig(Config): ) def _validate(self) -> None: - super()._validate() - if self.world_size is None: self.world_size = self.default_world_size if self.rank is None: @@ -352,6 +352,7 @@ def _validate(self) -> None: (self.pipeline_rank, pipeline_stride), ) + super()._validate() if self.reference_config is not None: self.compare(self.reference_config, ValueError) Assert.in_range(self.rank, 0, self.world_size) diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 7624c72c4..4795d80dc 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -29,7 +29,6 @@ from fast_llm.engine.optimizer.config import OptimizerConfig from fast_llm.engine.schedule.config import BatchConfig, ScheduleConfig from fast_llm.profile import ProfilingConfig -from fast_llm.redis.config import RedisConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -322,13 +321,6 @@ def _validate(self) -> None: self.wandb.alert.assert_sub_interval(self.logs) -@config_class() -class TrainerEventsRedisConfig(RedisConfig): - stream_key: str = FieldUpdate(default="fast_llm_events") - - payload_key: str = FieldUpdate(default="event") - - @config_class() class TrainerEvent(Config): enabled: bool = Field( diff --git a/fast_llm/engine/training/trainer_events.py b/fast_llm/engine/training/trainer_events.py index 8bce3e6de..93719615a 100644 --- a/fast_llm/engine/training/trainer_events.py +++ b/fast_llm/engine/training/trainer_events.py @@ -6,13 +6,17 @@ from fast_llm.engine.config_utils.run import is_main_rank from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel -from fast_llm.engine.training.config import TrainerEventsConfig, TrainerEventsRedisConfig, TrainingExportConfig +from fast_llm.engine.training.config import TrainerEventsConfig, TrainingExportConfig +from fast_llm.redis.config import RedisConfig logger = logging.getLogger(__name__) +REDIS_TRAINING_KEY = "fast_llm_events" + + class RedisEventSender: - def __init__(self, config: TrainerEventsRedisConfig): + def __init__(self, config: RedisConfig): self.config = config self.client = None @@ -30,7 +34,7 @@ def send(self, msg_type: str, payload: dict | None = None): payload = {} payload.update({"type": msg_type}) - self.client.xadd(self.config.stream_key, {self.config.payload_key: orjson.dumps(payload)}) + self.client.xadd(REDIS_TRAINING_KEY, {"event": orjson.dumps(payload)}) class TrainerEvents: diff --git a/fast_llm/redis/config.py b/fast_llm/redis/config.py index 5b6bfbddd..c36853787 100644 --- a/fast_llm/redis/config.py +++ b/fast_llm/redis/config.py @@ -14,15 +14,3 @@ class RedisConfig(Config): desc="Port number on which the Redis server is running.", hint=FieldHint.core, ) - - stream_key: str = Field( - default=None, - desc="Name of the Redis stream to read data from.", - hint=FieldHint.core, - ) - - payload_key: str = Field( - default=None, - desc="Key under which the message data is stored inside the Redis payload dict.", - hint=FieldHint.core, - ) diff --git a/setup.cfg b/setup.cfg index 34995ce96..495a9cf20 100644 --- a/setup.cfg +++ b/setup.cfg @@ -62,7 +62,6 @@ GENERATION = STREAMING = redis>=-7.1.0 - orjson>=3.11.5 # Required for supporting vision inputs VISION = diff --git a/tests/data/gptdata_streaming_test.py b/tests/data/gptdata_streaming_test.py index 3e388cc45..461c9756d 100644 --- a/tests/data/gptdata_streaming_test.py +++ b/tests/data/gptdata_streaming_test.py @@ -5,11 +5,11 @@ from fast_llm.config import NoAutoValidate from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.data.data.gpt.data import GPTData -from fast_llm.data.dataset.config import IngestionType +from fast_llm.data.dataset.config import StreamingDatasetConfig from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed from fast_llm.models.gpt.config import GPTBatchConfig -from tests.utils.redis import get_stream_config, make_sampling +from tests.utils.redis import make_sampling def distributed_gptdata_streaming_test( @@ -22,12 +22,8 @@ def distributed_gptdata_streaming_test( total_gpus, redis_port, result_path, - ingestion_type, ): - stream_config = get_stream_config() - stream_config = stream_config.from_dict( - stream_config.to_dict(), {("redis", "port"): redis_port, ("ingestion_type"): ingestion_type} - ) + stream_config = StreamingDatasetConfig.from_dict({"redis": {"port": redis_port}}) distributed = Distributed( DistributedConfig( @@ -89,7 +85,6 @@ def parse_args(): parser.add_argument("--total-gpus", type=int, required=True, help="Total number of GPUs available.") parser.add_argument("--redis-port", type=int, required=True, help="Redis port to connect to.") parser.add_argument("--result-path", type=str, required=True, help="Path to save test results.") - parser.add_argument("--ingestion-type", type=str, required=True, help="Ingestion type used in streaming dataset.") return parser.parse_args() @@ -107,7 +102,6 @@ def main(): total_gpus=args.total_gpus, redis_port=args.redis_port, result_path=args.result_path, - ingestion_type=IngestionType(args.ingestion_type), ) diff --git a/tests/data/test_streaming.py b/tests/data/test_streaming.py index a0bfae316..0bd2154ce 100644 --- a/tests/data/test_streaming.py +++ b/tests/data/test_streaming.py @@ -6,8 +6,7 @@ import pytest import torch -from fast_llm.data.dataset.config import IngestionType -from fast_llm.data.dataset.streaming import StreamingDataset +from fast_llm.data.dataset.streaming import RedisStreamingDataset from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed @@ -111,12 +110,10 @@ def run_distributed_gptdata_streaming_test( run_distributed_script, result_path, request, - ingestion_type: IngestionType, ): import tests.data.gptdata_streaming_test stream_config, fake_redis, fake_redis_server_killer = fake_redis_server - stream_config = stream_config.from_dict(stream_config.to_dict(), {("ingestion_type"): ingestion_type}) sequence_length = 10 micro_batch_size = 2 @@ -132,7 +129,6 @@ def run_distributed_gptdata_streaming_test( with redis_batch_producer( redis_client=fake_redis, fake_redis_server_killer=fake_redis_server_killer, - stream_config=stream_config, batch_size=batch_size, sequence_length=10, ): @@ -158,8 +154,6 @@ def run_distributed_gptdata_streaming_test( str(result_path), "--redis-port", str(redis_port), - "--ingestion-type", - str(ingestion_type.value), ] # TODO: distributed_capture is ignored now inside the script if request.config.getoption("distributed_capture"): @@ -183,7 +177,6 @@ def run_distributed_gptdata_streaming_test( total_gpus=total_gpus, redis_port=redis_port, result_path=result_path, - ingestion_type=ingestion_type, ) check_distributed_gptdata_streaming_test_results( @@ -231,10 +224,10 @@ def test_streaming_dataset_reads_single_message(monkeypatched_redis, stream_conf fake_redis = monkeypatched_redis distributed = Distributed(DistributedConfig(), use_cpu=True) - dataset = StreamingDataset(stream_config, distributed) + dataset = RedisStreamingDataset(stream_config, distributed) # Insert a message - push_msg(fake_redis, stream_config, [1, 2, 3]) + push_msg(fake_redis, [1, 2, 3]) it = iter(dataset) sample = next(it) @@ -252,12 +245,12 @@ def test_streaming_dataset_reads_multiple_messages(monkeypatched_redis, stream_c fake_redis = monkeypatched_redis distributed = Distributed(DistributedConfig(), use_cpu=True) - dataset = StreamingDataset(stream_config, distributed) + dataset = RedisStreamingDataset(stream_config, distributed) # Insert a message - push_msg(fake_redis, stream_config, [1, 2, 3]) - push_msg(fake_redis, stream_config, [1, 2, 3]) - push_msg(fake_redis, stream_config, [1, 2, 3]) + push_msg(fake_redis, [1, 2, 3]) + push_msg(fake_redis, [1, 2, 3]) + push_msg(fake_redis, [1, 2, 3]) it = iter(dataset) for i in range(3): @@ -275,10 +268,10 @@ def test_sampling_1_doc_exact_fit(monkeypatched_redis, stream_config): """Docs exactly fill one sample.""" fake_redis = monkeypatched_redis - push_msg(fake_redis, stream_config, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + push_msg(fake_redis, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) distributed = Distributed(DistributedConfig(), use_cpu=True) - sampler = StreamingDataset(stream_config, distributed).sample(make_sampling(10, 0, 1, distributed)) + sampler = RedisStreamingDataset(stream_config, distributed).sample(make_sampling(10, 0, 1, distributed)) out = next(iter(sampler)) @@ -292,11 +285,11 @@ def test_sampling_2_docs_exact_fit(monkeypatched_redis, stream_config): fake_redis = monkeypatched_redis # Two rollouts: lengths 4 and 6 -> exactly 10 - push_msg(fake_redis, stream_config, [1, 2, 3, 4]) - push_msg(fake_redis, stream_config, [5, 6, 7, 8, 9, 10]) + push_msg(fake_redis, [1, 2, 3, 4]) + push_msg(fake_redis, [5, 6, 7, 8, 9, 10]) distributed = Distributed(DistributedConfig(), use_cpu=True) - sampler = StreamingDataset(stream_config, distributed).sample(make_sampling(10, 0, 1, distributed)) + sampler = RedisStreamingDataset(stream_config, distributed).sample(make_sampling(10, 0, 1, distributed)) out = next(iter(sampler)) @@ -309,11 +302,11 @@ def test_sampling_skips_too_long_doc_and_padding_final(monkeypatched_redis, stre """Rollout longer than sample_length must be dropped.""" fake_redis = monkeypatched_redis - push_msg(fake_redis, stream_config, list(range(20))) # skip: too long - push_msg(fake_redis, stream_config, list(range(10))) # usable + push_msg(fake_redis, list(range(20))) # skip: too long + push_msg(fake_redis, list(range(10))) # usable distributed = Distributed(DistributedConfig(), use_cpu=True) - sampler = StreamingDataset(stream_config, distributed).sample(make_sampling(10, 0, 1, distributed)) + sampler = RedisStreamingDataset(stream_config, distributed).sample(make_sampling(10, 0, 1, distributed)) out = next(iter(sampler)) @@ -326,11 +319,11 @@ def test_sampling_overflow_creates_two(monkeypatched_redis, stream_config): """A document overflowing the boundary triggers padding + next sample.""" fake_redis = monkeypatched_redis - push_msg(fake_redis, stream_config, list(range(6))) - push_msg(fake_redis, stream_config, list(range(10))) + push_msg(fake_redis, list(range(6))) + push_msg(fake_redis, list(range(10))) distributed = Distributed(DistributedConfig(), use_cpu=True) - sampler = StreamingDataset(stream_config, distributed).sample(make_sampling(10, 0, 2, distributed)) + sampler = RedisStreamingDataset(stream_config, distributed).sample(make_sampling(10, 0, 2, distributed)) sampler_iter = iter(sampler) out = [next(sampler_iter)] @@ -343,18 +336,7 @@ def test_sampling_overflow_creates_two(monkeypatched_redis, stream_config): assert out[1].tokens.tokens.tolist() == list(range(10)) -@pytest.mark.parametrize( - "ingestion_type", - [ - IngestionType.CONSUMER_GROUP, - # TODO: need to implement wait_until_stream_empty for those variants on test side to enable tests for them - # IngestionType.ONE_STREAM, - # IngestionType.N_STREAMS, - ], -) -def test_gptdata_streaming_single_consumer( - fake_redis_server, run_distributed_script_lean, ingestion_type, result_path, request -): +def test_gptdata_streaming_single_consumer(fake_redis_server, run_distributed_script_lean, result_path, request): run_distributed_gptdata_streaming_test( fake_redis_server=fake_redis_server, @@ -369,7 +351,6 @@ def test_gptdata_streaming_single_consumer( run_distributed_script=run_distributed_script_lean, result_path=result_path, request=request, - ingestion_type=ingestion_type, ) @@ -395,5 +376,4 @@ def test_gptdata_streamin_gpus(fake_redis_server, variant, run_distributed_scrip run_distributed_script=run_distributed_script_lean, result_path=result_path, request=request, - ingestion_type=IngestionType.CONSUMER_GROUP, ) diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 53804d878..bb53de29e 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -431,7 +431,6 @@ def reference_distributed_shard(get_convert_path) -> torch.Tensor | None: # We don't want to depend on `test_save_and_load_in_parallel` because we still want to run this in cas of failure. # This should still run after `test_save_and_load_in_parallel` @requires_cuda -# NOTE: Should it depend on test_model_distributed instead? @pytest.mark.depends_on(on=["test_load_pretrained[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.convert, ModelTestingGroup.distributed) def test_load_parallel_checkpoint_in_single_gpu( diff --git a/tests/trainer/test_events.py b/tests/trainer/test_events.py index 4894a022f..14e559c31 100644 --- a/tests/trainer/test_events.py +++ b/tests/trainer/test_events.py @@ -382,7 +382,6 @@ def test_trainer_events_with_streaming(fake_redis_server, variant, run_distribut with redis_batch_producer( redis_client=fake_redis_client, fake_redis_server_killer=fake_redis_server_killer, - stream_config=stream_config, batch_size=batch_size, sequence_length=sequence_length, ): diff --git a/tests/utils/redis.py b/tests/utils/redis.py index 7e9d3b689..e34cb8173 100644 --- a/tests/utils/redis.py +++ b/tests/utils/redis.py @@ -9,30 +9,16 @@ import pytest from fast_llm.data.dataset.config import ( - IngestionType, SamplingConfig, SamplingData, SamplingParameters, ShufflingType, StreamingDatasetConfig, - StreamingDatasetRedisConfig, ) +from fast_llm.data.dataset.streaming import REDIS_DATA_KEY, REDIS_GROUP_NAME from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig -def get_stream_config(): - return StreamingDatasetConfig( - redis=StreamingDatasetRedisConfig( - host="localhost", - port=6379, - stream_key="test_stream", - payload_key="data", - ), - group_name="test_group", - consumer_name_prefix="consumer", - ) - - def find_free_port(): """Find a free TCP port and return it.""" with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: @@ -40,16 +26,15 @@ def find_free_port(): return s.getsockname()[1] -def push_msg(redis_client, config, tokens=None, stream_key_suffix=None): +def push_msg(redis_client, tokens=None, stream_key_suffix=None, payload_key="data", stream_key=REDIS_DATA_KEY): """Push a message into FakeRedis stream.""" msg = { "tokens": tokens, "tokens_dtype": "int64", } - stream_key = config.redis.stream_key if stream_key_suffix is not None: stream_key += stream_key_suffix - redis_client.xadd(stream_key, {config.redis.payload_key: orjson.dumps(msg)}) + redis_client.xadd(stream_key, {payload_key: orjson.dumps(msg)}) class FakeRedisServerKiller: @@ -73,20 +58,7 @@ def wait_until_stream_empty( stream_key, consumer_group, stop_event, - consumer_count, - ingestion_type: IngestionType, ): - if ingestion_type == IngestionType.CONSUMER_GROUP: - return wait_until_stream_empty_consumer_group(redis_client, stream_key, consumer_group, stop_event) - elif ingestion_type == IngestionType.ONE_STREAM: - raise NotImplementedError() - elif ingestion_type == IngestionType.N_STREAMS: - raise NotImplementedError() - else: - raise ValueError(f"Unknown ingestion type {ingestion_type.value}") - - -def wait_until_stream_empty_consumer_group(redis_client, stream_key, consumer_group, stop_event): """ Wait until lag == 0, meaning all messages have been delivered AND acknowledged. Absence of group mean test has not started yet, so we wait @@ -106,7 +78,7 @@ def wait_until_stream_empty_consumer_group(redis_client, stream_key, consumer_gr def get_consumer_count(redis_client, stop_event, config: StreamingDatasetConfig): while not stop_event.is_set(): - res = redis_client.hget(f"{config.redis.stream_key}:consumer_count", "0") + res = redis_client.hget(f"{REDIS_DATA_KEY}:consumer_count", "0") if res is None: time.sleep(0.05) continue @@ -114,18 +86,12 @@ def get_consumer_count(redis_client, stop_event, config: StreamingDatasetConfig) @contextlib.contextmanager -def redis_batch_producer( - redis_client, fake_redis_server_killer, stream_config, batch_size, sequence_length, num_batches=None -): +def redis_batch_producer(redis_client, fake_redis_server_killer, batch_size, sequence_length, num_batches=None): stop_event = threading.Event() thread_exc = [] def producer_loop(): - is_n_streams = stream_config.ingestion_type == IngestionType.N_STREAMS try: - consumer_count = get_consumer_count(redis_client, stop_event, stream_config) - stream = stream_config.redis.stream_key - group = stream_config.group_name batch_idx = 0 while not stop_event.is_set(): if num_batches is not None and batch_idx >= num_batches: @@ -135,17 +101,13 @@ def producer_loop(): return push_msg( redis_client, - stream_config, [batch_idx * batch_size + i] * sequence_length, - stream_key_suffix=f"_{i % consumer_count}" if is_n_streams else None, ) wait_until_stream_empty( redis_client, - stream, - group, + REDIS_DATA_KEY, + REDIS_GROUP_NAME, stop_event, - consumer_count=consumer_count, - ingestion_type=stream_config.ingestion_type, ) batch_idx += 1 except Exception as e: @@ -184,14 +146,13 @@ def make_sampling(sequence_length, extra_tokens, num_samples, distributed): @pytest.fixture def stream_config(): - return get_stream_config() + # TODO: ======= Not safe for parallel tests? ======= + return StreamingDatasetConfig.from_dict({"redis": {"port": find_free_port()}}) @pytest.fixture def fake_redis_server(stream_config): # We search for free port as port from previous test can still be not free even after server shutdown - stream_config = stream_config.from_dict(stream_config.to_dict(), {("redis", "port"): find_free_port()}) - server_address = (stream_config.redis.host, stream_config.redis.port) # ----- Monkey-patch handler to suppress broken pipes ----- From 992f447a28b05c239e82d2c895af3dbfda75186b Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 17 Dec 2025 15:36:12 -0500 Subject: [PATCH 02/12] stuff --- fast_llm/data/data/data_loader.py | 72 ++++++++++++++++++++++ fast_llm/data/data/data_loader_wrapper.py | 48 --------------- fast_llm/data/data/gpt/data.py | 6 +- fast_llm/data/dataset/abstract.py | 21 +++++++ fast_llm/data/dataset/config.py | 30 ++++++--- fast_llm/data/dataset/sampled.py | 22 +++++-- fast_llm/data/dataset/streaming.py | 63 +++++++------------ fast_llm/data/iterator.py | 25 -------- fast_llm/engine/distributed/config.py | 5 +- fast_llm/engine/distributed/distributed.py | 14 ++--- fast_llm/engine/training/config.py | 8 +-- fast_llm/engine/training/trainer_events.py | 2 +- fast_llm/redis/config.py | 16 ----- tests/data/gptdata_streaming_test.py | 2 +- tests/data/test_streaming.py | 16 +++-- tests/trainer/test_events.py | 4 +- tests/utils/redis.py | 4 +- 17 files changed, 177 insertions(+), 181 deletions(-) create mode 100644 fast_llm/data/data/data_loader.py delete mode 100644 fast_llm/data/data/data_loader_wrapper.py delete mode 100644 fast_llm/data/iterator.py diff --git a/fast_llm/data/data/data_loader.py b/fast_llm/data/data/data_loader.py new file mode 100644 index 000000000..ad8ad3cf6 --- /dev/null +++ b/fast_llm/data/data/data_loader.py @@ -0,0 +1,72 @@ +import itertools +import typing + +import torch.utils.data + +from fast_llm.core.distributed import broadcast_object + + +class SampledDatasetIterator(torch.utils.data.Sampler): + """ + A distributed sampler generating indices for a `SampledDataset` (i.e., the natural numbers). + To be used as the `batch_sampler` of a `torch.utils.data.DataLoader`. + """ + + def __init__(self, total_samples, begin_index, micro_batch_size, data_rank, data_parallel): + super().__init__() + self._total_samples = total_samples + self._begin_index = begin_index + self._batch_size = micro_batch_size * data_parallel + self._start_idx = data_rank * micro_batch_size + self._end_idx = (data_rank + 1) * micro_batch_size + + def __len__(self) -> int: + return self._total_samples + + def __iter__(self) -> typing.Iterator[list[int]]: + for idx in range(self._begin_index, self._total_samples - self._batch_size + 1, self._batch_size): + yield list(range(idx + self._start_idx, idx + self._end_idx)) + + +class DistributedDataLoaderWrapper: + """ + Wraps a regular dataloader so that only the process group leader + loads data, and then broadcasts the batch to other ranks in the group. + """ + + def __init__( + self, + data_loader: torch.utils.data.dataloader.DataLoader, + process_group: torch.distributed.ProcessGroup | None, + ): + self._data_loader = data_loader + self._rank = 0 if process_group is None else process_group.rank() + self._process_group = process_group + + def __iter__(self): + if self._rank == 0: + self._iterator = iter(self._data_loader) + else: + self._iterator = itertools.repeat(None) + if self._process_group is None: + return self._iterator + return self + + def __next__(self): + # TODO: + # Instead of broadcasting a general object, make this iterator yield an actual Batch class. + # Implement `get_state_dict` and `from_state_dict` in the Batch class so that we can + # efficiently broadcast tensors directly. This avoids using `broadcast_object` on the + # entire Batch object, which is inefficient for tensors because it serializes + # (pickles) them before sending. + + try: + data = next(self.iterator) # may raise StopIteration + except Exception as e: + data = e + data = broadcast_object(data, self._process_group, 0) + + if isinstance(data, Exception): + raise data + + return data diff --git a/fast_llm/data/data/data_loader_wrapper.py b/fast_llm/data/data/data_loader_wrapper.py deleted file mode 100644 index a44aa191d..000000000 --- a/fast_llm/data/data/data_loader_wrapper.py +++ /dev/null @@ -1,48 +0,0 @@ -import torch.utils.data.dataloader - -from fast_llm.core.distributed import broadcast_object - - -class DistributedDataLoaderWrapper: - """ - Wraps a regular dataloader so that only the process group leader - loads data, and then broadcasts the batch to other ranks in the group. - """ - - def __init__( - self, - data_loader: torch.utils.data.dataloader.DataLoader, - process_group: torch.distributed.ProcessGroup | None, - ): - self._data_loader = data_loader - self._rank = 0 if process_group is None else process_group.rank() - self.process_group = process_group - - def __iter__(self): - if self._rank == 0: - self.iterator = iter(self._data_loader) - if self.process_group is None: - return self.iterator - return self - - def __next__(self): - # TODO: - # Instead of broadcasting a general object, make this iterator yield an actual Batch class. - # Implement `get_state_dict` and `from_state_dict` in the Batch class so that we can - # efficiently broadcast tensors directly. This avoids using `broadcast_object` on the - # entire Batch object, which is inefficient for tensors because it serializes - # (pickles) them before sending. - - if self._rank == 0: - try: - data = next(self.iterator) # may raise StopIteration - except Exception as e: - data = e - data = broadcast_object(data, self.process_group, 0) - else: - data = broadcast_object(None, self.process_group, 0) - - if isinstance(data, Exception): - raise data - - return data diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 9e1574437..3af86652a 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -8,13 +8,12 @@ from fast_llm.core.distributed import safe_barrier from fast_llm.data.data.abstract import Data -from fast_llm.data.data.data_loader_wrapper import DistributedDataLoaderWrapper +from fast_llm.data.data.data_loader import DistributedDataLoaderWrapper, SampledDatasetIterator from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.data.dataset.abstract import SampledDataset from fast_llm.data.dataset.config import SamplingParameters from fast_llm.data.dataset.gpt.config import GPTSamplingData from fast_llm.data.dataset.monitor import DatasetMonitor -from fast_llm.data.iterator import SampledDatasetIterator from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.data.sample.language_model import LanguageModelBatch from fast_llm.engine.config_utils.run import log_main_rank @@ -133,8 +132,7 @@ def get_iterator( multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None, ) - if False: - # TODO: ====== do ====== + if self._datasets[dataset_name].requires_broadcast: data_loader = DistributedDataLoaderWrapper(data_loader, self.distributed.model_and_sequence_data_group) return iter(data_loader) diff --git a/fast_llm/data/dataset/abstract.py b/fast_llm/data/dataset/abstract.py index 2efdf3841..1df24e92b 100644 --- a/fast_llm/data/dataset/abstract.py +++ b/fast_llm/data/dataset/abstract.py @@ -5,6 +5,7 @@ if typing.TYPE_CHECKING: from fast_llm.data.dataset.config import SamplingData + from fast_llm.data.dataset.sampled import SampledIterableDataset class Dataset[SampleType: Sample](abc.ABC): @@ -27,6 +28,14 @@ def __getstate__(self): del state["__orig_class__"] return state + @property + def requires_broadcast(self) -> bool: + """ + Some dataset schemes load the dataset on a batch-data-parallel group leaders, + then broadcast to the other devices. + """ + return False + class SampledDataset[SampleType: Sample](Dataset[SampleType]): """ @@ -44,6 +53,18 @@ def __len__(self) -> int: class SamplableDataset[SampleType: Sample](Dataset[SampleType]): + @abc.abstractmethod def sample(self, config: "SamplingData") -> SampledDataset[SampleType]: pass + + +class SamplableIterableDataset[SampleType: Sample](SamplableDataset[SampleType]): + @abc.abstractmethod + def __iter__(self) -> typing.Iterator[SampleType]: + pass + + def sample(self, config: "SamplingData") -> "SampledIterableDataset[SampleType]": + from fast_llm.data.dataset.sampled import SampledIterableDataset + + return SampledIterableDataset(self, config) diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index 1da79e214..d9ba5e7d2 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -11,7 +11,6 @@ from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset from fast_llm.data.preprocessing.abstract import PreprocessingConfig from fast_llm.data.sample.abstract import Sample -from fast_llm.redis.config import RedisConfig from fast_llm.utils import Assert, normalize_probabilities if typing.TYPE_CHECKING: @@ -108,7 +107,7 @@ class DatasetConfig[SampleType: Sample](Config): @config_class(registry=True) class SampledDatasetConfig[SampleType: Sample](DatasetConfig[SampleType]): """ - A sampled dataset containing a prepared list or iterable of samples to be indexed sequentially (as-is) during training. + A sampled dataset containing a prepared list of samples to be indexed sequentially (as-is) during training. """ def build_and_sample(self, sampling: SamplingData) -> SampledDataset[SampleType]: @@ -302,19 +301,30 @@ def build(self, preprocessing: PreprocessingConfig) -> "IndexedDataset[SampleTyp raise FileNotFoundError(self.path) +@config_class() +class RedisConfig(Config): + # TODO: Move elsewhere? (Also used in trainer) Get it from the trainer in sampling config? + host: str = Field( + default="localhost", + desc="Hostname or IP address of the Redis server.", + hint=FieldHint.core, + ) + + port: int = Field( + default=6379, + desc="Port number on which the Redis server is running.", + hint=FieldHint.core, + ) + + @config_class(dynamic_type={SampledDatasetConfig: "streaming"}) -class StreamingDatasetConfig[SampleType: LanguageModelSample](SamplableDatasetConfig[SampleType]): +class StreamingDatasetConfig[SampleType: LanguageModelSample](RedisConfig, SamplableDatasetConfig[SampleType]): """ Configuration for a streaming dataset that reads training data from a Redis stream. """ _abstract = False - redis: RedisConfig = Field( - desc="Redis connection and stream settings used to fetch incoming training data.", - hint=FieldHint.core, - ) - acknowledge_interval: int = Field( default=10, desc="Number of messages after which the consumer acknowledges received IDs back to the Redis hash.", @@ -324,4 +334,6 @@ class StreamingDatasetConfig[SampleType: LanguageModelSample](SamplableDatasetCo def build_and_sample(self, sampling: SamplingData) -> SampledDataset[SampleType]: from fast_llm.data.dataset.streaming import RedisStreamingDataset - return RedisStreamingDataset[SampleType](self, sampling.distributed.config).sample(sampling) + return RedisStreamingDataset[StreamingDatasetConfig, SampleType](self, sampling.distributed.config).sample( + sampling + ) diff --git a/fast_llm/data/dataset/sampled.py b/fast_llm/data/dataset/sampled.py index 8ce780558..8cf7d938a 100644 --- a/fast_llm/data/dataset/sampled.py +++ b/fast_llm/data/dataset/sampled.py @@ -8,7 +8,7 @@ import torch import yaml -from fast_llm.data.dataset.abstract import SampledDataset +from fast_llm.data.dataset.abstract import SamplableIterableDataset, SampledDataset from fast_llm.data.dataset.config import SamplingData, ShufflingType from fast_llm.data.dataset.indexed import IndexedDataset from fast_llm.data.sample.abstract import Sample @@ -111,6 +111,10 @@ def __init__( # No barrier yet to allow running in parallel. # There needs to be one before calling `__getitem__`, normally handled through `Data`. + @property + def requires_broadcast(self) -> bool: + return self._indexed_dataset.requires_broadcast + def _sample(self) -> None: """ Create a `SampledDataset` with the requested parameters. @@ -434,17 +438,27 @@ def _load_yaml_data(self, data: dict[str, typing.Any]) -> None: class SampledIterableDataset[SampleType: Sample](SampledDataset[SampleType]): def __init__( self, - iterable_dataset: typing.Iterator[SampleType], + dataset: SamplableIterableDataset[SampleType], sampling: SamplingData, ): - self._iterator = iterable_dataset + self._dataset = dataset self._config = sampling.config self._parameters = sampling.parameters self._documents: list[SampleType] = [] self._current_length = 0 self._sample_length = self._parameters.sequence_length + self._parameters.extra_tokens + # Delay iterator creation to avoid pickling issues. + self._iterator: typing.Iterator[SampleType] | None = None + + @property + def requires_broadcast(self) -> bool: + # TODO: ====== fix ====== + # return self._iterator.requires_broadcast + return True def __getitem__(self, index: int) -> SampleType: + if self._iterator is None: + self._iterator = iter(self._dataset) while self._current_length < self._sample_length: document = next(self._iterator) if len(document) > self._sample_length: @@ -476,4 +490,4 @@ def __len__(self) -> int: @property def name(self) -> str: - return self._iterator.name + return self._dataset.name diff --git a/fast_llm/data/dataset/streaming.py b/fast_llm/data/dataset/streaming.py index 0898125ad..a1f0f32f9 100644 --- a/fast_llm/data/dataset/streaming.py +++ b/fast_llm/data/dataset/streaming.py @@ -4,9 +4,9 @@ import redis import torch.utils.data -from fast_llm.data.dataset.abstract import SamplableDataset -from fast_llm.data.dataset.config import SamplingData, StreamingDatasetConfig -from fast_llm.data.dataset.sampled import SampledIterableDataset +from fast_llm.config import Configurable +from fast_llm.data.dataset.abstract import SamplableIterableDataset +from fast_llm.data.dataset.config import StreamingDatasetConfig from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.data.sample.range import RangeSample from fast_llm.data.sample.token import TokenSample @@ -24,9 +24,11 @@ def dtype_from_string(name: str) -> torch.dtype: REDIS_GROUP_NAME = "fast_llm_group" -class RedisStreamingDataset[SampleType: LanguageModelSample](SamplableDataset[SampleType]): - def __init__(self, config: StreamingDatasetConfig, distributed_config: DistributedConfig): - super().__init__() +class RedisStreamingDataset[ConfigType: StreamingDatasetConfig, SampleType: LanguageModelSample]( + Configurable[ConfigType], SamplableIterableDataset[SampleType] +): + def __init__(self, config: ConfigType, distributed_config: DistributedConfig): + super().__init__(config) if distributed_config.pipeline_parallel > 1: # NOTE: It is not yet clear whether the issue comes from the streaming dataset # itself or from the distributed data-loader wrappers, but currently it @@ -34,49 +36,26 @@ def __init__(self, config: StreamingDatasetConfig, distributed_config: Distribut # the training step. raise NotImplementedError("Streaming dataset support is not implemented for pipeline-parallel training.") - self._name = f"redis[{config.redis.host}:{config.redis.port}]({REDIS_DATA_KEY}|{REDIS_GROUP_NAME})[data]" + self._name = f"redis[{config.host}:{config.port}]({REDIS_DATA_KEY}|{REDIS_GROUP_NAME})[data]" self._config = config - self.batch_data_rank = distributed_config.batch_data_rank - self.batch_data_parallel = distributed_config.batch_data_parallel + self._rank = distributed_config.batch_data_rank self.is_batch_data_group_leader = ( distributed_config.get_distributed_dim(DistributedDimNames.model_and_sequence_data).rank == 0 ) - if distributed_config.rank == 0: - redis_client = redis.Redis(host=self._config.redis.host, port=self._config.redis.port) - # TODO: Not needed? - redis_client.hset(f"{REDIS_DATA_KEY}:consumer_count", "0", str(self.batch_data_parallel)) + # TODO: Not needed? + # if distributed_config.rank == 0: + # redis_client = redis.Redis(host=self._config.host, port=self._config.port) + # redis_client.hset(f"{REDIS_DATA_KEY}:consumer_count", "0", str(distributed_config.batch_data_parallel)) + + @property + def requires_broadcast(self) -> bool: + return True @property def name(self) -> str: return self._name - def sample(self, config: SamplingData) -> SampledIterableDataset[LanguageModelSample]: - return SampledIterableDataset(iter(self), config) - - # def __getstate__(self) -> tuple[str, StreamingDatasetConfig, int, int, bool, bytes, bytes]: - # return ( - # self._name, - # self._config, - # self.batch_data_parallel, - # self.batch_data_rank, - # self.is_batch_data_group_leader, - # self.payload_key_b, - # self.hash_key_b, - # ) - - # def __setstate__(self, state: tuple[str, StreamingDatasetConfig, int, bool, bytes, bytes]): - # name, config, batch_data_parallel, batch_data_rank, is_batch_data_group_leader, payload_key_b, hash_key_b = ( - # state - # ) - # self._name = name - # self._config = config - # self.batch_data_parallel = batch_data_parallel - # self.batch_data_rank = batch_data_rank - # self.is_batch_data_group_leader = is_batch_data_group_leader - # self.payload_key_b = payload_key_b - # self.hash_key_b = hash_key_b - def __iter__(self) -> typing.Iterator[LanguageModelSample]: worker_info = torch.utils.data.get_worker_info() if worker_info is not None and worker_info.num_workers > 1: @@ -85,7 +64,7 @@ def __iter__(self) -> typing.Iterator[LanguageModelSample]: if not self.is_batch_data_group_leader: raise RuntimeError("Must be only called on the batch data group leader") - client = redis.Redis(host=self._config.redis.host, port=self._config.redis.port) + client = redis.Redis(host=self._config.host, port=self._config.port) # Create the consumer group at the start of the stream ("0") # If the stream already exists, XGROUP CREATE will fail unless we add mkstream=True @@ -105,7 +84,7 @@ def __iter__(self) -> typing.Iterator[LanguageModelSample]: # BLOCK: wait for new messages (milliseconds) messages = client.xreadgroup( groupname=REDIS_GROUP_NAME, - consumername=f"fast_llm_consumer_{self.batch_data_rank}", + consumername=f"fast_llm_consumer_{self._rank}", # ">" reads only new messages that have not been delivered to any consumer streams={REDIS_DATA_KEY: ">"}, count=1, @@ -121,7 +100,7 @@ def __iter__(self) -> typing.Iterator[LanguageModelSample]: processed += 1 # TODO: or do it after processing all received messaged then count > 1? if processed % self._config.acknowledge_interval == 0: - client.hset(f"{REDIS_DATA_KEY}:ack", str(self.batch_data_rank), msg_id) + client.hset(f"{REDIS_DATA_KEY}:ack", str(self._rank), msg_id) yield self._read_document(json.loads(msg_data[b"data"])) diff --git a/fast_llm/data/iterator.py b/fast_llm/data/iterator.py deleted file mode 100644 index a407c0258..000000000 --- a/fast_llm/data/iterator.py +++ /dev/null @@ -1,25 +0,0 @@ -import typing - -import torch.utils.data - - -class SampledDatasetIterator(torch.utils.data.Sampler): - """ - A distributed sampler generating indices for a `SampledDataset` (i.e., the natural numbers). - To be used as the `batch_sampler` of a `torch.utils.data.DataLoader`. - """ - - def __init__(self, total_samples, begin_index, micro_batch_size, data_rank, data_parallel): - super().__init__() - self._total_samples = total_samples - self._begin_index = begin_index - self._batch_size = micro_batch_size * data_parallel - self._start_idx = data_rank * micro_batch_size - self._end_idx = (data_rank + 1) * micro_batch_size - - def __len__(self) -> int: - return self._total_samples - - def __iter__(self) -> typing.Iterator[list[int]]: - for idx in range(self._begin_index, self._total_samples - self._batch_size + 1, self._batch_size): - yield list(range(idx + self._start_idx, idx + self._end_idx)) diff --git a/fast_llm/engine/distributed/config.py b/fast_llm/engine/distributed/config.py index 532dfda25..48800db8f 100644 --- a/fast_llm/engine/distributed/config.py +++ b/fast_llm/engine/distributed/config.py @@ -101,7 +101,7 @@ def from_sizes_and_strides(cls, name: str, global_rank: int, *sizes_and_strides: if len(global_ranks) == 1 or ( isinstance(global_ranks, range) and stride == global_ranks.stop - global_ranks.start ): - global_ranks = range(start, start + size * stride, sizes_and_strides[0][0]) + global_ranks = range(start, start + size * stride, sizes_and_strides[0][1]) else: global_ranks = (rank0 + rank1 for rank1 in range(0, size, stride) for rank0 in global_ranks) global_ranks = global_ranks if isinstance(global_ranks, range) else list(global_ranks) @@ -349,7 +349,7 @@ def _validate(self) -> None: DistributedDimNames.model_and_sequence_data, (self.tensor_parallel, tensor_stride), (self.sequence_data_parallel, sequence_data_stride), - (self.pipeline_rank, pipeline_stride), + (self.pipeline_parallel, pipeline_stride), ) super()._validate() @@ -362,6 +362,7 @@ def _add_distributed_dim_from_sizes_and_strides(self, name: str, *sizes_and_stri self._add_distributed_dim(DistributedDim.from_sizes_and_strides(name, self.rank, *sizes_and_strides)) def _add_distributed_dim(self, distributed_dim: DistributedDim) -> None: + log("AAAAAA", distributed_dim, distributed_dim.global_ranks, distributed_dim.rank, self.rank, self.world_size) Assert.eq(distributed_dim.global_ranks[distributed_dim.rank], self.rank, msg=distributed_dim) try: diff --git a/fast_llm/engine/distributed/distributed.py b/fast_llm/engine/distributed/distributed.py index 7b95cecfb..19b50771b 100644 --- a/fast_llm/engine/distributed/distributed.py +++ b/fast_llm/engine/distributed/distributed.py @@ -171,16 +171,10 @@ def __init__(self, config: DistributedConfig, use_cpu: bool = False): self.tensor_group = self.add_group(self._config.distributed_dims[DistributedDimNames.tensor]) self.sequence_data_group = self.add_group(self._config.distributed_dims[DistributedDimNames.sequence_data]) self.batch_data_group = self.add_group(self._config.distributed_dims[DistributedDimNames.batch_data]) - - # Global ranks wrong with pipeline first, so we hide the dims as a safety check. - if not self._config.pipeline_first: - self.tensor_and_sequence_data_group = self.add_group( - self._config.distributed_dims[DistributedDimNames.tensor_and_sequence_data] - ) - self.tensor_and_data_group = self.add_group( - self._config.distributed_dims[DistributedDimNames.tensor_and_data] - ) - + self.tensor_and_sequence_data_group = self.add_group( + self._config.distributed_dims[DistributedDimNames.tensor_and_sequence_data] + ) + self.tensor_and_data_group = self.add_group(self._config.distributed_dims[DistributedDimNames.tensor_and_data]) self.model_and_sequence_data_group = self.add_group( self._config.distributed_dims[DistributedDimNames.model_and_sequence_data] ) diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 4795d80dc..02829c580 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -16,6 +16,7 @@ skip_valid_if_none, ) from fast_llm.data.data.config import DataConfig +from fast_llm.data.dataset.config import RedisConfig from fast_llm.engine.checkpoint.config import ( CheckpointLoadConfig, CheckpointSaveConfig, @@ -398,16 +399,11 @@ class TrainingFinishedEventConfig(TrainerEvent): @config_class() -class TrainerEventsConfig(Config): +class TrainerEventsConfig(RedisConfig): """ Aggregates all trainer-side Redis-based event configurations. """ - redis: TrainerEventsRedisConfig = Field( - desc="Redis connection and stream settings used to fetch incoming training data.", - hint=FieldHint.core, - ) - weights_broadcast: WeightsBroadcastEventConfig = Field( default=None, desc="Configuration for signaling weight-ready events via Redis.", diff --git a/fast_llm/engine/training/trainer_events.py b/fast_llm/engine/training/trainer_events.py index 93719615a..3bbd60df6 100644 --- a/fast_llm/engine/training/trainer_events.py +++ b/fast_llm/engine/training/trainer_events.py @@ -4,10 +4,10 @@ import redis import torch.distributed +from fast_llm.data.dataset.config import RedisConfig from fast_llm.engine.config_utils.run import is_main_rank from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.engine.training.config import TrainerEventsConfig, TrainingExportConfig -from fast_llm.redis.config import RedisConfig logger = logging.getLogger(__name__) diff --git a/fast_llm/redis/config.py b/fast_llm/redis/config.py index c36853787..e69de29bb 100644 --- a/fast_llm/redis/config.py +++ b/fast_llm/redis/config.py @@ -1,16 +0,0 @@ -from fast_llm.config import Config, Field, FieldHint, config_class - - -@config_class() -class RedisConfig(Config): - host: str = Field( - default="localhost", - desc="Hostname or IP address of the Redis server.", - hint=FieldHint.core, - ) - - port: int = Field( - default=6379, - desc="Port number on which the Redis server is running.", - hint=FieldHint.core, - ) diff --git a/tests/data/gptdata_streaming_test.py b/tests/data/gptdata_streaming_test.py index 461c9756d..859d90af0 100644 --- a/tests/data/gptdata_streaming_test.py +++ b/tests/data/gptdata_streaming_test.py @@ -23,7 +23,7 @@ def distributed_gptdata_streaming_test( redis_port, result_path, ): - stream_config = StreamingDatasetConfig.from_dict({"redis": {"port": redis_port}}) + stream_config = StreamingDatasetConfig(port=redis_port) distributed = Distributed( DistributedConfig( diff --git a/tests/data/test_streaming.py b/tests/data/test_streaming.py index 0bd2154ce..daf15c65e 100644 --- a/tests/data/test_streaming.py +++ b/tests/data/test_streaming.py @@ -122,7 +122,7 @@ def run_distributed_gptdata_streaming_test( pipeline_parallel = variant["pipeline_parallel"] sequence_data_parallel = variant["sequence_data_parallel"] total_gpus = variant["total_gpus"] - redis_port = stream_config.redis.port + redis_port = stream_config.port result_path = result_path / "distributed_gptdata_streaming_test" / request.node.name @@ -223,8 +223,7 @@ def test_streaming_dataset_reads_single_message(monkeypatched_redis, stream_conf """StreamingDataset should read a message and convert it into LanguageModelSample.""" fake_redis = monkeypatched_redis - distributed = Distributed(DistributedConfig(), use_cpu=True) - dataset = RedisStreamingDataset(stream_config, distributed) + dataset = RedisStreamingDataset(stream_config, DistributedConfig()) # Insert a message push_msg(fake_redis, [1, 2, 3]) @@ -244,8 +243,7 @@ def test_streaming_dataset_reads_multiple_messages(monkeypatched_redis, stream_c """StreamingDataset should read a message and convert it into LanguageModelSample.""" fake_redis = monkeypatched_redis - distributed = Distributed(DistributedConfig(), use_cpu=True) - dataset = RedisStreamingDataset(stream_config, distributed) + dataset = RedisStreamingDataset(stream_config, DistributedConfig()) # Insert a message push_msg(fake_redis, [1, 2, 3]) @@ -271,7 +269,7 @@ def test_sampling_1_doc_exact_fit(monkeypatched_redis, stream_config): push_msg(fake_redis, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) distributed = Distributed(DistributedConfig(), use_cpu=True) - sampler = RedisStreamingDataset(stream_config, distributed).sample(make_sampling(10, 0, 1, distributed)) + sampler = RedisStreamingDataset(stream_config, distributed.config).sample(make_sampling(10, 0, 1, distributed)) out = next(iter(sampler)) @@ -289,7 +287,7 @@ def test_sampling_2_docs_exact_fit(monkeypatched_redis, stream_config): push_msg(fake_redis, [5, 6, 7, 8, 9, 10]) distributed = Distributed(DistributedConfig(), use_cpu=True) - sampler = RedisStreamingDataset(stream_config, distributed).sample(make_sampling(10, 0, 1, distributed)) + sampler = RedisStreamingDataset(stream_config, distributed.config).sample(make_sampling(10, 0, 1, distributed)) out = next(iter(sampler)) @@ -306,7 +304,7 @@ def test_sampling_skips_too_long_doc_and_padding_final(monkeypatched_redis, stre push_msg(fake_redis, list(range(10))) # usable distributed = Distributed(DistributedConfig(), use_cpu=True) - sampler = RedisStreamingDataset(stream_config, distributed).sample(make_sampling(10, 0, 1, distributed)) + sampler = RedisStreamingDataset(stream_config, distributed.config).sample(make_sampling(10, 0, 1, distributed)) out = next(iter(sampler)) @@ -323,7 +321,7 @@ def test_sampling_overflow_creates_two(monkeypatched_redis, stream_config): push_msg(fake_redis, list(range(10))) distributed = Distributed(DistributedConfig(), use_cpu=True) - sampler = RedisStreamingDataset(stream_config, distributed).sample(make_sampling(10, 0, 2, distributed)) + sampler = RedisStreamingDataset(stream_config, distributed.config).sample(make_sampling(10, 0, 2, distributed)) sampler_iter = iter(sampler) out = [next(sampler_iter)] diff --git a/tests/trainer/test_events.py b/tests/trainer/test_events.py index 14e559c31..b14c4927d 100644 --- a/tests/trainer/test_events.py +++ b/tests/trainer/test_events.py @@ -356,8 +356,8 @@ def test_trainer_events_with_streaming(fake_redis_server, variant, run_distribut # so fake consumers can read them as well from this dict config model_config["events"] = { "redis": { - "host": stream_config.redis.host, - "port": stream_config.redis.port, + "host": stream_config.host, + "port": stream_config.port, "stream_key": "fast_llm_events", "payload_key": "event", }, diff --git a/tests/utils/redis.py b/tests/utils/redis.py index e34cb8173..cc820b84d 100644 --- a/tests/utils/redis.py +++ b/tests/utils/redis.py @@ -147,13 +147,13 @@ def make_sampling(sequence_length, extra_tokens, num_samples, distributed): @pytest.fixture def stream_config(): # TODO: ======= Not safe for parallel tests? ======= - return StreamingDatasetConfig.from_dict({"redis": {"port": find_free_port()}}) + return StreamingDatasetConfig(port=find_free_port()) @pytest.fixture def fake_redis_server(stream_config): # We search for free port as port from previous test can still be not free even after server shutdown - server_address = (stream_config.redis.host, stream_config.redis.port) + server_address = (stream_config.host, stream_config.port) # ----- Monkey-patch handler to suppress broken pipes ----- orig_handle = fakeredis._tcp_server.TCPFakeRequestHandler.handle From b525407733307fdf7ad7cbbcbfdd125b4fe21338 Mon Sep 17 00:00:00 2001 From: RaymondLi0 Date: Fri, 19 Dec 2025 14:50:03 -0500 Subject: [PATCH 03/12] Activation distillation: metrics and padding (#423) Co-authored-by: Torsten Scholak --- fast_llm/layers/block/config.py | 1 + fast_llm/layers/decoder/block.py | 68 +++++++++++++++++---- fast_llm/layers/decoder/stochastic_mixer.py | 22 +++++++ fast_llm/models/gpt/model.py | 25 ++++++++ 4 files changed, 103 insertions(+), 13 deletions(-) diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index 261d54025..b06f69ee5 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -45,6 +45,7 @@ class BlockKwargs: device = "device" hidden_states = "hidden_states" output_hidden_states = "output_hidden_states" + activation_mask = "activation_mask" @config_class(registry=True) diff --git a/fast_llm/layers/decoder/block.py b/fast_llm/layers/decoder/block.py index 148dabd5c..0e3d6f0c0 100644 --- a/fast_llm/layers/decoder/block.py +++ b/fast_llm/layers/decoder/block.py @@ -136,9 +136,9 @@ def forward( fw_input = input_ hidden_states = self.norm_1(input_) self._debug(hidden_states, "norm_1", kwargs.get(BlockKwargs.hidden_dims), kwargs) - hidden_states, bias = self.mixer(hidden_states, kwargs) + hidden_states, bias = self.mixer(hidden_states, kwargs, metrics=metrics) - hidden_states, bias = self.activation_distillation_loss(hidden_states, bias, kwargs, losses) + hidden_states, bias = self.activation_distillation_loss(hidden_states, bias, kwargs, losses, metrics) with set_generator(generator): input_ = self._bias_dropout_add(hidden_states, bias, input_) @@ -154,7 +154,7 @@ def forward( hidden_states = torch.stack((fw_input, hidden_states), dim=0) return hidden_states - def activation_distillation_loss(self, hidden_states, bias, kwargs, losses): + def activation_distillation_loss(self, hidden_states, bias, kwargs, losses, metrics): """ Maybe apply activation distillation loss and setup backward hooks. """ @@ -178,26 +178,68 @@ def activation_distillation_loss(self, hidden_states, bias, kwargs, losses): # L2 loss activation_loss_factor = self._config.activation_distillation_factor # (batch, sequence, hidden) or (sequence, batch, hidden). Take the norm over hidden dim. - # TODO: handle possible padding? - local_loss_sum = torch.sum(torch.norm(mixer_output - teacher_tensor, p=2, dim=(2))) - # mixer_output.shape is (batch, sequence, hidden) or (sequence, batch, hidden) - # In either case, dims 0 and 1 are batch and sequence - total_count = mixer_output.shape[0] * mixer_output.shape[1] + + # Handle possible padding by using pre-computed activation mask + sequence_first = kwargs.get(BlockKwargs.sequence_first, False) + activation_mask = kwargs.get(BlockKwargs.activation_mask) + + if activation_mask is not None: + # Use pre-computed activation mask (bool tensor where True = valid token) + mask = activation_mask.to(dtype=mixer_output.dtype) + if sequence_first: + # (batch, sequence) -> (sequence, batch) + mask = mask.T + + # Compute masked L2 loss: norm over hidden dim, then apply mask + per_token_loss = torch.norm( + mixer_output - teacher_tensor, p=2, dim=-1 + ) # (batch, sequence) or (sequence, batch) + masked_loss = per_token_loss * mask + local_loss_sum = torch.sum(masked_loss) + total_count = int(mask.sum().item()) + else: + # No activation_mask available, compute loss on all tokens + per_token_loss = torch.norm( + mixer_output - teacher_tensor, p=2, dim=-1 + ) # (batch, sequence) or (sequence, batch) + local_loss_sum = torch.sum(per_token_loss) + # mixer_output.shape is (batch, sequence, hidden) or (sequence, batch, hidden) + # In either case, dims 0 and 1 are batch and sequence + total_count = mixer_output.shape[0] * mixer_output.shape[1] # All-reduce across tensor-parallel group if sequence-parallel is enabled if self._sequence_parallel and self._distributed.tensor_group is not None: all_reduce(local_loss_sum, group=self._distributed.tensor_group, op=ReduceOp.SUM) - # Assume all ranks contribute the same count (not the case if padding) - total_count *= self._distributed.tensor_group.size() + if activation_mask is not None: + # Different ranks may have different amounts of padding + total_count_tensor = torch.tensor(total_count, device=mixer_output.device, dtype=torch.int64) + all_reduce(total_count_tensor, group=self._distributed.tensor_group, op=ReduceOp.SUM) + total_count = int(total_count_tensor.item()) + else: + # All ranks contribute the same count + total_count *= self._distributed.tensor_group.size() - activation_loss = activation_loss_factor * (local_loss_sum / total_count) + activation_loss = local_loss_sum / total_count + scaled_activation_loss = activation_loss_factor * activation_loss # Backward hooks - hidden_states = AuxiliaryLoss.apply(hidden_states, activation_loss, 1.0) - bias = AuxiliaryLoss.apply(bias, activation_loss, 1.0) if bias is not None else None + hidden_states = AuxiliaryLoss.apply(hidden_states, scaled_activation_loss, 1.0) + bias = AuxiliaryLoss.apply(bias, scaled_activation_loss, 1.0) if bias is not None else None # Logging if losses is not None and self._activation_distillation_loss_name in losses: losses[self._activation_distillation_loss_name].append(activation_loss.detach()) + # Per-layer metrics + if metrics is not None: + metrics[f"{self.module_name}/activation_distillation_loss"] = activation_loss.detach() + + # If using stochastic mixer, also log per-mixer-type activation distillation loss + from fast_llm.layers.decoder.stochastic_mixer import StochasticMixer + + if isinstance(self.mixer, StochasticMixer): + selected_mixer = self.mixer._last_selected_mixer + metrics[f"{self.module_name}/activation_distillation_loss/{selected_mixer}"] = ( + activation_loss.detach() + ) return hidden_states, bias def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: diff --git a/fast_llm/layers/decoder/stochastic_mixer.py b/fast_llm/layers/decoder/stochastic_mixer.py index 673c64034..984f34b80 100644 --- a/fast_llm/layers/decoder/stochastic_mixer.py +++ b/fast_llm/layers/decoder/stochastic_mixer.py @@ -94,6 +94,10 @@ def __init__( if hasattr(param, "allow_no_grad"): param.allow_no_grad = True + # Track mixer selection counts for logging actual proportions during training + self._mixer_counts_total = {name: 0 for name in self.mixers.keys()} + self._last_selected_mixer = None + def setup(self, distributed: Distributed) -> None: """Setup all mixers with the distributed context.""" super().setup(distributed) @@ -117,6 +121,24 @@ def _forward( ) -> tuple[torch.Tensor, torch.Tensor | None]: mixer_name = self._sample_mixer_name(kwargs) + if self.training: + self._mixer_counts_total[mixer_name] += 1 + self._last_selected_mixer = mixer_name + + if metrics is not None: + # Use module_name as prefix to distinguish between different layer indices + metric_prefix = f"{self.module_name}/stochastic" + + # Instantaneous metric: last selected mixer + metrics[f"{metric_prefix}/last_selected"] = mixer_name + + # Cumulative metrics: total counts and proportions + total_count = sum(self._mixer_counts_total.values()) + for name, count in self._mixer_counts_total.items(): + proportion = count / total_count if total_count > 0 else 0.0 + metrics[f"{metric_prefix}/{name}_count_total"] = count + metrics[f"{metric_prefix}/{name}_proportion_total"] = proportion + if get_model_debug_level() > 0: from fast_llm.layers.block.config import BlockKwargs diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 64e7f1cbd..2f43d1e41 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -217,12 +217,37 @@ def preprocess_batch( pasts = presents presents = None if i == len(preprocessed_meta) - 1 else [] + # Create activation mask for activation distillation + # This mask should: + # - Be 0 on padding tokens (added at the end when documents aren't truncated) + # - Be 1 on image placeholder tokens (token value -100 but not padding) + # - Be 1 on all other valid tokens (ignores loss-masking-spans) + # + # Note: Padding is added as a separate document with all tokens = -100 + # We detect padding by checking if all tokens in a document segment are -100 + activation_mask = torch.ones_like(cropped_tokens.tokens, dtype=torch.bool) + + for sample_index, sample_lengths in enumerate(cropped_tokens.lengths): + # Iterate through documents in this sample + pos = 0 + for doc_length in sample_lengths: + # Check if this document is padding (all tokens are -100) + doc_tokens = cropped_tokens.tokens[sample_index, pos : pos + doc_length] + is_padding_doc = torch.all(doc_tokens == -100).item() + + if is_padding_doc: + # This is a padding document, mask it out + activation_mask[sample_index, pos : pos + doc_length] = False + + pos += doc_length + kwargs: dict[str, typing.Any] = { **kwargs_meta, AttentionKwargs.past_key_values: pasts, AttentionKwargs.presents: presents, BlockKwargs.iteration: iteration, AttentionKwargs.sequence_lengths: cropped_tokens.lengths, + BlockKwargs.activation_mask: activation_mask, AttentionKwargs.device: self._distributed.device, BlockKwargs.hidden_states: {}, **reference_logits[i], From 4144317bdeeecb394889093b937e7e7d1468d5b4 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 22 Dec 2025 12:50:31 -0500 Subject: [PATCH 04/12] misc --- fast_llm/data/data/data_loader.py | 2 +- fast_llm/data/dataset/config.py | 5 + fast_llm/data/dataset/monitor.py | 8 + fast_llm/data/dataset/streaming.py | 12 +- fast_llm/engine/distributed/config.py | 25 +- fast_llm/engine/distributed/distributed.py | 3 +- fast_llm/engine/training/trainer_events.py | 4 +- setup.cfg | 2 +- tests/conftest.py | 5 +- tests/data/gptdata_streaming_test.py | 109 ---- tests/data/test_streaming.py | 523 +++++++------------- tests/models/distributed_test_checkpoint.py | 90 ---- tests/models/distributed_test_model.py | 54 -- tests/models/test_checkpoint.py | 66 ++- tests/models/test_model.py | 45 +- tests/trainer/events_fake_consumer.py | 4 +- tests/trainer/test_events.py | 7 +- tests/utils/redis.py | 137 ++--- tests/utils/run_test_script.py | 17 +- tests/utils/subtest.py | 254 ++++++++++ tests/utils/utils.py | 143 ------ 21 files changed, 607 insertions(+), 908 deletions(-) delete mode 100644 tests/data/gptdata_streaming_test.py delete mode 100644 tests/models/distributed_test_checkpoint.py delete mode 100644 tests/models/distributed_test_model.py create mode 100644 tests/utils/subtest.py diff --git a/fast_llm/data/data/data_loader.py b/fast_llm/data/data/data_loader.py index ad8ad3cf6..ba7e5e612 100644 --- a/fast_llm/data/data/data_loader.py +++ b/fast_llm/data/data/data_loader.py @@ -61,7 +61,7 @@ def __next__(self): # (pickles) them before sending. try: - data = next(self.iterator) # may raise StopIteration + data = next(self._iterator) # may raise StopIteration except Exception as e: data = e data = broadcast_object(data, self._process_group, 0) diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index d9ba5e7d2..003b1dfb0 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -316,6 +316,11 @@ class RedisConfig(Config): hint=FieldHint.core, ) + def get_client(self): + import redis + + return redis.Redis(self.host, self.port) + @config_class(dynamic_type={SampledDatasetConfig: "streaming"}) class StreamingDatasetConfig[SampleType: LanguageModelSample](RedisConfig, SamplableDatasetConfig[SampleType]): diff --git a/fast_llm/data/dataset/monitor.py b/fast_llm/data/dataset/monitor.py index 01f3195e4..27070f674 100644 --- a/fast_llm/data/dataset/monitor.py +++ b/fast_llm/data/dataset/monitor.py @@ -51,3 +51,11 @@ def __getitem__(self, index: int) -> SampleType: @property def name(self) -> str: return self._dataset.name + + @property + def requires_broadcast(self) -> bool: + """ + Some dataset schemes load the dataset on a batch-data-parallel group leaders, + then broadcast to the other devices. + """ + return self._dataset.requires_broadcast diff --git a/fast_llm/data/dataset/streaming.py b/fast_llm/data/dataset/streaming.py index a1f0f32f9..cff028f62 100644 --- a/fast_llm/data/dataset/streaming.py +++ b/fast_llm/data/dataset/streaming.py @@ -29,12 +29,12 @@ class RedisStreamingDataset[ConfigType: StreamingDatasetConfig, SampleType: Lang ): def __init__(self, config: ConfigType, distributed_config: DistributedConfig): super().__init__(config) - if distributed_config.pipeline_parallel > 1: - # NOTE: It is not yet clear whether the issue comes from the streaming dataset - # itself or from the distributed data-loader wrappers, but currently it - # interferes with pipeline-parallel training and causes a timeout during - # the training step. - raise NotImplementedError("Streaming dataset support is not implemented for pipeline-parallel training.") + # if distributed_config.pipeline_parallel > 1: + # NOTE: It is not yet clear whether the issue comes from the streaming dataset + # itself or from the distributed data-loader wrappers, but currently it + # interferes with pipeline-parallel training and causes a timeout during + # the training step. + # raise NotImplementedError("Streaming dataset support is not implemented for pipeline-parallel training.") self._name = f"redis[{config.host}:{config.port}]({REDIS_DATA_KEY}|{REDIS_GROUP_NAME})[data]" self._config = config diff --git a/fast_llm/engine/distributed/config.py b/fast_llm/engine/distributed/config.py index 48800db8f..624ac22d4 100644 --- a/fast_llm/engine/distributed/config.py +++ b/fast_llm/engine/distributed/config.py @@ -98,13 +98,13 @@ def from_sizes_and_strides(cls, name: str, global_rank: int, *sizes_and_strides: for size, stride in sizes_and_strides: if size == 1: continue - if len(global_ranks) == 1 or ( - isinstance(global_ranks, range) and stride == global_ranks.stop - global_ranks.start - ): - global_ranks = range(start, start + size * stride, sizes_and_strides[0][1]) + if len(global_ranks) == 1: + global_ranks = range(start, start + size * stride, stride) + elif isinstance(global_ranks, range) and stride == global_ranks.stop - global_ranks.start: + global_ranks = range(start, start + size * stride, global_ranks.step) else: - global_ranks = (rank0 + rank1 for rank1 in range(0, size, stride) for rank0 in global_ranks) - global_ranks = global_ranks if isinstance(global_ranks, range) else list(global_ranks) + global_ranks = [rank0 + rank1 for rank1 in range(0, size * stride, stride) for rank0 in global_ranks] + Assert.eq(len(global_ranks), world_size) return DistributedDim(name=name, size=world_size, rank=rank, global_ranks=global_ranks) @@ -348,8 +348,16 @@ def _validate(self) -> None: self._add_distributed_dim_from_sizes_and_strides( DistributedDimNames.model_and_sequence_data, (self.tensor_parallel, tensor_stride), - (self.sequence_data_parallel, sequence_data_stride), - (self.pipeline_parallel, pipeline_stride), + ( + (self.pipeline_parallel, pipeline_stride) + if self.pipeline_first + else (self.sequence_data_parallel, sequence_data_stride) + ), + ( + (self.sequence_data_parallel, sequence_data_stride) + if self.pipeline_first + else (self.pipeline_parallel, pipeline_stride) + ), ) super()._validate() @@ -362,7 +370,6 @@ def _add_distributed_dim_from_sizes_and_strides(self, name: str, *sizes_and_stri self._add_distributed_dim(DistributedDim.from_sizes_and_strides(name, self.rank, *sizes_and_strides)) def _add_distributed_dim(self, distributed_dim: DistributedDim) -> None: - log("AAAAAA", distributed_dim, distributed_dim.global_ranks, distributed_dim.rank, self.rank, self.world_size) Assert.eq(distributed_dim.global_ranks[distributed_dim.rank], self.rank, msg=distributed_dim) try: diff --git a/fast_llm/engine/distributed/distributed.py b/fast_llm/engine/distributed/distributed.py index 19b50771b..99871a850 100644 --- a/fast_llm/engine/distributed/distributed.py +++ b/fast_llm/engine/distributed/distributed.py @@ -27,6 +27,7 @@ def __init__( local_world_size: int | None = None, timeout: float = 60, use_cpu: bool = False, + init_method: str = "env://", ): self._rank = DistributedConfig.default_rank if rank is None else rank @@ -54,7 +55,7 @@ def __init__( # TODO: Allow other init methods? self.store, _, _ = next( torch.distributed.rendezvous( - "env://", + init_method, self._rank, self._world_size, timeout=datetime.timedelta(seconds=timeout), diff --git a/fast_llm/engine/training/trainer_events.py b/fast_llm/engine/training/trainer_events.py index 3bbd60df6..0937999fc 100644 --- a/fast_llm/engine/training/trainer_events.py +++ b/fast_llm/engine/training/trainer_events.py @@ -1,6 +1,6 @@ +import json import logging -import orjson import redis import torch.distributed @@ -34,7 +34,7 @@ def send(self, msg_type: str, payload: dict | None = None): payload = {} payload.update({"type": msg_type}) - self.client.xadd(REDIS_TRAINING_KEY, {"event": orjson.dumps(payload)}) + self.client.xadd(REDIS_TRAINING_KEY, {"event": json.dumps(payload)}) class TrainerEvents: diff --git a/setup.cfg b/setup.cfg index 495a9cf20..f4ad02c43 100644 --- a/setup.cfg +++ b/setup.cfg @@ -61,7 +61,7 @@ GENERATION = lm_eval>=0.4.9 STREAMING = - redis>=-7.1.0 + redis>=7.1.0 # Required for supporting vision inputs VISION = diff --git a/tests/conftest.py b/tests/conftest.py index df56c78ab..baef9d1de 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -27,14 +27,13 @@ from tests.utils.run_test_script import ( # isort: skip compare_results_for_all_models, run_distributed_script, - run_distributed_script_lean, run_test_script_base_path, run_test_script_for_all_models, ) -from tests.utils.redis import fake_redis_server, stream_config # isort: skip from tests.utils.model_configs import model_testing_config, ModelTestingConfig, testing_group_enabled # isort: skip -from tests.utils.utils import result_path, format_resource_report, report_subtest # isort: skip +from tests.utils.utils import result_path # isort: skip +from tests.utils.subtest import format_resource_report, report_subtest, run_parallel_script # isort: skip # Import all dynamic classes. import fast_llm.cli # isort: skip diff --git a/tests/data/gptdata_streaming_test.py b/tests/data/gptdata_streaming_test.py deleted file mode 100644 index 859d90af0..000000000 --- a/tests/data/gptdata_streaming_test.py +++ /dev/null @@ -1,109 +0,0 @@ -import argparse -import pathlib -import pickle - -from fast_llm.config import NoAutoValidate -from fast_llm.data.data.gpt.config import GPTDataConfig -from fast_llm.data.data.gpt.data import GPTData -from fast_llm.data.dataset.config import StreamingDatasetConfig -from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.models.gpt.config import GPTBatchConfig -from tests.utils.redis import make_sampling - - -def distributed_gptdata_streaming_test( - sequence_length, - micro_batch_size, - batch_size, - tensor_parallel, - pipeline_parallel, - sequence_data_parallel, - total_gpus, - redis_port, - result_path, -): - stream_config = StreamingDatasetConfig(port=redis_port) - - distributed = Distributed( - DistributedConfig( - tensor_parallel=tensor_parallel, - pipeline_parallel=pipeline_parallel, - sequence_data_parallel=sequence_data_parallel, - ), - use_cpu=total_gpus == 0, - ) - sampling_data = make_sampling(sequence_length, 0, micro_batch_size, distributed) - - data_config = {"datasets": {"streaming1": stream_config.to_dict()}, "sampling": {"shuffle": "disabled"}} - data_config = GPTDataConfig.from_dict(data_config) - - data = GPTData(data_config, distributed.config) - - data.setup( - distributed=distributed, - sampling_parameters={"streaming1": sampling_data.parameters}, - preprocessing={}, - cache_directory="/tmp", - ) - - with NoAutoValidate(): - batch_config = GPTBatchConfig( - micro_batch_size=micro_batch_size, batch_size=batch_size, sequence_length=sequence_length - ) - batch_config.setup(distributed_config=distributed.config) - batch_config.validate() - - data_iter = data.get_iterator(batch_config, "streaming1", consumed_samples=0, num_workers=1, prefetch_factor=1) - - batch = next(data_iter) - # TODO: save result per batch_data_group and rank - assert batch.tokens.tokens.shape == (micro_batch_size, sequence_length) - - result_path = ( - pathlib.Path(result_path) - / ( - f"{distributed.config.batch_data_rank}_" - f"{distributed.model_and_sequence_data_group.rank() if distributed.model_and_sequence_data_group is not None else 0}" - ) - / "batch.pkl" - ) - result_path.parent.mkdir(exist_ok=True, parents=True) - with result_path.open("wb") as f: - pickle.dump(batch, f) - - -def parse_args(): - parser = argparse.ArgumentParser(description="Run distributed GPT data streaming test.") - - parser.add_argument("--sequence-length", type=int, required=True, help="Sequence length of the model input.") - parser.add_argument("--micro-batch-size", type=int, required=True, help="Micro batch size.") - parser.add_argument("--batch-size", type=int, required=True, help="Global batch size.") - parser.add_argument("--tensor-parallel", type=int, required=True, help="Tensor parallel degree.") - parser.add_argument("--pipeline-parallel", type=int, required=True, help="Pipeline parallel degree.") - parser.add_argument("--sequence-data-parallel", type=int, required=True, help="Sequence data parallel degree.") - parser.add_argument("--total-gpus", type=int, required=True, help="Total number of GPUs available.") - parser.add_argument("--redis-port", type=int, required=True, help="Redis port to connect to.") - parser.add_argument("--result-path", type=str, required=True, help="Path to save test results.") - - return parser.parse_args() - - -def main(): - args = parse_args() - - distributed_gptdata_streaming_test( - sequence_length=args.sequence_length, - micro_batch_size=args.micro_batch_size, - batch_size=args.batch_size, - tensor_parallel=args.tensor_parallel, - pipeline_parallel=args.pipeline_parallel, - sequence_data_parallel=args.sequence_data_parallel, - total_gpus=args.total_gpus, - redis_port=args.redis_port, - result_path=args.result_path, - ) - - -if __name__ == "__main__": - main() diff --git a/tests/data/test_streaming.py b/tests/data/test_streaming.py index daf15c65e..e16583e8f 100644 --- a/tests/data/test_streaming.py +++ b/tests/data/test_streaming.py @@ -1,377 +1,236 @@ +import contextlib import logging -import os -import pickle +import pathlib +import typing import fakeredis import pytest +import redis import torch +from fast_llm.config import NoAutoValidate +from fast_llm.core.distributed import safe_barrier +from fast_llm.data.data.gpt.config import GPTDataConfig +from fast_llm.data.data.gpt.data import GPTData +from fast_llm.data.dataset.config import RedisConfig, SamplingParameters, StreamingDatasetConfig from fast_llm.data.dataset.streaming import RedisStreamingDataset +from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.data.sample.language_model import LanguageModelSample -from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.engine.distributed.distributed import Distributed -from tests.utils.redis import make_sampling, push_msg, redis_batch_producer +from fast_llm.models.gpt.config import GPTBatchConfig +from fast_llm.utils import Assert +from tests.utils.redis import find_free_port, make_sampling, push_msg, redis_batch_producer +from tests.utils.subtest import DistributedTestContext from tests.utils.utils import requires_cuda logger = logging.getLogger(__name__) -# --------------------------------------------------------------------- -# Fixtures -# --------------------------------------------------------------------- - - -@pytest.fixture -def fake_redis(): - """Return a FakeRedis instance.""" - return fakeredis.FakeRedis() - - @pytest.fixture -def monkeypatched_redis(monkeypatch, fake_redis): - """Monkeypatch redis.Redis globally (works even for imports inside functions).""" - import redis - +def fake_redis(monkeypatch): + """Monkeypatch redis.Redis globally.""" + fake_redis = fakeredis.FakeRedis() monkeypatch.setattr(redis, "Redis", lambda *args, **kwargs: fake_redis) - return fake_redis - - -# --------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------- - - -def generate_parallelism_variants(total_gpus: int): - """ - Generate all valid variants of (data_groups, tensor_parallel, pipeline_parallel, sequence_parallel) - for a number of GPUs up to the total_gpus. - If total_gpus is odd and > 1, fallback to nearest lower even number for decomposable parallelism. - """ - if total_gpus > 1 and total_gpus % 2 == 1: - total_gpus = total_gpus - 1 - - if total_gpus < 2: - # No gpu and one gpu tests are the same, - # so no need of creation of variant for a single gpu - return [] + try: + yield fake_redis + finally: + fake_redis.close() - variants = [] - for gpus in range(2, total_gpus + 1, 2): - # try all possible numbers of data groups (1..total_gpus) - for data_groups in range(1, gpus + 1): - if gpus % data_groups != 0: - continue # cannot evenly split - - gpus_per_group = gpus // data_groups - - # now find all decompositions of gpus_per_group into tp*pp*sp - for tp in range(1, gpus_per_group + 1): - if gpus_per_group % tp != 0: - continue - rem_after_tp = gpus_per_group // tp - # TODO: currently streaming dataset does not support pipeline parallel setup - # for pp in range(1, rem_after_tp + 1): - for pp in range(1, 2): - if rem_after_tp % pp != 0: - continue - sp = rem_after_tp // pp - try: - # instead of repeating all safeguards here just try to - # instantiate distributed config to check if combination is valid - dist_config = DistributedConfig( - tensor_parallel=tp, - pipeline_parallel=pp, - sequence_data_parallel=sp, - world_size=gpus, - # TODO: works only on one node - local_world_size=gpus, - rank=0, - ) - except Exception: - continue - - variants.append( - { - "data_groups": data_groups, - "batch_data_parallel": dist_config.batch_data_parallel, - "tensor_parallel": tp, - "pipeline_parallel": pp, - "sequence_data_parallel": sp, - "total_gpus": gpus, - } - ) - return variants - - -def run_distributed_gptdata_streaming_test( - fake_redis_server, - variant, - run_distributed_script, - result_path, - request, -): - import tests.data.gptdata_streaming_test - - stream_config, fake_redis, fake_redis_server_killer = fake_redis_server - - sequence_length = 10 - micro_batch_size = 2 - batch_size = micro_batch_size * variant["batch_data_parallel"] - tensor_parallel = variant["tensor_parallel"] - pipeline_parallel = variant["pipeline_parallel"] - sequence_data_parallel = variant["sequence_data_parallel"] - total_gpus = variant["total_gpus"] - redis_port = stream_config.port - - result_path = result_path / "distributed_gptdata_streaming_test" / request.node.name - - with redis_batch_producer( - redis_client=fake_redis, - fake_redis_server_killer=fake_redis_server_killer, - batch_size=batch_size, - sequence_length=10, - ): - if total_gpus > 0: - script = [ - "-m", - tests.data.gptdata_streaming_test.__name__, - "--sequence-length", - str(sequence_length), - "--micro-batch-size", - str(micro_batch_size), - "--batch-size", - str(batch_size), - "--tensor-parallel", - str(tensor_parallel), - "--pipeline-parallel", - str(pipeline_parallel), - "--sequence-data-parallel", - str(sequence_data_parallel), - "--total-gpus", - str(total_gpus), - "--result-path", - str(result_path), - "--redis-port", - str(redis_port), - ] - # TODO: distributed_capture is ignored now inside the script - if request.config.getoption("distributed_capture"): - logger.warning( - "Capturing output and forwarding to associated tests. Run with `--no-distributed-capture` to disable." - ) - else: - script.append("--no-distributed-capture") - - env = os.environ.copy() - env["PYTHONHASHSEED"] = "42" - run_distributed_script(script, num_gpus=total_gpus, env=env) - else: - tests.data.gptdata_streaming_test.distributed_gptdata_streaming_test( - sequence_length=sequence_length, - micro_batch_size=micro_batch_size, - batch_size=batch_size, - tensor_parallel=tensor_parallel, - pipeline_parallel=pipeline_parallel, - sequence_data_parallel=sequence_data_parallel, - total_gpus=total_gpus, - redis_port=redis_port, - result_path=result_path, - ) - - check_distributed_gptdata_streaming_test_results( - result_path=result_path, - micro_batch_size=micro_batch_size, - batch_data_parallel=variant["batch_data_parallel"], - total_gpu=variant["total_gpus"], - ) - - -def check_distributed_gptdata_streaming_test_results( - result_path, - micro_batch_size, - batch_data_parallel, - total_gpu, +@pytest.mark.parametrize( + "messages", + [ + (range(3),), + (range(3), range(3, 7)), + (range(3), range(5), [], [9, 4]), + ], +) +def test_streaming_dataset( + fake_redis: fakeredis.FakeRedis, + messages: tuple[list[int], ...], ): - batch_data_parallel_size = total_gpu // batch_data_parallel if total_gpu > 0 else 1 - sample_idx = set() - for i in range(batch_data_parallel): - ref_batch = None - for j in range(batch_data_parallel_size): - with (result_path / f"{i}_{j}" / "batch.pkl").open("rb") as f: - batch = pickle.load(f) - if ref_batch is None: - ref_batch = batch - else: - # batches for same batch_data_parallel_group must be equal on all ranks - assert torch.equal(batch.tokens.tokens, ref_batch.tokens.tokens) - for j in range(micro_batch_size): - val = ref_batch.tokens.tokens[j, 0].item() - # all samples in batches between groups and in the batch must be unique - assert val not in sample_idx - sample_idx.add(val) - # unique sample count must be the same as global batch size - assert len(sample_idx) == micro_batch_size * batch_data_parallel - - -# --------------------------------------------------------------------- -# Tests -# --------------------------------------------------------------------- - - -def test_streaming_dataset_reads_single_message(monkeypatched_redis, stream_config): - """StreamingDataset should read a message and convert it into LanguageModelSample.""" - fake_redis = monkeypatched_redis - - dataset = RedisStreamingDataset(stream_config, DistributedConfig()) - - # Insert a message - push_msg(fake_redis, [1, 2, 3]) - - it = iter(dataset) - sample = next(it) - - assert isinstance(sample, LanguageModelSample) - assert torch.equal(sample.tokens.tokens, torch.tensor([1, 2, 3], dtype=torch.int64)) - assert sample.tokens.lengths == [3] - assert sample.loss_masking_spans is None - assert sample.chosen_spans is None - assert sample.rejected_spans is None - - -def test_streaming_dataset_reads_multiple_messages(monkeypatched_redis, stream_config): """StreamingDataset should read a message and convert it into LanguageModelSample.""" - fake_redis = monkeypatched_redis - - dataset = RedisStreamingDataset(stream_config, DistributedConfig()) - - # Insert a message - push_msg(fake_redis, [1, 2, 3]) - push_msg(fake_redis, [1, 2, 3]) - push_msg(fake_redis, [1, 2, 3]) - - it = iter(dataset) - for i in range(3): - sample = next(it) - + stream_config = StreamingDatasetConfig(port=find_free_port()) + dataset_iterator = iter(RedisStreamingDataset(stream_config, DistributedConfig())) + for message in messages: + push_msg(fake_redis, list(message)) + for message in messages: + sample = next(dataset_iterator) assert isinstance(sample, LanguageModelSample) - assert torch.equal(sample.tokens.tokens, torch.tensor([1, 2, 3], dtype=torch.int64)) - assert sample.tokens.lengths == [3] + Assert.eq(sample.tokens.tokens.tolist(), list(message)) + Assert.eq(sample.tokens.lengths, [len(message)]) assert sample.loss_masking_spans is None assert sample.chosen_spans is None assert sample.rejected_spans is None -def test_sampling_1_doc_exact_fit(monkeypatched_redis, stream_config): - """Docs exactly fill one sample.""" - fake_redis = monkeypatched_redis - - push_msg(fake_redis, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) - - distributed = Distributed(DistributedConfig(), use_cpu=True) - sampler = RedisStreamingDataset(stream_config, distributed.config).sample(make_sampling(10, 0, 1, distributed)) - - out = next(iter(sampler)) - - assert isinstance(out, LanguageModelSample) - assert len(out) == 10 - assert out.tokens.tokens.tolist() == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] - - -def test_sampling_2_docs_exact_fit(monkeypatched_redis, stream_config): - """Docs exactly fill one sample.""" - fake_redis = monkeypatched_redis - - # Two rollouts: lengths 4 and 6 -> exactly 10 - push_msg(fake_redis, [1, 2, 3, 4]) - push_msg(fake_redis, [5, 6, 7, 8, 9, 10]) - +@pytest.mark.parametrize( + ("messages", "expected_samples", "expected_lengths"), + [ + ((range(5),), (range(5),), ([5],)), # Single message, exact fit. + ((range(3), [3, 4]), (range(5),), ([3, 2],)), # Two messages, exact fit. + ((range(6), range(5)), (range(5),), ([5],)), # Two messages, one dropped. + ( + (range(3), range(5)), + ( + [0, 1, 2, -100, -100], + range(5), + ), + ( + [3, 2], + [5], + ), + ), # Two messages, one padded. + ], +) +def test_streaming_sampled_dataset( + fake_redis: fakeredis.FakeRedis, + messages: tuple[list[int], ...], + expected_samples: tuple[list[int], ...], + expected_lengths: tuple[int, ...], +): + """StreamingDataset should read a message and convert it into LanguageModelSample.""" + stream_config = StreamingDatasetConfig(port=find_free_port()) distributed = Distributed(DistributedConfig(), use_cpu=True) - sampler = RedisStreamingDataset(stream_config, distributed.config).sample(make_sampling(10, 0, 1, distributed)) - - out = next(iter(sampler)) - - assert isinstance(out, LanguageModelSample) - assert len(out) == 10 - assert out.tokens.tokens.tolist() == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] - - -def test_sampling_skips_too_long_doc_and_padding_final(monkeypatched_redis, stream_config): - """Rollout longer than sample_length must be dropped.""" - fake_redis = monkeypatched_redis + dataset_iterator = iter( + RedisStreamingDataset(stream_config, distributed.config).sample(make_sampling(5, 1, distributed)) + ) + for message in messages: + push_msg(fake_redis, list(message)) + for expected_sample, expected_lengths_ in zip(expected_samples, expected_lengths, strict=True): + sample = next(dataset_iterator) + assert isinstance(sample, LanguageModelSample) + Assert.eq(sample.tokens.tokens.tolist(), list(expected_sample)) + Assert.eq(sample.tokens.lengths, expected_lengths_) + assert sample.loss_masking_spans is None + assert sample.chosen_spans is None + assert sample.rejected_spans is None - push_msg(fake_redis, list(range(20))) # skip: too long - push_msg(fake_redis, list(range(10))) # usable - distributed = Distributed(DistributedConfig(), use_cpu=True) - sampler = RedisStreamingDataset(stream_config, distributed.config).sample(make_sampling(10, 0, 1, distributed)) +_NUM_BATCHES = 1 - out = next(iter(sampler)) - # too big message is skipped and next message is returned instead - assert len(out) == 10 - assert out.tokens.tokens.tolist() == list(range(10)) +def _get_distributed_and_batch_config( + distributed_config_dict: dict[str, typing.Any], world_size: int = 1 +) -> tuple[DistributedConfig, GPTBatchConfig]: + distributed_config = DistributedConfig.from_dict( + distributed_config_dict, {"world_size": world_size, "local_world_size": world_size} + ) + with NoAutoValidate(): + batch_config = GPTBatchConfig(micro_batch_size=2, sequence_length=10) + batch_config.setup(distributed_config=distributed_config) + batch_config.validate() + return distributed_config, batch_config -def test_sampling_overflow_creates_two(monkeypatched_redis, stream_config): - """A document overflowing the boundary triggers padding + next sample.""" - fake_redis = monkeypatched_redis +def _run_test_data_streaming( + path: pathlib.Path, distributed_config: DistributedConfig, batch_config: GPTBatchConfig, port: int +): + redis_config = RedisConfig(port=port + 100) - push_msg(fake_redis, list(range(6))) - push_msg(fake_redis, list(range(10))) + data = GPTData(GPTDataConfig(datasets={"train": {"type": "streaming", "port": port + 100}}), distributed_config) + distributed = Distributed(distributed_config) + with ( + redis_batch_producer(redis_config, batch_config) if distributed_config.rank == 0 else contextlib.nullcontext() + ): + data.setup( + distributed=distributed, + sampling_parameters={ + "train": SamplingParameters( + sequence_length=batch_config.sequence_length, + extra_tokens=0, + num_samples=batch_config.batch_size * _NUM_BATCHES, + truncate_documents=False, + ) + }, + preprocessing=LanguageModelPreprocessingConfig(), + cache_directory=path / "cache", + timeout=5, + ) + + data_iter = data.get_iterator(batch_config, "train", consumed_samples=0, num_workers=0, prefetch_factor=None) + batches = [next(data_iter) for _ in range(_NUM_BATCHES)] + path.mkdir(parents=True, exist_ok=True) + torch.save( + torch.stack([batch.tokens.tokens[:, 0] for batch in batches]), + path / f"rank_{distributed_config.batch_data_rank}_" + f"{distributed_config.get_distributed_dim(DistributedDimNames.model_and_sequence_data).rank}.pt", + ) + # Wait for other processes to finish before shutting down the server. + safe_barrier(distributed.world_group, "streaming test end") + + +def check_data_streaming_results( + path: pathlib.Path, + distributed_config: DistributedConfig, + batch_config: GPTBatchConfig, +): + sample_indexes = set() + for batch_data_rank in range(distributed_config.batch_data_parallel): + batches_tokens = torch.load(path / f"rank_{batch_data_rank}_0.pt") + Assert.eq(batches_tokens.shape, (_NUM_BATCHES, batch_config.micro_batch_size)) + for model_and_sequence_data_rank in range( + 1, distributed_config.get_distributed_dim(DistributedDimNames.model_and_sequence_data).size + ): + Assert.all_equal( + torch.load(path / f"rank_{batch_data_rank}_{model_and_sequence_data_rank}.pt"), batches_tokens + ) + sample_indexes.update(batches_tokens.flatten().tolist()) + Assert.eq(len(sample_indexes), _NUM_BATCHES * batch_config.batch_size) - distributed = Distributed(DistributedConfig(), use_cpu=True) - sampler = RedisStreamingDataset(stream_config, distributed.config).sample(make_sampling(10, 0, 2, distributed)) - sampler_iter = iter(sampler) - out = [next(sampler_iter)] - out.append(next(sampler_iter)) +def _run_test_data_streaming_distributed( + test_context: DistributedTestContext, base_path: pathlib.Path, port: int +) -> None: + # Import all dynamic classes. TODO: needed? + import fast_llm.cli # noqa - # sample 1: 0..5 + pad(4) - assert out[0].tokens.tokens.tolist() == list(range(6)) + [-100, -100, -100, -100] + for name, num_gpus, distributed_config_dict in _DISTRIBUTED_TESTING_CONFIGS: + with test_context.subtest(base_path, name, num_gpus) as subtest: + if subtest.do_run: + distributed_config, batch_config = _get_distributed_and_batch_config(distributed_config_dict, num_gpus) + _run_test_data_streaming(base_path / name, distributed_config, batch_config, port) - # sample 2: 0..5 + pad(4) - assert out[1].tokens.tokens.tolist() == list(range(10)) +@requires_cuda +def test_data_streaming(result_path, worker_resources): + distributed_config, batch_config = _get_distributed_and_batch_config({}) + path = result_path / "data_streaming/single_gpu" + _run_test_data_streaming(path, distributed_config, batch_config, worker_resources.torchrun_port) + check_data_streaming_results(path, distributed_config, batch_config) + + +_DISTRIBUTED_TESTING_CONFIGS = [ + ("dp2", 2, {}), + ("sdp2", 2, {"sequence_data_parallel": 2}), + ("tp2", 2, {"tensor_parallel": 2}), + ("pp2", 2, {"pipeline_parallel": 2}), + ("dp2_sdp2", 4, {"sequence_data_parallel": 2}), + ("dp2_tp2", 4, {"tensor_parallel": 2}), + ("dp2_pp2", 4, {"pipeline_parallel": 2}), + ("sdp2_tp2", 4, {"sequence_data_parallel": 2, "tensor_parallel": 2}), + ("sdp2_pp2", 4, {"sequence_data_parallel": 2, "pipeline_parallel": 2}), + ("tp2_pp2", 4, {"tensor_parallel": 2, "pipeline_parallel": 2}), +] -def test_gptdata_streaming_single_consumer(fake_redis_server, run_distributed_script_lean, result_path, request): - run_distributed_gptdata_streaming_test( - fake_redis_server=fake_redis_server, - variant={ - "data_groups": 1, - "tensor_parallel": 1, - "pipeline_parallel": 1, - "sequence_data_parallel": 1, - "total_gpus": 0, - "batch_data_parallel": 1, - }, - run_distributed_script=run_distributed_script_lean, - result_path=result_path, - request=request, +@requires_cuda +@pytest.mark.depends_on(on=["test_data_streaming"]) +def test_run_data_streaming_distributed(run_parallel_script, result_path, worker_resources): + if torch.cuda.device_count() < 2: + pytest.skip(f"Not enough GPUs") + run_parallel_script( + _run_test_data_streaming_distributed, + (result_path / "data_streaming", worker_resources.torchrun_port), + world_size=torch.cuda.device_count(), ) -variants = generate_parallelism_variants(torch.cuda.device_count()) - - -@pytest.mark.slow @requires_cuda -@pytest.mark.parametrize( - "variant", - variants, - ids=[ - f"dg{v['data_groups']}_tp{v['tensor_parallel']}_pp{v['pipeline_parallel']}_sp{v['sequence_data_parallel']}_gpu{v['total_gpus']}" - for v in variants - ], -) -def test_gptdata_streamin_gpus(fake_redis_server, variant, run_distributed_script_lean, result_path, request): - # TODO: make tests on the same number of gpu as subtests - # similar to how it is done in the test_model for speed - run_distributed_gptdata_streaming_test( - fake_redis_server=fake_redis_server, - variant=variant, - run_distributed_script=run_distributed_script_lean, - result_path=result_path, - request=request, - ) +@pytest.mark.depends_on(on=["test_data_streaming"]) +@pytest.mark.parametrize(("name", "num_gpus", "distributed_config_dict"), _DISTRIBUTED_TESTING_CONFIGS) +def test_run_streaming_distributed(result_path, name, num_gpus, distributed_config_dict, report_subtest): + report_subtest(path := result_path / f"data_streaming/{name}", num_gpus) + distributed_config, batch_config = _get_distributed_and_batch_config(distributed_config_dict, num_gpus) + check_data_streaming_results(path, distributed_config, batch_config) diff --git a/tests/models/distributed_test_checkpoint.py b/tests/models/distributed_test_checkpoint.py deleted file mode 100644 index 001eb36da..000000000 --- a/tests/models/distributed_test_checkpoint.py +++ /dev/null @@ -1,90 +0,0 @@ -import gc -import logging - -import torch - -from fast_llm.cli import fast_llm_main_wrapper -from fast_llm.config import NoAutoValidate -from fast_llm.core.distributed import safe_barrier -from fast_llm.engine.checkpoint.config import ( - CheckpointLoadConfig, - CheckpointSaveConfig, - DistributedCheckpointFormat, - FastLLMCheckpointFormat, -) -from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.engine.distributed.distributed import ProcessGroupPool -from fast_llm.engine.multi_stage.config import StageMode -from fast_llm.utils import Assert, header -from tests.utils.model_configs import ModelTestingConfig -from tests.utils.run_test_script import parse_run_distributed_script -from tests.utils.save_load_configs import DISTRIBUTED_SAVE_LOAD_CONFIGS, DistributedSaveLoadConfig -from tests.utils.utils import DistributedSubtestContext - -logger = logging.getLogger(__name__) - - -def _test_load_and_save_parallel( - model_testing_config: ModelTestingConfig, - config: DistributedSaveLoadConfig, -): - logger.info(header(config.name)) - logger.info(f"Loading {config.load_format} checkpoint from {config.load_path}") - with NoAutoValidate(): - load_config = CheckpointLoadConfig(path=config.load_path, format=config.load_format) - load_config.setup(model_testing_config.model_config_class) - load_config.validate() - model = model_testing_config.model_class.from_pretrained( - load_config, - # The world size and rank are already set through environment variable. - {"distributed": config.distributed}, - mode=StageMode.inference, - ) - for save_format in (DistributedCheckpointFormat, FastLLMCheckpointFormat): - logger.info(f"Saving {save_format.name} checkpoint to {config.save_path / save_format.name}") - model.save_checkpoint(CheckpointSaveConfig(path=config.save_path / save_format.name, format=save_format)) - del model - gc.collect() - torch.cuda.empty_cache() - - -def main(args: list[str] | None = None) -> None: - base_path, model_testing_config, do_capture = parse_run_distributed_script(args) - - if do_capture: - logger.warning( - "Capturing output and forwarding to associated tests. Run with `--no-distributed-capture` to disable." - ) - - with ProcessGroupPool(timeout=20) as pool: - failures = [] - world_size = DistributedConfig.default_world_size - rank = DistributedConfig.default_rank - group = pool.get_process_group(range(world_size), rank) - - for config in DISTRIBUTED_SAVE_LOAD_CONFIGS.values(): - if config.load_format == "{checkpoint_format}" and model_testing_config.checkpoint_format is None: - continue - config = config.resolve(base_path, model_testing_config) - Assert.eq(world_size, config.num_gpus) - with DistributedSubtestContext(base_path, config.name, group, world_size, enabled=do_capture) as subtest: - _test_load_and_save_parallel( - model_testing_config=model_testing_config, - config=config, - ) - if not subtest.success: - failures.append(config.name) - - # Final barrier to ensure everything is done before torchrun potentially kills workers. - safe_barrier(group, "testing end") - # Let pytest know how things went. - # These should already be reported above, we repeat for convenience. - if failures: - raise RuntimeError(f"The following subtests failed: {", ".join(failures)}") - else: - logger.warning("All tests passed") - - -if __name__ == "__main__": - with fast_llm_main_wrapper(): - main() diff --git a/tests/models/distributed_test_model.py b/tests/models/distributed_test_model.py deleted file mode 100644 index 890a75077..000000000 --- a/tests/models/distributed_test_model.py +++ /dev/null @@ -1,54 +0,0 @@ -import logging - -from fast_llm.cli import fast_llm_main_wrapper -from fast_llm.core.distributed import safe_barrier -from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.engine.distributed.distributed import ProcessGroupPool -from tests.utils.distributed_configs import DISTRIBUTED_TESTING_CONFIGS -from tests.utils.run_test_script import do_run_test_script_for_all_models, parse_run_distributed_script -from tests.utils.utils import DistributedSubtestContext - -logger = logging.getLogger(__name__) - - -def main(args: list[str] | None = None) -> None: - base_path, model_testing_config, do_capture = parse_run_distributed_script(args) - - if do_capture: - logger.warning( - "Capturing output and forwarding to associated tests. Run with `--no-distributed-capture` to disable." - ) - - # TODO: Why are barriers needed? - with ProcessGroupPool(timeout=60) as pool: - failures = [] - world_size = DistributedConfig.default_world_size - rank = DistributedConfig.default_rank - group = pool.get_process_group(range(world_size), rank) - safe_barrier(group, "start") - - for name, config in DISTRIBUTED_TESTING_CONFIGS.items(): - if model_testing_config.should_skip(config): - continue - if world_size < config.num_gpus: - logger.warning(f"{name} {f"SKIPPED (not enough GPUs: {world_size} < {config.num_gpus})"})") - continue - with DistributedSubtestContext(base_path, name, group, config.num_gpus, enabled=do_capture) as subtest: - if rank < config.num_gpus: - do_run_test_script_for_all_models(config, model_testing_config, base_path) - if not subtest.success: - failures.append(name) - - # Final barrier to ensure everything is done before torchrun potentially kills workers. - safe_barrier(group, "testing end") - # Let pytest know how things went. - # These should already be reported above, we repeat for convenience. - if failures: - raise RuntimeError(f"The following subtests failed: {", ".join(failures)}") - else: - logger.warning("All tests passed") - - -if __name__ == "__main__": - with fast_llm_main_wrapper(): - main() diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index bb53de29e..83f1ed105 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -1,3 +1,4 @@ +import gc import logging import pathlib import shutil @@ -7,6 +8,7 @@ import torch import yaml +from fast_llm.config import NoAutoValidate from fast_llm.engine.checkpoint.config import ( CheckpointFormat, CheckpointLoadConfig, @@ -16,12 +18,13 @@ ModelConfigType, ) from fast_llm.engine.checkpoint.convert import ConvertConfig -from fast_llm.engine.multi_stage.config import FastLLMModelConfig, ShardName -from fast_llm.utils import Assert +from fast_llm.engine.multi_stage.config import FastLLMModelConfig, ShardName, StageMode +from fast_llm.utils import Assert, header from tests.utils.compare_tensor_logs import CompareConfig from tests.utils.distributed_configs import DistributedTestingConfig from tests.utils.model_configs import ModelTestingConfig, ModelTestingGroup from tests.utils.save_load_configs import DISTRIBUTED_SAVE_LOAD_CONFIGS, DistributedSaveLoadConfig +from tests.utils.subtest import DistributedTestContext from tests.utils.utils import requires_cuda logger = logging.getLogger(__name__) @@ -391,31 +394,54 @@ def test_huggingface_model(model_testing_config, get_convert_path): raise ValueError(f"Comparison failed ({len(errors)} errors)") +def _save_and_load_in_parallel( + test_context: DistributedTestContext, base_path: pathlib.Path, model_testing_config: ModelTestingConfig +) -> None: + # Import all dynamic classes. + import fast_llm.cli # noqa + + for config in DISTRIBUTED_SAVE_LOAD_CONFIGS.values(): + if config.load_format == "{checkpoint_format}" and model_testing_config.checkpoint_format is None: + continue + config = config.resolve(base_path, model_testing_config) + with test_context.subtest(base_path, config.name, config.num_gpus) as subtest: + if subtest.do_run: + logger.info(header(config.name)) + logger.info(f"Loading {config.load_format} checkpoint from {config.load_path}") + with NoAutoValidate(): + load_config = CheckpointLoadConfig(path=config.load_path, format=config.load_format) + load_config.setup(model_testing_config.model_config_class) + load_config.validate() + model = model_testing_config.model_class.from_pretrained( + load_config, + # The world size and rank are already set through environment variable. + {"distributed": config.distributed}, + mode=StageMode.inference, + ) + for save_format in (DistributedCheckpointFormat, FastLLMCheckpointFormat): + logger.info(f"Saving {save_format.name} checkpoint to {config.save_path / save_format.name}") + model.save_checkpoint( + CheckpointSaveConfig(path=config.save_path / save_format.name, format=save_format) + ) + del model + gc.collect() + torch.cuda.empty_cache() + + @requires_cuda @pytest.mark.depends_on(on=["test_load_pretrained[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.convert, ModelTestingGroup.distributed) -def test_save_and_load_in_parallel(run_distributed_script, run_test_script_base_path, model_testing_config, request): +def test_save_and_load_in_parallel(run_parallel_script, run_test_script_base_path, model_testing_config): # Save and load checkpoints to and from various distributed configurations. # Combined in a single test to mitigate process creation overhead. # TODO: Test beyond 2 gpu configs? - import tests.models.distributed_test_checkpoint - if torch.cuda.device_count() < 2: - pytest.skip(f"Not enough GPUs: {torch.cuda.device_count()} < 2") - - script = [ - "-m", - tests.models.distributed_test_checkpoint.__name__, - str(run_test_script_base_path), - model_testing_config.name, - ] - if request.config.getoption("distributed_capture"): - logger.warning( - "Capturing output and forwarding to associated tests. Run with `--no-distributed-capture` to disable." - ) - else: - script.append("--no-distributed-capture") - run_distributed_script(script, num_gpus=2) + pytest.skip(f"Not enough GPUs2") + run_parallel_script( + _save_and_load_in_parallel, + (run_test_script_base_path, model_testing_config), + world_size=2, + ) @pytest.fixture(scope="module") diff --git a/tests/models/test_model.py b/tests/models/test_model.py index d14721142..84b4d99dc 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -1,4 +1,5 @@ import logging +import pathlib import pytest import torch @@ -8,8 +9,10 @@ SIMPLE_TESTING_CONFIG, SINGLE_GPU_TESTING_CONFIGS, ) -from tests.utils.model_configs import ModelTestingGroup -from tests.utils.utils import check_subtest_success, requires_cuda, set_subtest_success +from tests.utils.model_configs import ModelTestingConfig, ModelTestingGroup +from tests.utils.run_test_script import do_run_test_script_for_all_models +from tests.utils.subtest import DistributedTestContext, check_subtest_success, set_subtest_success +from tests.utils.utils import requires_cuda logger = logging.getLogger(__name__) @@ -49,27 +52,33 @@ def test_and_compare_model( compare_results_for_all_models(config) +def _run_model_distributed( + test_context: DistributedTestContext, base_path: pathlib.Path, model_testing_config: ModelTestingConfig +) -> None: + # Import all dynamic classes. + import fast_llm.cli # noqa + + for name, config in DISTRIBUTED_TESTING_CONFIGS.items(): + if model_testing_config.should_skip(config): + continue + with test_context.subtest(base_path, name, config.num_gpus) as subtest: + if subtest.do_run: + do_run_test_script_for_all_models(config, model_testing_config, base_path) + + @requires_cuda @pytest.mark.depends_on(on=["test_model_simple[{model_testing_config}]"]) @pytest.mark.model_testing_group( ModelTestingGroup.distributed, ) -def test_run_model_distributed(run_distributed_script, model_testing_config, run_test_script_base_path, request): - import tests.models.distributed_test_model - - script = [ - "-m", - tests.models.distributed_test_model.__name__, - str(run_test_script_base_path), - model_testing_config.name, - ] - if request.config.getoption("distributed_capture"): - logger.warning( - "Capturing output and forwarding to associated tests. Run with `--no-distributed-capture` to disable." - ) - else: - script.append("--no-distributed-capture") - run_distributed_script(script, num_gpus=torch.cuda.device_count()) +def test_run_model_distributed(run_parallel_script, model_testing_config, run_test_script_base_path, request): + if torch.cuda.device_count() < 2: + pytest.skip(f"Not enough GPUs") + run_parallel_script( + _run_model_distributed, + (run_test_script_base_path, model_testing_config), + world_size=torch.cuda.device_count(), + ) # We don't want to depend on `test_model_distributed` because we still want to run this in cas of failure. diff --git a/tests/trainer/events_fake_consumer.py b/tests/trainer/events_fake_consumer.py index 4c2d30891..9692134db 100644 --- a/tests/trainer/events_fake_consumer.py +++ b/tests/trainer/events_fake_consumer.py @@ -1,7 +1,7 @@ +import json import sys from pathlib import Path -import orjson import redis import safetensors.torch import torch.distributed @@ -67,7 +67,7 @@ def main(): for event_id, msg in events: last_id = event_id assert msg_key in msg - msg = orjson.loads(msg[msg_key].decode()) + msg = json.loads(msg[msg_key].decode()) print(f"{consumer_id} msg received: {msg}") if msg["type"] == config["events"]["weights_broadcast"]["weights_ready_message_type"] or ( msg["type"] == config["events"]["weights_broadcast"]["initial_weights_step_message_type"] diff --git a/tests/trainer/test_events.py b/tests/trainer/test_events.py index b14c4927d..baa0526e3 100644 --- a/tests/trainer/test_events.py +++ b/tests/trainer/test_events.py @@ -11,6 +11,7 @@ import torch import yaml +from fast_llm.data.dataset.config import StreamingDatasetConfig from tests.utils.model_configs import MODEL_CONFIGS from tests.utils.redis import redis_batch_producer from tests.utils.utils import requires_cuda @@ -327,8 +328,8 @@ def generate_variants(num_gpus: int) -> list[dict[str, typing.Any]]: for v in variants ], ) -def test_trainer_events_with_streaming(fake_redis_server, variant, run_distributed_script_lean, result_path, request): - stream_config, fake_redis_client, fake_redis_server_killer = fake_redis_server +def test_trainer_events_with_streaming(variant, run_distributed_script, result_path, request): + stream_config = StreamingDatasetConfig(port=port) test_result_path = result_path / request.node.name test_result_path_fast_llm = test_result_path / "fast_llm" test_result_path_consumers = test_result_path / "consumers" @@ -394,7 +395,7 @@ def test_trainer_events_with_streaming(fake_redis_server, variant, run_distribut ): run_fast_llm_training( model_config=model_config, - run_distributed_script=run_distributed_script_lean, + run_distributed_script=run_distributed_script, assigned_gpus=fast_llm_assigned_gpus, ) check_events_results( diff --git a/tests/utils/redis.py b/tests/utils/redis.py index cc820b84d..7e8072aab 100644 --- a/tests/utils/redis.py +++ b/tests/utils/redis.py @@ -1,22 +1,23 @@ import contextlib +import itertools +import json import pathlib import socket import threading import time import fakeredis -import orjson -import pytest from fast_llm.data.dataset.config import ( + RedisConfig, SamplingConfig, SamplingData, SamplingParameters, - ShufflingType, StreamingDatasetConfig, ) from fast_llm.data.dataset.streaming import REDIS_DATA_KEY, REDIS_GROUP_NAME from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig +from fast_llm.models.gpt.config import GPTBatchConfig def find_free_port(): @@ -26,31 +27,9 @@ def find_free_port(): return s.getsockname()[1] -def push_msg(redis_client, tokens=None, stream_key_suffix=None, payload_key="data", stream_key=REDIS_DATA_KEY): +def push_msg(redis_client, tokens): """Push a message into FakeRedis stream.""" - msg = { - "tokens": tokens, - "tokens_dtype": "int64", - } - if stream_key_suffix is not None: - stream_key += stream_key_suffix - redis_client.xadd(stream_key, {payload_key: orjson.dumps(msg)}) - - -class FakeRedisServerKiller: - def __init__(self, server): - self._server = server - self._is_killed = False - self._lock = threading.Lock() - - def kill(self): - with self._lock: - if not self._is_killed: - try: - self._server.shutdown() - self._server.server_close() - finally: - self._is_killed = True + redis_client.xadd(REDIS_DATA_KEY, {"data": json.dumps({"tokens": tokens, "tokens_dtype": "int64"})}) def wait_until_stream_empty( @@ -86,57 +65,39 @@ def get_consumer_count(redis_client, stop_event, config: StreamingDatasetConfig) @contextlib.contextmanager -def redis_batch_producer(redis_client, fake_redis_server_killer, batch_size, sequence_length, num_batches=None): - stop_event = threading.Event() - thread_exc = [] - - def producer_loop(): - try: - batch_idx = 0 - while not stop_event.is_set(): - if num_batches is not None and batch_idx >= num_batches: +def redis_batch_producer(config: RedisConfig, batch_config: GPTBatchConfig): + with fake_redis_server(config): + stop_event = threading.Event() + client = config.get_client() + + def producer_loop(): + for sample_index in itertools.count(): + if stop_event.is_set(): break - for i in range(batch_size): - if stop_event.is_set(): - return - push_msg( - redis_client, - [batch_idx * batch_size + i] * sequence_length, - ) - wait_until_stream_empty( - redis_client, - REDIS_DATA_KEY, - REDIS_GROUP_NAME, - stop_event, - ) - batch_idx += 1 - except Exception as e: - # if failed to push messages kill redis server so waiting side in the test would unlock - fake_redis_server_killer.kill() - thread_exc.append(e) - raise + push_msg(client, [sample_index] * batch_config.sequence_length) + if sample_index % 5 == 0: + wait_until_stream_empty(client, REDIS_DATA_KEY, REDIS_GROUP_NAME, stop_event) - thread = threading.Thread(target=producer_loop, daemon=True) - thread.start() + thread = threading.Thread(target=producer_loop, daemon=True) + thread.start() - try: - yield - finally: - stop_event.set() - thread.join(timeout=10) - if thread_exc: - raise thread_exc[0] + try: + yield + finally: + stop_event.set() + thread.join(timeout=1) + client.close() -def make_sampling(sequence_length, extra_tokens, num_samples, distributed): +def make_sampling(sequence_length, num_samples, distributed): return SamplingData( parameters=SamplingParameters( sequence_length=sequence_length, - extra_tokens=extra_tokens, + extra_tokens=0, num_samples=num_samples, truncate_documents=False, ), - config=SamplingConfig(shuffle=ShufflingType.disabled), + config=SamplingConfig(), distributed=distributed, dataset_name="test", cache_directory=pathlib.Path("/tmp"), @@ -144,16 +105,9 @@ def make_sampling(sequence_length, extra_tokens, num_samples, distributed): ) -@pytest.fixture -def stream_config(): - # TODO: ======= Not safe for parallel tests? ======= - return StreamingDatasetConfig(port=find_free_port()) - - -@pytest.fixture -def fake_redis_server(stream_config): +@contextlib.contextmanager +def fake_redis_server(config: RedisConfig): # We search for free port as port from previous test can still be not free even after server shutdown - server_address = (stream_config.host, stream_config.port) # ----- Monkey-patch handler to suppress broken pipes ----- orig_handle = fakeredis._tcp_server.TCPFakeRequestHandler.handle @@ -170,27 +124,14 @@ def safe_handle(self): fakeredis._tcp_server.TCPFakeRequestHandler.handle = safe_handle - server = fakeredis.TcpFakeServer(server_address, server_type="redis") - server_killer = FakeRedisServerKiller(server) - - # ----- Start server thread ----- - def serve(): - try: - server.serve_forever() - except Exception: - # Extra safety: catch anything from serve_forever - pass - - thread = threading.Thread(target=serve, daemon=True) + server = fakeredis.TcpFakeServer((config.host, config.port), server_type="redis") + thread = threading.Thread(target=server.serve_forever, daemon=True) thread.start() - # ----- reate a redis-py client pointing at the fake serve ----- - import redis - - client = redis.Redis(host=server_address[0], port=server_address[1]) - - yield stream_config, client, server_killer - - # ----- Teardown ----- - server_killer.kill() - thread.join() + try: + yield + finally: + # ----- Teardown ----- + server.shutdown() + server.server_close() + thread.join() diff --git a/tests/utils/run_test_script.py b/tests/utils/run_test_script.py index e880e67ef..d43855789 100644 --- a/tests/utils/run_test_script.py +++ b/tests/utils/run_test_script.py @@ -47,22 +47,7 @@ def do_run_distributed_script( @pytest.fixture(scope="session") -def run_distributed_script( - worker_resources: "WorkerResources", - run_test_script_base_path: pathlib.Path, - model_testing_config: ModelTestingConfig, -): - return functools.partial( - do_run_distributed_script, - rendezvous_port=worker_resources.rendezvous_port, - torchrun_port=worker_resources.torchrun_port, - ) - - -@pytest.fixture(scope="session") -def run_distributed_script_lean( - worker_resources: "WorkerResources", -): +def run_distributed_script(worker_resources: "WorkerResources"): return functools.partial( do_run_distributed_script, rendezvous_port=worker_resources.rendezvous_port, diff --git a/tests/utils/subtest.py b/tests/utils/subtest.py new file mode 100644 index 000000000..3ba613170 --- /dev/null +++ b/tests/utils/subtest.py @@ -0,0 +1,254 @@ +import functools +import json +import logging +import math +import pathlib +import sys +import time +import traceback +import typing + +import pytest +import torch + +from fast_llm.core.distributed import allreduce_scalar, safe_barrier +from fast_llm.engine.config_utils.logging import configure_logging +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.distributed.distributed import ProcessGroupPool +from fast_llm.utils import Assert, get_and_reset_memory_usage_mib, header + +logger = logging.getLogger(__name__) + + +class DistributedTestContext: + def __init__(self, do_capture: bool, timeout: float = 20.0, init_method: str = "env://"): + self._do_capture = do_capture + self._timeout = timeout + self._init_method = init_method + + def __enter__(self): + if self._do_capture: + logger.warning( + "Capturing output and forwarding to associated tests. Run with `--no-distributed-capture` to disable." + ) + + self._pool = ProcessGroupPool(timeout=self._timeout, init_method=self._init_method).__enter__() + self._rank = self._pool.rank + self._world_size = self._pool.world_size + self._failures = [] + self._configure_logging() + self._group = self._pool.get_process_group(range(self._world_size), self._rank) + # TODO: Barriers needed? + safe_barrier(self._group, "start") + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + # Final barrier to ensure everything is done before torchrun potentially kills workers. + safe_barrier(self._group, "testing end") + # Let pytest know how things went. + # These should already be reported above, we repeat for convenience. + if self._failures: + raise RuntimeError(f"The following subtests failed: {", ".join(self._failures)}") + else: + logger.warning("All tests passed") + + def subtest(self, base_path: pathlib.Path, name: str, num_gpus: int): + return self.DistributedSubtestContext(self, base_path, name, num_gpus) + + def _configure_logging(self): + configure_logging(rank=self._rank, world_size=self._world_size) + + class DistributedSubtestContext: + def __init__( + self, test_context: "DistributedTestContext", base_path: pathlib.Path, name: str, num_gpus: int + ) -> None: + self._test_context = test_context + self._path = base_path / name + self._name = name + self._num_gpus = num_gpus + self._skip = self._test_context._world_size < self._num_gpus + self._do_run = self._test_context._rank < num_gpus and not self._skip + self._do_capture = self._test_context._do_capture and self._do_run + self._success = False + + def __enter__(self) -> typing.Self: + if self._do_capture: + self._sys_stdout = sys.stdout + self._sys_stderr = sys.stderr + self._path.mkdir(parents=True, exist_ok=True) + sys.stdout = self._path.joinpath(f"pytest_stdout_{self._test_context._rank}").open("w") + sys.stderr = self._path.joinpath(f"pytest_stderr_{self._test_context._rank}").open("w") + self._test_context._configure_logging() + # Logging is set to log to the old stdout, so we need to reconfigure. + self._start = time.perf_counter() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self._skip: + # Skipped tests should exit right away. + Assert.none(exc_val) + logger.warning( + f"{self._name} {f"SKIPPED (not enough GPUs: {self._test_context._world_size} < {self._num_gpus})"})" + ) + return + + if self._do_capture: + try: + stdout_handle = sys.stdout + stderr_handle = sys.stderr + sys.stdout = self._sys_stdout + sys.stderr = self._sys_stderr + stdout_handle.close() + stderr_handle.close() + finally: + assert DistributedConfig.default_world_size > 1 + self._test_context._configure_logging() + + if exc_type is None: + self._success = True + else: + self._path.mkdir(parents=True, exist_ok=True) + self._path.joinpath(f"pytest_traceback_{self._test_context._rank}").write_text(traceback.format_exc()) + + logger.warning(f"{self._name} done, waiting for other ranks ({"PASSED" if self._success else "FAILED"})") + + if (group := self._test_context._group) is not None: + # Barrier so `allreduce_scalar` doesn't go crazy in case of desync. + safe_barrier(group, self._name) + self._success = allreduce_scalar(self._success, dtype=torch.int64, group=group) == group.size() + + if self._do_capture: + # Free resources to limit memory usage. + report = get_and_reset_memory_usage_mib(clear_cache=True, global_stats=True, reset_global_stats=True) + report["duration"] = time.perf_counter() - self._start + + json.dump(report, self._path.joinpath(f"pytest_report_{self._test_context._rank}").open("w")) + + if self._test_context._rank == 0: + set_subtest_success(self._path, self._success) + logger.warning(f"{self._name} {"PASSED" if self._success else "FAILED"}") + if not self._success: + self._test_context._failures.append(self._name) + + return True + + @property + def do_run(self) -> bool: + return self._do_run and not self._skip + + +def set_subtest_success(path: pathlib.Path, success: bool = True): + path.joinpath("pytest_success").write_text(str(int(success))) + + +def check_subtest_success(path: pathlib, fail: bool = True) -> bool: + if not path.is_dir(): + if fail: + pytest.fail(f"Test {path.name} did not run", pytrace=False) + else: + return False + try: + return bool(int(path.joinpath("pytest_success").read_text())) + except OSError: + return False + + +def format_resource_report(title: str, report: dict[str, float]) -> str: + return "".join( + [ + f"{title}:\n ", + f"Max Reserved: {report.get("max_reserved", math.nan):.0f} MiB", + f"| Max Allocated: {report.get("max_allocated", math.nan):.0f} MiB".ljust(26), + f"| End Reserved: {report.get("reserved", math.nan):.0f} MiB".ljust(25), + f"| End Allocated: {report.get("allocated", math.nan):.0f} MiB".ljust(26), + f"| Duration: {report.get("duration", math.nan):.2f}".ljust(18), + f"| GPUs: {report["gpus"]:.0f}" if "gpus" in report else "", + ] + ) + + +@pytest.fixture(scope="function") +def report_subtest(request: pytest.FixtureRequest): + verbose = request.config.getoption("verbose") + do_capture = request.config.getoption("distributed_capture") + + def do_report_subtest(path: pathlib.Path, world_size: int) -> None: + success = check_subtest_success(path) + if not do_capture: + logger.warning("Distributed capture is disabled. See distributed test for run output.") + elif verbose > 1 or not success: + for rank in range(world_size): + for fd, file_ in (("stdout", sys.stdout), ("stderr", sys.stdout), ("traceback", sys.stderr)): + print(header(f"{fd} rank {rank}", 80), file=file_) + file_path = path / f"pytest_{fd}_{rank}" + try: + print(file_path.read_text(), file=file_) + except OSError: + print(f"<<< not found {file_path}>>>", file=file_) + else: + print("Set verbose > 1 to show run output.") + + reports = {} + for rank in range(world_size): + try: + reports[f"rank_{rank}"] = json.load(path.joinpath(f"pytest_report_{rank}").open("r")) + except OSError: + reports[rank] = {} + keys = {key for report in reports.values() for key in report} + report = {key: max(report[key] for report in reports.values() if key in report) for key in keys} + report["gpus"] = world_size + reports["global"] = report + + print(header(f"Resource usage", 80), file=sys.stderr) + for name, report in reports.items(): + print(format_resource_report(name, report), file=sys.stderr) + setattr(request.node, "fast_llm_resource_report", report) + + if not success: + raise RuntimeError(f"test {path.name} failed") + + return do_report_subtest + + +def parallel_worker( + rank: int, + world_size: int, + init_method: str, + do_capture: bool, + fn: typing.Callable, + fn_args: typing.Sequence[typing.Any], +): + DistributedConfig.default_rank = rank + DistributedConfig.default_world_size = world_size + DistributedConfig.default_local_world_size = world_size + with DistributedTestContext(do_capture, 60, init_method) as test_context: + fn(test_context, *fn_args) + + +def do_run_parallel_script( + fn: typing.Callable, + fn_args: typing.Sequence[typing.Any], + port: int, + do_capture: bool, + world_size: int, + timeout: float = 240, +): + if do_capture: + logger.warning( + "Capturing output and forwarding to associated tests. Run with `--no-distributed-capture` to disable." + ) + torch.multiprocessing.spawn( + parallel_worker, + args=(world_size, f"tcp://localhost:{port}", do_capture, fn, fn_args), + nprocs=world_size, + join=False, + ).join(timeout, grace_period=63) + + +@pytest.fixture(scope="session") +def run_parallel_script(worker_resources: "WorkerResources", request: pytest.FixtureRequest): + return functools.partial( + do_run_parallel_script, + port=worker_resources.rendezvous_port, + do_capture=request.config.getoption("distributed_capture"), + ) diff --git a/tests/utils/utils.py b/tests/utils/utils.py index 3b79f7607..f0ca20db8 100644 --- a/tests/utils/utils.py +++ b/tests/utils/utils.py @@ -1,23 +1,14 @@ -import json import logging -import math -import pathlib -import sys -import time -import traceback import typing import pytest import torch -from fast_llm.core.distributed import ProcessGroup, allreduce_scalar, safe_barrier from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.base_model.config import set_model_names -from fast_llm.engine.config_utils.logging import configure_logging from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageConfig from fast_llm.engine.multi_stage.stage import Stage -from fast_llm.utils import get_and_reset_memory_usage_mib, header from tests.utils.global_variables import TEST_RESULTS_PATH logger = logging.getLogger(__name__) @@ -65,137 +56,3 @@ def get_stage( stage.restore_parameters() stage.reset_gradients() return stage - - -class DistributedSubtestContext: - def __init__( - self, base_path: pathlib.Path, name: str, group: ProcessGroup | None, num_gpus: int, enabled: bool = True - ) -> None: - self._path = base_path / name - self._name = name - self._group = group - self._rank = 0 if group is None else group.rank() - self._rank_enabled = self._rank < num_gpus - self._enabled = enabled and self._rank_enabled - self.success = False - - def __enter__(self) -> typing.Self: - if self._enabled: - self._sys_stdout = sys.stdout - self._sys_stderr = sys.stderr - self._path.mkdir(parents=True, exist_ok=True) - sys.stdout = self._path.joinpath(f"pytest_stdout_{self._rank}").open("w") - sys.stderr = self._path.joinpath(f"pytest_stderr_{self._rank}").open("w") - # Logging is set to log to the old stdout, so we need to reconfigure. - configure_logging() - self._start = time.perf_counter() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - if self._enabled: - try: - stdout_handle = sys.stdout - stderr_handle = sys.stderr - sys.stdout = self._sys_stdout - sys.stderr = self._sys_stderr - stdout_handle.close() - stderr_handle.close() - finally: - configure_logging() - - if exc_type is None: - self.success = True - else: - self._path.mkdir(parents=True, exist_ok=True) - self._path.joinpath(f"pytest_traceback_{self._rank}").write_text(traceback.format_exc()) - - if self._group is not None: - # Barrier so `allreduce_scalar` doesn't go crazy in case of desync. - safe_barrier(self._group, self._name) - self.success = allreduce_scalar(self.success, dtype=torch.int64, group=self._group) == self._group.size() - - if self._rank_enabled: - # Free resources to limit memory usage. - report = get_and_reset_memory_usage_mib(clear_cache=True, global_stats=True, reset_global_stats=True) - report["duration"] = time.perf_counter() - self._start - - json.dump(report, self._path.joinpath(f"pytest_report_{self._rank}").open("w")) - - logger.warning(f"{self._name} {"PASSED" if self.success else "FAILED"})") - if self._rank == 0: - set_subtest_success(self._path, self.success) - - return True - - -def set_subtest_success(path: pathlib.Path, success: bool = True): - path.joinpath("pytest_success").write_text(str(int(success))) - - -def check_subtest_success(path: pathlib, fail: bool = True) -> bool: - if not path.is_dir(): - if fail: - pytest.fail(f"Test {path.name} did not run", pytrace=False) - else: - return False - try: - return bool(int(path.joinpath("pytest_success").read_text())) - except OSError: - return False - - -def format_resource_report(title: str, report: dict[str, float]) -> str: - return "".join( - [ - f"{title}:\n ", - f"Max Reserved: {report.get("max_reserved", math.nan):.0f} MiB", - f"| Max Allocated: {report.get("max_allocated", math.nan):.0f} MiB".ljust(26), - f"| End Reserved: {report.get("reserved", math.nan):.0f} MiB".ljust(25), - f"| End Allocated: {report.get("allocated", math.nan):.0f} MiB".ljust(26), - f"| Duration: {report.get("duration", math.nan):.2f}".ljust(18), - f"| GPUs: {report["gpus"]:.0f}" if "gpus" in report else "", - ] - ) - - -@pytest.fixture(scope="function") -def report_subtest(request: pytest.FixtureRequest): - verbose = request.config.getoption("verbose") - do_capture = request.config.getoption("distributed_capture") - - def do_report_subtest(path: pathlib.Path, world_size: int) -> None: - success = check_subtest_success(path) - if not do_capture: - logger.warning("Distributed capture is disabled. See distributed test for run output.") - elif verbose > 1 or not success: - for rank in range(world_size): - for fd, file_ in (("stdout", sys.stdout), ("stderr", sys.stdout), ("traceback", sys.stderr)): - print(header(f"{fd} rank {rank}", 80), file=file_) - file_path = path / f"pytest_{fd}_{rank}" - try: - print(file_path.read_text(), file=file_) - except OSError: - print(f"<<< not found {file_path}>>>", file=file_) - else: - print("Set verbose > 1 to show run output.") - - reports = {} - for rank in range(world_size): - try: - reports[f"rank_{rank}"] = json.load(path.joinpath(f"pytest_report_{rank}").open("r")) - except OSError: - reports[rank] = {} - keys = {key for report in reports.values() for key in report} - report = {key: max(report[key] for report in reports.values() if key in report) for key in keys} - report["gpus"] = world_size - reports["global"] = report - - print(header(f"Resource usage", 80), file=sys.stderr) - for name, report in reports.items(): - print(format_resource_report(name, report), file=sys.stderr) - setattr(request.node, "fast_llm_resource_report", report) - - if not success: - raise RuntimeError(f"test {path.name} failed") - - return do_report_subtest From 55c6d0e31185047ba3a04a9c70f77db0c1442d8d Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 22 Dec 2025 12:51:36 -0500 Subject: [PATCH 05/12] Add metadata to dataset config files (#420) --- fast_llm/data/dataset/gpt/config.py | 7 +- fast_llm/data/dataset/memmap.py | 11 +++- .../data/preparator/gpt_memmap/prepare.py | 65 ++++++++++--------- fast_llm/data/sample/abstract.py | 17 +++++ fast_llm/data/sample/language_model.py | 54 +++++++++++++++ fast_llm/data/sample/patch.py | 34 ++++++++++ fast_llm/data/sample/range.py | 20 ++++++ fast_llm/data/sample/token.py | 39 ++++++++++- 8 files changed, 213 insertions(+), 34 deletions(-) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index fc326d366..41a2fe7ff 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -64,7 +64,12 @@ def build(self, preprocessing: PreprocessingConfig) -> SamplableDataset[SampleTy def _load_config(self) -> SampledDatasetConfig[SampleType]: assert self.path.is_file(), f"File {self.path} does not exist." - return SampledDatasetConfig[SampleType].from_dict(self._convert_paths(yaml.safe_load(self.path.open("r")))) + config = yaml.safe_load(self.path.open("r")) + Assert.eq(config.keys(), {"config", "metadata"}) + if config.keys() == {"config", "metadata"}: + # Newer format with metadata + config = config["config"] + return SampledDatasetConfig[SampleType].from_dict(self._convert_paths(config)) def _convert_paths(self, config): # Recursively convert paths relative to `self.path.parent` to make them relative to cwd. diff --git a/fast_llm/data/dataset/memmap.py b/fast_llm/data/dataset/memmap.py index f80a48b0a..9831f81ba 100644 --- a/fast_llm/data/dataset/memmap.py +++ b/fast_llm/data/dataset/memmap.py @@ -8,7 +8,12 @@ from fast_llm.data.dataset.config import SamplingParameters from fast_llm.data.dataset.indexed import IndexedDataset from fast_llm.data.preprocessing.abstract import PreprocessingConfig -from fast_llm.data.sample.abstract import MemmapIndexDatasetReaderConfig, MemmapWriter, Sample +from fast_llm.data.sample.abstract import ( + MemmapIndexDatasetReaderConfig, + MemmapIndexedDatasetReader, + MemmapWriter, + Sample, +) FILE_HEADER = b"fast_llm_prepared_dataset" @@ -82,6 +87,10 @@ def get_document_sizes(self) -> torch.Tensor: def get_document_size(self, index: int) -> int: return self._reader.get_document_size(index) + @property + def reader(self) -> MemmapIndexedDatasetReader: + return self._reader + @classmethod def write_dataset( cls, diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 2ea81d8a6..e0f5f02fc 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -39,7 +39,7 @@ from fast_llm.data.sample.token import TokenSample from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type from fast_llm.engine.config_utils.run import log_main_rank -from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum +from fast_llm.utils import normalize_probabilities, padded_cumsum logger = logging.getLogger(__name__) @@ -346,16 +346,18 @@ def generate_config_yaml_for_sharded_dst( # Create the config file(s) on rank 0 dataset_configs, reader_configs = zip(*dataset_and_reader_configs) if self._config.splits: - for split_name, split_config in self._split_and_blend_dataset_configs( + for split_name, (split_config, metadata) in self._split_and_blend_dataset_configs( dataset_configs, reader_configs, self._config.splits, self._config.output_path ).items(): self._save_dataset_config( - split_config, self._config.output_path / f"fast_llm_config_{split_name}.yaml" + split_config, + metadata, + output_path=self._config.output_path / f"fast_llm_config_{split_name}.yaml", ) else: self._save_dataset_config( - self._blend_dataset_configs(dataset_configs, reader_configs), - self._config.output_path / f"fast_llm_config.yaml", + *self._blend_dataset_configs(dataset_configs, reader_configs), + output_path=self._config.output_path / f"fast_llm_config.yaml", ) # Save metadata on rank 0 @@ -368,29 +370,30 @@ def generate_config_yaml_for_sharded_dst( @classmethod def _save_dataset_config( - cls, dataset_config: IndexedDatasetConfig[_sample_type], output_path: pathlib.Path + cls, + dataset_config: IndexedDatasetConfig[_sample_type], + metadata: dict[str, typing.Any], + output_path: pathlib.Path, ) -> None: logger.info(f"Saving config to {output_path}") - yaml.safe_dump( - dataset_config.to_dict(), - output_path.open("w"), - ) + yaml.safe_dump({"config": dataset_config.to_dict(), "metadata": metadata}, output_path.open("w")) @classmethod def _blend_dataset_configs( cls, dataset_configs: list[MemmapDatasetConfig[_sample_type]], reader_configs: list[MemmapIndexDatasetReaderConfig], - ) -> IndexedDatasetConfig[_sample_type]: + ) -> tuple[IndexedDatasetConfig[_sample_type], dict[str, typing.Any]]: + datasets_metadata = [reader_config.get_metadata() for reader_config in reader_configs] if len(dataset_configs) == 1: - return dataset_configs[0] + return dataset_configs[0], datasets_metadata[0] return SampledDatasetConfig[cls._sample_type].from_dict( { "type": "blended", "datasets": dataset_configs, "weights": [reader_config.num_tokens for reader_config in reader_configs], } - ) + ), reader_configs[0].blend_metadata(datasets_metadata) def _split_and_blend_dataset_configs( self, @@ -398,7 +401,7 @@ def _split_and_blend_dataset_configs( reader_configs: list[MemmapIndexDatasetReaderConfig], splits: dict[str, int | float], output_path: pathlib.Path, - ) -> dict[str, SampledDatasetConfig[_sample_type]]: + ) -> dict[str, tuple[SampledDatasetConfig[_sample_type], dict[str, typing.Any]]]: split_cumsum = padded_cumsum(normalize_probabilities(list(splits.values()), return_array=True)).tolist() dataset_sizes = [reader_config.num_tokens for reader_config in reader_configs] dataset_probabilities = normalize_probabilities(dataset_sizes) @@ -407,7 +410,7 @@ def _split_and_blend_dataset_configs( for split_index, split_name in enumerate(splits): datasets_in_split = [] - dataset_tokens_in_split = [] + datasets_metadata = [] for dataset_index, (dataset_config, reader_config) in enumerate( zip(dataset_configs, reader_configs, strict=True) ): @@ -424,17 +427,17 @@ def _split_and_blend_dataset_configs( if split_begin_in_dataset == 0 and split_end_in_dataset == 1: # All the dataset belongs to the split. datasets_in_split.append(dataset_configs[dataset_index]) - dataset_tokens_in_split.append(dataset_sizes[dataset_index]) + datasets_metadata.append(reader_config.get_metadata()) + elif split_end_in_dataset > split_begin_in_dataset: # Part of the dataset belongs to the split. # TODO: Somehow getting a segfault when merging two lines below (numpy bug?). dataset = dataset_config.to_copy({"path": output_path / dataset_config.path}).build( self._preprocessing_config ) - sizes_cumsum = dataset.get_document_sizes().numpy().cumsum() - Assert.eq(sizes_cumsum[-1], reader_config.num_tokens) - begin_index = _get_nearest_split(sizes_cumsum, split_begin_in_dataset * reader_config.num_tokens) - end_index = _get_nearest_split(sizes_cumsum, split_end_in_dataset * reader_config.num_tokens) + begin_index, end_index, metadata = dataset.reader.get_split( + split_begin_in_dataset, split_end_in_dataset + ) if end_index > begin_index: datasets_in_split.append( DatasetSliceConfig[self._sample_type].from_dict( @@ -446,10 +449,7 @@ def _split_and_blend_dataset_configs( } ) ) - dataset_tokens_in_split.append( - sizes_cumsum[end_index - 1].item() - - (sizes_cumsum[begin_index - 1].item() if begin_index > 0 else 0) - ) + datasets_metadata.append(metadata) # [else] None of the dataset belongs to the split. @@ -457,14 +457,17 @@ def _split_and_blend_dataset_configs( # This is a big problem, but we don't want to crash the whole run. logger.error(f"Datasets split {split_name} is empty!") elif len(datasets_in_split) == 1: - dataset_splits[split_name] = datasets_in_split[0] + dataset_splits[split_name] = (datasets_in_split[0], datasets_metadata[0]) else: - dataset_splits[split_name] = BlendedDatasetConfig[self._sample_type].from_dict( - { - "type": "blended", - "datasets": datasets_in_split, - "weights": dataset_tokens_in_split, - } + dataset_splits[split_name] = ( + BlendedDatasetConfig[self._sample_type].from_dict( + { + "type": "blended", + "datasets": datasets_in_split, + "weights": [dataset_metadata["num_tokens"] for dataset_metadata in datasets_metadata], + } + ), + reader_configs[0].blend_metadata(datasets_metadata), ) return dataset_splits diff --git a/fast_llm/data/sample/abstract.py b/fast_llm/data/sample/abstract.py index 1d71363b7..494a5c4a5 100644 --- a/fast_llm/data/sample/abstract.py +++ b/fast_llm/data/sample/abstract.py @@ -73,6 +73,13 @@ def expected_buffer_size(self) -> int: """ raise NotImplementedError() + def get_metadata(self) -> dict[str, typing.Any]: + raise NotImplementedError() + + @classmethod + def blend_metadata(cls, metadata: list[dict[str, typing.Any]]) -> dict[str, typing.Any]: + raise NotImplementedError() + @config_class(dynamic_type={MemmapReaderBaseConfig: "none"}) class NullReaderConfig(MemmapReaderBaseConfig): @@ -159,6 +166,13 @@ def reader_class(self) -> "type[MemmapIndexedDatasetReader]": def get_reader(self, buffer: memoryview, model_preprocessing: PreprocessingConfig) -> "MemmapIndexedDatasetReader": return self.reader_class(self, buffer, model_preprocessing) + def get_metadata(self) -> dict[str, typing.Any]: + return {"num_tokens": self.num_tokens} + + @classmethod + def blend_metadata(cls, metadata: list[dict[str, typing.Any]]) -> dict[str, typing.Any]: + return {"num_tokens": sum(metadata_["num_tokens"] for metadata_ in metadata)} + class MemmapReaderBase[ConfigType: MemmapReaderBaseConfig](Configurable[ConfigType]): @abc.abstractmethod @@ -196,6 +210,9 @@ def get_document_sizes(self) -> "torch.Tensor": def get_document_size(self, index: int) -> int: pass + def get_split(self, begin_ratio: float, end_ratio: float) -> tuple[int, int, dict[str, typing.Any]]: + raise NotImplementedError() + class MemmapWriter(abc.ABC): def __init__( diff --git a/fast_llm/data/sample/language_model.py b/fast_llm/data/sample/language_model.py index 3183a9ec1..22b89acf1 100644 --- a/fast_llm/data/sample/language_model.py +++ b/fast_llm/data/sample/language_model.py @@ -23,6 +23,7 @@ from fast_llm.data.sample.patch import ( EmptyPatchReader, PatchBatch, + PatchReader, PatchReaderBaseConfig, PatchReaderConfig, PatchSample, @@ -31,6 +32,7 @@ from fast_llm.data.sample.range import ( EmptyRangeReader, RangeBatch, + RangeReader, RangeReaderBaseConfig, RangeReaderConfig, RangeSample, @@ -222,6 +224,41 @@ def _expected_buffer_size(self) -> int: + self.image_patches.expected_buffer_size ) + def get_metadata(self) -> dict[str, typing.Any]: + out = super().get_metadata() + out["tokens"] = self.tokens.get_metadata() + if not isinstance(self.loss_masking_spans, NullReaderConfig): + out["loss_masking_spans"] = self.loss_masking_spans.get_metadata() + if not isinstance(self.chosen_spans, NullReaderConfig): + out["chosen_spans"] = self.chosen_spans.get_metadata() + if not isinstance(self.rejected_spans, NullReaderConfig): + out["rejected_spans"] = self.rejected_spans.get_metadata() + if not isinstance(self.image_patches, NullReaderConfig): + out["image_patches"] = self.image_patches.get_metadata() + return out + + @classmethod + def blend_metadata(cls, metadata: list[dict[str, typing.Any]]) -> dict[str, typing.Any]: + out = super().blend_metadata(metadata) + out["tokens"] = TokenReaderConfig.blend_metadata([metadata_["tokens"] for metadata_ in metadata]) + if "loss_masking_spans" in metadata[0]: + out["loss_masking_spans"] = RangeReaderConfig.blend_metadata( + [metadata_["loss_masking_spans"] for metadata_ in metadata] + ) + if "chosen_spans" in metadata[0]: + out["chosen_spans"] = RangeReaderConfig.blend_metadata( + [metadata_["chosen_spans"] for metadata_ in metadata] + ) + if "rejected_spans" in metadata[0]: + out["image_patches"] = RangeReaderConfig.blend_metadata( + [metadata_["image_patches"] for metadata_ in metadata] + ) + if "image_patches" in metadata[0]: + out["image_patches"] = PatchReaderConfig.blend_metadata( + [metadata_["image_patches"] for metadata_ in metadata] + ) + return out + class LanguageModelReader[ConfigType: LanguageModelReaderConfig](MemmapIndexedDatasetReader[ConfigType]): _model_preprocessing: LanguageModelPreprocessingConfig @@ -305,6 +342,23 @@ def get_document_sizes(self) -> torch.Tensor: def get_document_size(self, index: int) -> int: return self._tokens.get_document_size(index) + def get_split(self, begin_ratio: float, end_ratio: float) -> tuple[int, int, dict[str, typing.Any]]: + begin_index, end_index, token_metadata = self._tokens.get_split(begin_ratio, end_ratio) + metadata = { + "num_tokens": token_metadata["num_tokens"], + "tokens": token_metadata, + } + if hasattr(self, "_loss_masking_spans") and isinstance(self._loss_masking_spans, RangeReader): + metadata["loss_masking_spans"] = self._loss_masking_spans.get_split(begin_index, end_index) + if hasattr(self, "_chosen_spans") and isinstance(self._chosen_spans, RangeReader): + metadata["chosen_spans"] = self._chosen_spans.get_split(begin_index, end_index) + if hasattr(self, "_rejected_spans") and isinstance(self._rejected_spans, RangeReader): + metadata["rejected_spans"] = self._rejected_spans.get_split(begin_index, end_index) + if hasattr(self, "_image_patches") and isinstance(self._image_patches, PatchReader): + metadata["image_patches"] = self._image_patches.get_split(begin_index, end_index) + + return begin_index, end_index, metadata + class LanguageModelWriter(MemmapWriter): _preprocessing_config: LanguageModelPreprocessingConfig diff --git a/fast_llm/data/sample/patch.py b/fast_llm/data/sample/patch.py index 221746752..7ae537104 100644 --- a/fast_llm/data/sample/patch.py +++ b/fast_llm/data/sample/patch.py @@ -192,6 +192,27 @@ def _expected_buffer_size(self) -> int: * torch.int32.itemsize ) + def get_metadata(self) -> dict[str, typing.Any]: + return { + "num_documents": self.num_documents, + "num_patches": self.num_patches, + "num_patch_groups": self.num_patch_groups, + "num_pixels": self.patch_size * self.num_patches, + "patch_shape": self.patch_shape, + "data_type": str(self.data_type), + } + + @classmethod + def blend_metadata(cls, metadata: list[dict[str, typing.Any]]) -> dict[str, typing.Any]: + return { + "num_documents": sum(metadata_["num_documents"] for metadata_ in metadata), + "num_patches": sum(metadata_["num_patches"] for metadata_ in metadata), + "num_patch_groups": sum(metadata_["num_patch_groups"] for metadata_ in metadata), + "num_pixels": sum(metadata_["num_pixels"] for metadata_ in metadata), + "patch_shape": get_unique(metadata_["patch_shape"] for metadata_ in metadata), + "data_type": get_unique(metadata_["data_type"] for metadata_ in metadata), + } + class PatchReader[ConfigType: PatchReaderConfig](MemmapReader[ConfigType]): def __init__(self, config: ConfigType, buffer: memoryview, model_preprocessing: PreprocessingConfig | None = None): @@ -253,6 +274,19 @@ def get_document(self, index: int, begin: int, end: int) -> Sample: ), ) + def get_split(self, begin_index: int, end_index: int) -> dict[str, typing.Any]: + Assert.custom(lambda x: x == sorted(x), [0, begin_index, end_index, self._config.num_documents]) + num_patches = self._patch_count_cumsums[end_index].item() - self._patch_count_cumsums[begin_index].item() + return { + "num_documents": end_index - begin_index, + "num_patches": num_patches, + "num_patch_groups": self._group_count_cumsums[end_index].item() + - self._group_count_cumsums[begin_index].item(), + "num_pixels": self._config.patch_size * num_patches, + "patch_shape": self._config.patch_shape, + "data_type": str(self._config.data_type), + } + class EmptyPatchReader[ConfigType: PatchReaderBaseConfig](MemmapReaderBase[ConfigType]): def get_document(self, index: int, begin: int, end: int) -> Sample: diff --git a/fast_llm/data/sample/range.py b/fast_llm/data/sample/range.py index a77846725..53683342a 100644 --- a/fast_llm/data/sample/range.py +++ b/fast_llm/data/sample/range.py @@ -92,6 +92,19 @@ def writer_class(self) -> "type[RangeWriter]": def _expected_buffer_size(self) -> int: return self.num_ranges * torch.int32.itemsize * 2 + (self.num_documents + 1) * torch.int32.itemsize + def get_metadata(self) -> dict[str, typing.Any]: + return { + "num_documents": self.num_documents, + "num_ranges": self.num_ranges, + } + + @classmethod + def blend_metadata(cls, metadata: list[dict[str, typing.Any]]) -> dict[str, typing.Any]: + return { + "num_documents": sum(metadata_["num_documents"] for metadata_ in metadata), + "num_ranges": sum(metadata_["num_ranges"] for metadata_ in metadata), + } + class RangeReader[ConfigType: RangeReaderConfig](MemmapReader[ConfigType]): def __init__(self, config: ConfigType, buffer: memoryview, model_preprocessing: PreprocessingConfig | None = None): @@ -116,6 +129,13 @@ def get_document(self, index: int, begin: int, end: int) -> Sample: ) return RangeSample([(begin_, end_) for begin_, end_ in cropped_ranges if end_ > begin_], sample_size) + def get_split(self, begin_index: int, end_index: int) -> dict[str, typing.Any]: + Assert.custom(lambda x: x == sorted(x), [0, begin_index, end_index, self._config.num_documents]) + return { + "num_documents": end_index - begin_index, + "num_ranges": self._count_cumsums[end_index].item() - self._count_cumsums[begin_index].item(), + } + class EmptyRangeReader[ConfigType: RangeReaderBaseConfig](MemmapReaderBase[ConfigType]): def get_document(self, index: int, begin: int, end: int) -> Sample: diff --git a/fast_llm/data/sample/token.py b/fast_llm/data/sample/token.py index 1bc9ef1a1..cd4d7fa02 100644 --- a/fast_llm/data/sample/token.py +++ b/fast_llm/data/sample/token.py @@ -14,7 +14,7 @@ Sample, ) from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.utils import Assert +from fast_llm.utils import Assert, get_unique def crop_lengths(lengths: list[int], begin: int, end: int) -> list[int]: @@ -110,6 +110,21 @@ def writer_class(self) -> "type[TokenWriter]": def _expected_buffer_size(self) -> int: return self.num_tokens * self.data_type.torch.itemsize + (self.num_documents + 1) * torch.int64.itemsize + def get_metadata(self) -> dict[str, typing.Any]: + return { + "num_tokens": self.num_tokens, + "num_documents": self.num_documents, + "data_type": str(self.data_type), + } + + @classmethod + def blend_metadata(cls, metadata: list[dict[str, typing.Any]]) -> dict[str, typing.Any]: + return { + "num_tokens": sum(metadata_["num_tokens"] for metadata_ in metadata), + "num_documents": sum(metadata_["num_documents"] for metadata_ in metadata), + "data_type": get_unique(metadata_["data_type"] for metadata_ in metadata), + } + class TokenReader[ConfigType: TokenReaderConfig](MemmapIndexedDatasetReader[ConfigType]): def __init__(self, config: ConfigType, buffer: memoryview, model_preprocessing: PreprocessingConfig | None = None): @@ -135,6 +150,28 @@ def get_document_sizes(self) -> torch.Tensor: def get_document_size(self, index: int) -> int: return self._size_cumsums[index + 1].item() - self._size_cumsums[index].item() + def get_split(self, begin_ratio: float, end_ratio: float) -> tuple[int, int, dict[str, typing.Any]]: + Assert.custom(lambda x: x == sorted(x), [0, begin_ratio, end_ratio, 1]) + begin_index = _get_nearest_split(self._size_cumsums[1:], begin_ratio * self.num_tokens) + end_index = _get_nearest_split(self._size_cumsums[1:], end_ratio * self.num_tokens) + + return ( + begin_index, + end_index, + { + "num_tokens": self._size_cumsums[end_index].item() - self._size_cumsums[begin_index].item(), + "num_documents": end_index - begin_index, + "data_type": str(self._config.data_type), + }, + ) + + +def _get_nearest_split(cumsum: torch.Tensor, value: float) -> int: + left = torch.searchsorted(cumsum, value, side="right") + if left == len(cumsum): + return left.item() + return left.item() + 1 if (value - cumsum[left]) / (cumsum[left + 1] - cumsum[left]) > 0.5 else left.item() + class TokenWriter(MemmapWriter): def __enter__(self): From d2789452546b92223d4761a6d2c75eb8fac272c8 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 22 Dec 2025 14:08:11 -0500 Subject: [PATCH 06/12] Add support for gloo (#424) --- fast_llm/core/distributed.py | 37 +++++++++++++++------ fast_llm/engine/checkpoint/safe_load.py | 6 ++-- fast_llm/engine/distributed/config.py | 22 +++++++++++- fast_llm/engine/distributed/distributed.py | 15 +++++++-- fast_llm/layers/language_model/embedding.py | 3 ++ tests/conftest.py | 23 ++++++++++++- tests/models/distributed_test_checkpoint.py | 9 +++-- tests/models/distributed_test_model.py | 7 ++-- tests/utils/distributed_configs.py | 22 ++++++------ tests/utils/model_configs.py | 5 +++ 10 files changed, 115 insertions(+), 34 deletions(-) diff --git a/fast_llm/core/distributed.py b/fast_llm/core/distributed.py index 2a200688a..4dcc53d55 100644 --- a/fast_llm/core/distributed.py +++ b/fast_llm/core/distributed.py @@ -29,10 +29,15 @@ logger = logging.getLogger(__name__) -def add_ephemeral_timeout(group: ProcessGroup, timeout: float | None = None) -> None: +@contextlib.contextmanager +def set_timeout(group: ProcessGroup | None, timeout: float | None = None): if group is not None and timeout is not None: - # TODO: Only works for nccl? - group._add_ephemeral_timeout(datetime.timedelta(seconds=timeout)) + timeout_ = group.options._timeout + group.set_timeout(datetime.timedelta(seconds=timeout)) + yield + group.set_timeout(timeout_) + else: + yield def broadcast( @@ -43,8 +48,8 @@ def broadcast( opts = torch.distributed.BroadcastOptions() opts.rootRank = src opts.rootTensor = 0 - add_ephemeral_timeout(group, timeout) - work = group.broadcast([tensor], opts) + with set_timeout(group, timeout): + work = group.broadcast([tensor], opts) if async_op: return work else: @@ -55,7 +60,7 @@ def broadcast( def check_parallel_match(tensor: torch.Tensor, group: ProcessGroup | None, name: str) -> None: # A utility function to check for tensor-parallel (or other) mismatches. all_tensors = tensor.new_empty((group.size(),) + tensor.shape) - all_gather_into_tensor(all_tensors, tensor, group) + all_gather_into_tensor(all_tensors, tensor.unsqueeze(0), group) mismatches = (all_tensors != tensor).any(dim=0) num_mismatches = mismatches.sum().item() @@ -84,8 +89,8 @@ def allreduce_scalar( ) -> float | int: if group: value = torch.full([1], value, dtype=dtype, device=torch.cuda.current_device()) - add_ephemeral_timeout(group, timeout) - torch.distributed.all_reduce(value, op=op, group=group) + with set_timeout(group, timeout): + torch.distributed.all_reduce(value, op=op, group=group) return value.item() else: return value @@ -99,9 +104,9 @@ def all_gather_scalar( ): if group: value = torch.full([1], value, dtype=dtype, device=torch.cuda.current_device()) - add_ephemeral_timeout(group, timeout) output_tensor = value.new_empty((group.size(),)) - torch.distributed.all_gather_into_tensor(output_tensor, value, group=group) + with set_timeout(group, timeout): + torch.distributed.all_gather_into_tensor(output_tensor, value, group=group) return output_tensor.tolist() else: return value @@ -147,6 +152,12 @@ def broadcast_object(input_object: typing.Any | None, group: ProcessGroup | None def send(tensor: torch.Tensor, dst: int, group: ProcessGroup, async_op=False, tag: int = 0) -> Work | None: assert group is not None + if isinstance(group, torch.distributed.ProcessGroupGloo) and tensor.device.type != "cpu": + # send not supported for gloo on GPU. + tensor_cpu = tensor.cpu() + group.send([tensor_cpu], dst, tag).wait() + tensor.copy_(tensor_cpu) + return None work = group.send([tensor], dst, tag) if async_op: return work @@ -157,6 +168,12 @@ def send(tensor: torch.Tensor, dst: int, group: ProcessGroup, async_op=False, ta def recv(tensor: torch.Tensor, src: int, group: ProcessGroup, async_op=False, tag: int = 0) -> Work | None: assert group is not None + if isinstance(group, torch.distributed.ProcessGroupGloo) and tensor.device.type != "cpu": + # recv not supported for gloo on GPU. + tensor_cpu = tensor.cpu() + group.recv([tensor_cpu], src, tag).wait() + tensor.copy_(tensor_cpu) + return None work = group.recv([tensor], src, tag) if async_op: return work diff --git a/fast_llm/engine/checkpoint/safe_load.py b/fast_llm/engine/checkpoint/safe_load.py index 2e2a01881..d3f72a47c 100644 --- a/fast_llm/engine/checkpoint/safe_load.py +++ b/fast_llm/engine/checkpoint/safe_load.py @@ -4,7 +4,7 @@ import torch from torch.distributed import all_reduce -from fast_llm.core.distributed import add_ephemeral_timeout +from fast_llm.core.distributed import set_timeout from fast_llm.engine.multi_stage.config import ShardName from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.functional.triton.pointwise import triton_fill @@ -171,8 +171,8 @@ def _check_parameters(self, errors: list[str]) -> None: if self._distributed.world_group is not None: counter_tensor = torch.tensor(counters, dtype=torch.int64).to(self._distributed.device) # This may be the first distributed barrier after loading, so we need to wait for everyone to finish. - add_ephemeral_timeout(self._distributed.world_group, self._timeout) - all_reduce(counter_tensor, group=self._distributed.world_group) + with set_timeout(self._distributed.world_group, self._timeout): + all_reduce(counter_tensor, group=self._distributed.world_group) counters = counter_tensor.tolist() # Compare global counts against expected values. diff --git a/fast_llm/engine/distributed/config.py b/fast_llm/engine/distributed/config.py index f4dab5a26..7f4b7bc38 100644 --- a/fast_llm/engine/distributed/config.py +++ b/fast_llm/engine/distributed/config.py @@ -40,7 +40,7 @@ MAX_SEED = 2**64 -class PhaseType(str, enum.Enum): +class PhaseType(enum.StrEnum): training = "Training" validation = "Validation" test = "Test" @@ -51,6 +51,21 @@ def is_training(self) -> bool: return self == PhaseType.training +class DistributedBackend(enum.StrEnum): + nccl = "nccl" + gloo = "gloo" + + @property + def process_group_class(self) -> type["ProcessGroup"]: + import torch + + return ( + torch.distributed.ProcessGroupNCCL + if self == DistributedBackend.nccl + else torch.distributed.ProcessGroupGloo + ) + + @dataclasses.dataclass class DistributedDim: """ @@ -175,6 +190,11 @@ class DistributedConfig(Config): desc="Prioritize the pipeline groups for placement of nearby ranks over data groups.", hint=FieldHint.expert, ) + backend: DistributedBackend = Field( + default=DistributedBackend.nccl, + desc="The distributed backend to use.", + hint=FieldHint.expert, + ) timeout: float = Field( default=60, desc="Timeout for distributed operations.", diff --git a/fast_llm/engine/distributed/distributed.py b/fast_llm/engine/distributed/distributed.py index 302cfcdce..aa2be6ce7 100644 --- a/fast_llm/engine/distributed/distributed.py +++ b/fast_llm/engine/distributed/distributed.py @@ -8,6 +8,7 @@ from fast_llm.core.distributed import ProcessGroup from fast_llm.engine.distributed.config import ( MAX_SEED, + DistributedBackend, DistributedConfig, DistributedDim, DistributedDimNames, @@ -27,6 +28,7 @@ def __init__( local_world_size: int | None = None, timeout: float = 60, use_cpu: bool = False, + backend: DistributedBackend = DistributedBackend.nccl, ): self._rank = DistributedConfig.default_rank if rank is None else rank @@ -36,10 +38,12 @@ def __init__( ) self._timeout = timeout self._use_cpu = use_cpu + self._backend = backend self._process_groups = {} if self._use_cpu: - Assert.eq(self._world_size, 1) + if backend == DistributedBackend.nccl: + Assert.eq(self._world_size, 1) self._device = torch.device("cpu") else: Assert.in_range_incl(self._local_world_size, 1, torch.cuda.device_count()) @@ -77,6 +81,10 @@ def local_world_size(self): def device(self): return self._device + @property + def backend(self): + return self._backend + def get_process_group(self, global_ranks: range | tuple, group_rank: int) -> ProcessGroup | None: """ Get the requested process group from the pool, or create it if it doesn't exist. @@ -100,8 +108,7 @@ def get_process_group(self, global_ranks: range | tuple, group_rank: int) -> Pro if isinstance(global_ranks, range) else f"ranks_{"_".join(str(rank) for rank in global_ranks)}" ) - - group = torch.distributed.ProcessGroupNCCL( + group = self._backend.process_group_class( torch.distributed.PrefixStore(prefix + "/", self.store), group_rank, group_size, @@ -157,6 +164,7 @@ def __init__(self, config: DistributedConfig, use_cpu: bool = False): self._config.local_world_size, self._config.timeout, use_cpu, + self._config.backend, ) else: self._pool = _default_pool @@ -164,6 +172,7 @@ def __init__(self, config: DistributedConfig, use_cpu: bool = False): Assert.eq(self._pool.rank, self._config.rank) Assert.geq(self._pool.local_world_size, self._config.local_world_size) Assert.eq(self._pool.device.type, "cpu" if use_cpu else "cuda") + Assert.eq(self._pool.backend, self._config.backend) self.world_group = self.add_group(self._config.distributed_dims[DistributedDimNames.world]) self.data_group = self.add_group(self._config.distributed_dims[DistributedDimNames.data]) diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index fc8794b5a..93850d24c 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -106,6 +106,9 @@ def _forward( if self._sequence_parallel: embeddings = split(embeddings, group=group, dim=0) + if isinstance(group, torch.distributed.ProcessGroupGloo): + # Somehow needed is some rare cases to prevent autograd from complaining, ex. in `stp2_pp2s1_bf4`. + embeddings = embeddings.clone() else: if self._sequence_parallel: token_ids = split(token_ids, group=group, dim=0) diff --git a/tests/conftest.py b/tests/conftest.py index 58301919f..ba2927c64 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -47,7 +47,18 @@ def pytest_addoption(parser): group = parser.getgroup("fast_llm") group.addoption("--skip-slow", action="store_true") group.addoption("--show-skipped", action="store_true") - group.addoption("--show-gpu-memory", type=int, default=10) + group.addoption( + "--show-gpu-memory", + type=int, + default=10, + help="Show resource usage stats for the tests and distributed subtests with the highest GPU memory usage.", + ) + group.addoption( + "--show-durations", + type=int, + default=None, + help="Show resource usage stats for the slowest tests and distributed subtests.", + ) group.addoption("--no-distributed-capture", dest="distributed_capture", action="store_false") group.addoption("--models", nargs="*") group.addoption( @@ -229,6 +240,16 @@ def pytest_terminal_summary(terminalreporter): for nodeid in sorted_nodeids[: terminalreporter.config.getoption("--show-gpu-memory")]: terminalreporter.write_line(format_resource_report(nodeid, resource_reports[nodeid])) + if (show_durations := terminalreporter.config.getoption("--show-durations")) is not None: + terminalreporter.write_sep("=", "Highest durations", bold=True) + sorted_nodeids = sorted( + resource_reports.keys(), + key=lambda nodeid: (resource_reports[nodeid]["duration"] if "duration" in resource_reports[nodeid] else 0), + reverse=True, + ) + for nodeid in sorted_nodeids[:show_durations]: + terminalreporter.write_line(format_resource_report(nodeid, resource_reports[nodeid])) + def pytest_runtest_call(item: pytest.Function): if torch.cuda.is_available(): diff --git a/tests/models/distributed_test_checkpoint.py b/tests/models/distributed_test_checkpoint.py index 001eb36da..407946545 100644 --- a/tests/models/distributed_test_checkpoint.py +++ b/tests/models/distributed_test_checkpoint.py @@ -12,7 +12,7 @@ DistributedCheckpointFormat, FastLLMCheckpointFormat, ) -from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.distributed.config import DistributedBackend, DistributedConfig from fast_llm.engine.distributed.distributed import ProcessGroupPool from fast_llm.engine.multi_stage.config import StageMode from fast_llm.utils import Assert, header @@ -37,7 +37,7 @@ def _test_load_and_save_parallel( model = model_testing_config.model_class.from_pretrained( load_config, # The world size and rank are already set through environment variable. - {"distributed": config.distributed}, + {"distributed": {**config.distributed, "backend": model_testing_config.distributed_backend}}, mode=StageMode.inference, ) for save_format in (DistributedCheckpointFormat, FastLLMCheckpointFormat): @@ -56,7 +56,10 @@ def main(args: list[str] | None = None) -> None: "Capturing output and forwarding to associated tests. Run with `--no-distributed-capture` to disable." ) - with ProcessGroupPool(timeout=20) as pool: + with ProcessGroupPool( + timeout=20, + backend=DistributedBackend(model_testing_config.distributed_backend), + ) as pool: failures = [] world_size = DistributedConfig.default_world_size rank = DistributedConfig.default_rank diff --git a/tests/models/distributed_test_model.py b/tests/models/distributed_test_model.py index 890a75077..29b68366d 100644 --- a/tests/models/distributed_test_model.py +++ b/tests/models/distributed_test_model.py @@ -2,7 +2,7 @@ from fast_llm.cli import fast_llm_main_wrapper from fast_llm.core.distributed import safe_barrier -from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.distributed.config import DistributedBackend, DistributedConfig from fast_llm.engine.distributed.distributed import ProcessGroupPool from tests.utils.distributed_configs import DISTRIBUTED_TESTING_CONFIGS from tests.utils.run_test_script import do_run_test_script_for_all_models, parse_run_distributed_script @@ -20,7 +20,10 @@ def main(args: list[str] | None = None) -> None: ) # TODO: Why are barriers needed? - with ProcessGroupPool(timeout=60) as pool: + with ProcessGroupPool( + timeout=60, + backend=DistributedBackend(model_testing_config.distributed_backend), + ) as pool: failures = [] world_size = DistributedConfig.default_world_size rank = DistributedConfig.default_rank diff --git a/tests/utils/distributed_configs.py b/tests/utils/distributed_configs.py index 83ed6836a..9c1cc9369 100644 --- a/tests/utils/distributed_configs.py +++ b/tests/utils/distributed_configs.py @@ -222,6 +222,17 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon num_gpus=2, compare_config=_compare_layer_match, ), + # Depth-first micro-batches, tensor-parallel + DistributedTestingConfig( + name="tp2_df4", + compare="df4", + config_args=[ + "model.distributed.tensor_parallel=2", + "batch.depth_first_micro_batches=4", + ], + num_gpus=2, + compare_config=_compare_layer_match, + ), # Cross-entropy splits DistributedTestingConfig( name="stp2_ce4", @@ -247,17 +258,6 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon num_gpus=4, compare_config=_compare_layer_match, ), - # Depth-first micro-batches, tensor-parallel - DistributedTestingConfig( - name="tp2_df4", - compare="df4", - config_args=[ - "model.distributed.tensor_parallel=2", - "batch.depth_first_micro_batches=4", - ], - num_gpus=4, - compare_config=_compare_layer_match, - ), # Breadth-first micro-batches DistributedTestingConfig( name="sdp2_stp2_bf4", diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 6156cb709..1248a1117 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -147,6 +147,10 @@ def model_class(self): def base_model_config_class(self): return self.model_config_class.get_base_model_config_class() + @functools.cached_property + def distributed_backend(self): + return self.config_dict["model"]["distributed"]["backend"] + def should_skip(self, distributed_config: DistributedTestingConfig) -> bool: return any(re.search(pattern, distributed_config.name) for pattern in self.skip_tests) @@ -254,6 +258,7 @@ def _update_and_add_testing_config( "distributed": { "reproducible_init": True, "timeout": 20, + "backend": "nccl", }, }, "batch": {"batch_size": 8, "sequence_length": 512}, From c8a73df7fa005c1a85598fd53e28b75d5648f405 Mon Sep 17 00:00:00 2001 From: Oleksiy Ostapenko Date: Tue, 23 Dec 2025 07:48:41 -0800 Subject: [PATCH 07/12] Fix for layer distillation with sequence parallel training (#431) --- fast_llm/layers/decoder/block.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/fast_llm/layers/decoder/block.py b/fast_llm/layers/decoder/block.py index 0e3d6f0c0..f5abd1f6d 100644 --- a/fast_llm/layers/decoder/block.py +++ b/fast_llm/layers/decoder/block.py @@ -194,6 +194,29 @@ def activation_distillation_loss(self, hidden_states, bias, kwargs, losses, metr per_token_loss = torch.norm( mixer_output - teacher_tensor, p=2, dim=-1 ) # (batch, sequence) or (sequence, batch) + + # Slice mask to match per_token_loss shape (for sequence parallelism) + # When sequence_tensor_parallel is enabled, per_token_loss only has local sequence length + if mask.shape != per_token_loss.shape: + # Calculate the sequence offset for this rank using the hidden_dims parallel rank + hidden_dims = kwargs.get(BlockKwargs.hidden_dims) + seq_dim_idx = 0 if sequence_first else 1 + hidden_seq_dim = hidden_dims[seq_dim_idx] if hidden_dims else None + + if hidden_seq_dim and hidden_seq_dim.parallel_dim: + # Use the rank from the actual parallel dimension used by hidden states + local_seq_length = per_token_loss.shape[0] if sequence_first else per_token_loss.shape[1] + seq_offset = hidden_seq_dim.parallel_dim.rank * local_seq_length + else: + seq_offset = 0 + + if sequence_first: + # mask: (sequence, batch), per_token_loss: (local_sequence, batch) + mask = mask[seq_offset : seq_offset + per_token_loss.shape[0], :] + else: + # mask: (batch, sequence), per_token_loss: (batch, local_sequence) + mask = mask[:, seq_offset : seq_offset + per_token_loss.shape[1]] + masked_loss = per_token_loss * mask local_loss_sum = torch.sum(masked_loss) total_count = int(mask.sum().item()) From 44b14accca7e8b25bbf7929512d10edb4b32210f Mon Sep 17 00:00:00 2001 From: Oleksiy Ostapenko Date: Tue, 23 Dec 2025 09:45:55 -0800 Subject: [PATCH 08/12] Reverse KL: more efficient implementation + normalisation by sequence length (#430) --- fast_llm/functional/cross_entropy.py | 42 ++++++++++++-------------- tests/functional/test_cross_entropy.py | 4 ++- 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 8c9ea9399..a12516b5d 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -227,6 +227,7 @@ def distributed_log_softmax( return logits_norm - sum_exp_logits.log() # log_softmax +@torch.compile def _reverse_kl_forward_backward( logits: torch.Tensor, target: torch.Tensor, @@ -259,24 +260,21 @@ def _reverse_kl_forward_backward( if loss_mask is not None: Assert.eq(loss_mask.shape, logits.shape[:-1]) - # Compute log probabilities teacher_log_probs = distributed_log_softmax(target.float(), group=group) - student_log_probs = distributed_log_softmax(logits, group=group) - - # Reverse KL: input=teacher_log_probs, target=student_probs - loss_terms = torch.nn.functional.kl_div( - teacher_log_probs, # input = log(p) - student_log_probs, # target = log(q) - reduction="none", - log_target=True, - ).sum(dim=-1) + log_ratio = distributed_log_softmax(logits, group=group) + + student_probs = log_ratio.exp() + log_ratio = log_ratio - teacher_log_probs # In-place: log_ratio = student_log_probs - teacher_log_probs + del teacher_log_probs + # Compute loss terms: student_probs * log_ratio, then sum over vocab + # This is equivalent to kl_div(..., log_target=True) but more memory efficient + loss_terms = (student_probs * log_ratio).sum(dim=-1) + if loss_mask is not None: # loss mask is the same on all ranks for TP over vocab. valid = loss_mask.to(loss_terms.dtype) loss_terms = loss_terms * valid - valid_tokens = valid.sum() - else: - valid_tokens = torch.prod(torch.tensor(loss_terms.shape, device=loss_terms.device, dtype=loss_terms.dtype)) + valid_tokens = torch.prod(torch.tensor(loss_terms.shape, device=loss_terms.device, dtype=loss_terms.dtype)) loss = loss_terms.sum() # sums over batch and seq. len. if group is not None: @@ -284,20 +282,20 @@ def _reverse_kl_forward_backward( loss /= valid_tokens if grad_output is not None: - # need to calculate gradient manually, backprop through all reduce can be problematic, see https://github.com/pytorch/pytorch/issues/58005 - log_ratio = student_log_probs - teacher_log_probs - expected = torch.sum(torch.exp(student_log_probs) * log_ratio, dim=-1, keepdim=True) - # expected E_q(log s - log t) -- this is actually dependent on the full vocab! + # Gradient: d/d(logits) KL(q||p) = q * (log(q/p) - E_q[log(q/p)]) + # where E_q[log(q/p)] is the expected log ratio under the student distribution + expected = torch.sum(student_probs * log_ratio, dim=-1, keepdim=True) if group is not None: all_reduce(expected, op=ReduceOp.SUM, group=group) - grad_base = torch.exp(student_log_probs) * (log_ratio - expected) + log_ratio = log_ratio - expected + log_ratio = log_ratio * student_probs + del student_probs # Free after use if loss_mask is not None: - valid = loss_mask.to(logits.dtype).unsqueeze(-1) - grad_base = grad_base * valid + log_ratio = log_ratio * loss_mask.to(logits.dtype).unsqueeze(-1) - grad = grad_base.mul(grad_output / valid_tokens) - grad = grad.to(logits.dtype) + log_ratio = log_ratio * (grad_output / valid_tokens) + grad = log_ratio.to(logits.dtype) else: grad = None diff --git a/tests/functional/test_cross_entropy.py b/tests/functional/test_cross_entropy.py index 72644d061..20d16bb96 100644 --- a/tests/functional/test_cross_entropy.py +++ b/tests/functional/test_cross_entropy.py @@ -104,7 +104,9 @@ def _reverse_kl_forward_backward_torch(logits: torch.Tensor, target: torch.Tenso reduction="none", log_target=True, ).sum(dim=-1) - output = per_sample.mean() if loss_mask is None else (per_sample * loss_mask).sum() / loss_mask.sum() + if loss_mask is not None: + per_sample = per_sample * loss_mask + output = per_sample.mean() output.backward() return output, logits.grad From 4543b399a907901fff85ffd9be5723e9073326ca Mon Sep 17 00:00:00 2001 From: Torsten Scholak Date: Mon, 5 Jan 2026 13:54:02 -0500 Subject: [PATCH 09/12] Refactor Apriel2 cache, add Qwen2 converter, and conversation format for SFT (#422) Co-authored-by: Claude Opus 4.5 Co-authored-by: bigximik --- .github/ISSUE_TEMPLATE/feature_request.md | 20 +- fast_llm/data/preparator/gpt_memmap/config.py | 124 +- .../data/preparator/gpt_memmap/prepare.py | 187 +-- fast_llm/data/preprocessing/tokenizer.py | 102 ++ fast_llm/models/gpt/conversion/apriel2.py | 131 +- fast_llm/models/gpt/conversion/qwen2.py | 69 + .../models/multimodal/conversion/apriel2.py | 9 +- fast_llm_external_models/apriel2/cache.py | 312 ++-- .../apriel2/conversion/__init__.py | 195 +-- .../apriel2/conversion/config.py | 251 +++- .../apriel2/conversion/converters.py | 604 +++++--- .../apriel2/conversion/executor.py | 11 +- .../apriel2/conversion/expr.py | 22 +- .../apriel2/conversion/io.py | 7 +- .../apriel2/conversion/llava/plan.py | 7 +- .../apriel2/conversion/qwen2/__init__.py | 6 + .../apriel2/conversion/qwen2/config.py | 79 ++ .../apriel2/conversion/qwen2/plan.py | 100 ++ .../apriel2/conversion/render.py | 28 +- fast_llm_external_models/apriel2/convert.py | 55 +- .../apriel2/examples/prepare_tulu3.yaml | 103 ++ .../examples/train_supernet_qwen2.yaml | 193 +++ .../examples/train_supernet_small.yaml | 2 +- .../apriel2/modeling_apriel2.py | 130 +- .../tests/test_apriel2/conftest.py | 503 ++++++- .../tests/test_apriel2/test_cache.py | 1258 ----------------- .../test_cache_apriel2_specific.py | 341 +++++ .../test_apriel2/test_cache_contracts.py | 591 ++++++++ .../tests/test_apriel2/test_causal_conv1d.py | 23 +- .../test_apriel2/test_compose_configs.py | 426 +++--- ...tion_torture.py => test_conversion_e2e.py} | 621 +++----- .../test_apriel2/test_convert_from_llava.py | 18 +- .../tests/test_apriel2/test_equivalence.py | 15 +- .../tests/test_apriel2/test_expr_plan.py | 405 ++++-- .../tests/test_apriel2/test_integration.py | 330 +++++ .../test_apriel2/test_mixer_equivalence.py | 108 +- .../test_apriel2/test_model_structure.py | 69 +- .../tests/test_apriel2/test_modeling.py | 58 +- .../tests/test_apriel2/test_plan_execution.py | 598 ++++++++ setup.py | 6 +- tests/data/test_tokenizer.py | 260 ++++ 41 files changed, 5567 insertions(+), 2810 deletions(-) create mode 100644 fast_llm_external_models/apriel2/conversion/qwen2/__init__.py create mode 100644 fast_llm_external_models/apriel2/conversion/qwen2/config.py create mode 100644 fast_llm_external_models/apriel2/conversion/qwen2/plan.py create mode 100644 fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml create mode 100644 fast_llm_external_models/apriel2/examples/train_supernet_qwen2.yaml delete mode 100644 fast_llm_external_models/tests/test_apriel2/test_cache.py create mode 100644 fast_llm_external_models/tests/test_apriel2/test_cache_apriel2_specific.py create mode 100644 fast_llm_external_models/tests/test_apriel2/test_cache_contracts.py rename fast_llm_external_models/tests/test_apriel2/{test_plan_composition_torture.py => test_conversion_e2e.py} (78%) create mode 100644 fast_llm_external_models/tests/test_apriel2/test_integration.py create mode 100644 fast_llm_external_models/tests/test_apriel2/test_plan_execution.py diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md index 50c5a2c1c..a09f78c6c 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.md +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -8,26 +8,26 @@ assignees: '' --- # 🎯 **Goal (What & Why)** -> **Clearly state the purpose of this feature.** +> **Clearly state the purpose of this feature.** > _(Example: Add FP8 support using torchao to improve training throughput by 1.5x.)_ # 🚀 **Execution Plan** -> _(This section may start as an incomplete draft but must be defined before implementation begins.)_ +> _(This section may start as an incomplete draft but must be defined before implementation begins.)_ ### **Step 1: What is the smallest working version?** -> _(Describe the simplest way to implement this feature with minimal effort.)_ +> _(Describe the simplest way to implement this feature with minimal effort.)_ -### **Step 2: What additional optimizations are possible (but optional)?** -> _(List potential refinements that can be added in later PRs if needed.)_ +### **Step 2: What additional optimizations are possible (but optional)?** +> _(List potential refinements that can be added in later PRs if needed.)_ # 📌 **Acceptance Criteria** (Must-Haves for Completion) -* The feature must be **functional and tested**. -* The implementation must be **documented in practical terms**. -* The PR must include a **performance/impact summary**. -* **No refactors unless directly necessary** for feature completion. +* The feature must be **functional and tested**. +* The implementation must be **documented in practical terms**. +* The PR must include a **performance/impact summary**. +* **No refactors unless directly necessary** for feature completion. # 🛠️ **Project Management** - [ ] **Assign the project to the Fast-LLM project.** - [ ] **Set the `Estimate` field (in days) in the GitHub project.** - [ ] **Use the `Size` field to categorize the PR size (Small/Medium/Large).** -- [ ] **Assign an owner when opening the issue.** +- [ ] **Assign an owner when opening the issue.** diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 503b400c3..a1aadf40a 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -15,30 +15,81 @@ from fast_llm.data.preparator.gpt_memmap.prepare import GPTMemmapDatasetPreparator -@config_class() +@config_class(registry=True) class LanguageModelSourceConfig(Config): """ - A schema holding the name of each relevant column in the dataset. - Setting optional entries will enable the associated feature. + Abstract base class for data source schemas. + + Use `type: document` (default) for documents with text, optional span annotations, and optional images. + Use `type: conversation` for structured chat/conversation datasets. + """ + + @classmethod + def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: + if cls is LanguageModelSourceConfig and cls.get_subclass(default.get("type")) is None: + # Default to DocumentSourceConfig when type is not specified + return DocumentSourceConfig._from_dict(default, strict) + return super()._from_dict(default, strict=strict) + + @functools.cached_property + def columns(self) -> list[str]: + """Columns to read from the dataset.""" + raise NotImplementedError + + @functools.cached_property + def has_loss_masking_span(self) -> bool: + return False + + @functools.cached_property + def has_preference_spans(self) -> bool: + return False + + @functools.cached_property + def has_images(self) -> bool: + return False + + +@config_class(dynamic_type={LanguageModelSourceConfig: "document"}) +class DocumentSourceConfig(LanguageModelSourceConfig): + """ + Source schema for document datasets with text, optional span annotations, and optional images. + + The dataset should have a text column containing the document text. + Optionally, it can have additional columns for: + - Loss masking spans: character ranges to mask from loss computation + - Preference spans: chosen/rejected text for DPO training + - Images: image data with character positions for multimodal training """ text: str = Field( default="text", - desc="Field of the dataset to use.", + desc="Field containing the document text.", + hint=FieldHint.optional, + ) + loss_masking_spans: str | None = Field( + default=None, + desc="Field containing character spans to mask for loss computation.", hint=FieldHint.optional, ) - loss_masking_spans: None | str = Field( - default=None, desc="Field containing character spans to mask for loss computation", hint=FieldHint.optional + chosen_span: str | None = Field( + default=None, + desc="Field containing chosen text for preference optimization.", + hint=FieldHint.optional, ) - chosen_span: None | str = Field( - default=None, desc="Field containing chosen text for preference optimization", hint=FieldHint.optional + rejected_span: str | None = Field( + default=None, + desc="Field containing rejected text for preference optimization.", + hint=FieldHint.optional, ) - rejected_span: None | str = Field( - default=None, desc="Field containing rejected text for preference optimization", hint=FieldHint.optional + images: str | None = Field( + default=None, + desc="Field containing images.", + hint=FieldHint.optional, ) - images: None | str = Field(default=None, desc="Field containing images", hint=FieldHint.optional) - image_positions: None | str = Field( - default=None, desc="Field containing image positions in the text.", hint=FieldHint.optional + image_positions: str | None = Field( + default=None, + desc="Field containing image positions in the text.", + hint=FieldHint.optional, ) @functools.cached_property @@ -48,6 +99,8 @@ def columns(self) -> list[str]: columns.append(self.loss_masking_spans) if self.has_preference_spans: columns.extend([self.chosen_span, self.rejected_span]) + if self.has_images: + columns.extend([self.images, self.image_positions]) return columns @functools.cached_property @@ -67,7 +120,50 @@ def has_images(self) -> bool: def _validate(self): super()._validate() if self.has_preference_spans and self.has_loss_masking_span: - raise ValueError(f"Can not enable both loss masking and preference spans.") + raise ValueError("Cannot enable both loss masking and preference spans.") + + +@config_class(dynamic_type={LanguageModelSourceConfig: "conversation"}) +class ConversationSourceConfig(LanguageModelSourceConfig): + """ + Source schema for chat/conversation datasets (e.g., Tulu 3, ShareGPT, OpenAI format). + + The dataset should have a messages column containing a list of message dicts, + where each message has 'role' and 'content' keys. Common roles include: + - 'system': System prompt + - 'user': User input + - 'assistant': Model response (trained on by default) + - 'tool': Tool/function results + - 'ipython': Code execution results + + The conversation is formatted using the tokenizer's chat template, which must + contain {% generation %}...{% endgeneration %} markers to define which content + to train on. Loss masking spans are automatically computed from these markers. + + Example dataset format: + { + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello!"}, + {"role": "assistant", "content": "Hi there!"}, + ] + } + """ + + messages: str = Field( + default="messages", + desc="Field containing the conversation messages list. Each message should have 'role' and 'content' keys.", + hint=FieldHint.core, + ) + + @functools.cached_property + def columns(self) -> list[str]: + return [self.messages] + + @functools.cached_property + def has_loss_masking_span(self) -> bool: + # Conversation format always generates loss masking spans from chat template markers + return True @config_class() diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index e0f5f02fc..325d33c43 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -28,7 +28,12 @@ ) from fast_llm.data.dataset.memmap import MemmapDataset from fast_llm.data.preparator.config import DatasetPreparator -from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig, LanguageModelSourceConfig +from fast_llm.data.preparator.gpt_memmap.config import ( + ConversationSourceConfig, + DocumentSourceConfig, + GPTMemmapDatasetPreparatorConfig, + LanguageModelSourceConfig, +) from fast_llm.data.preprocessing.abstract import NullPreprocessingConfig from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.data.preprocessing.tokenizer import Tokenizer @@ -132,6 +137,10 @@ def run(self) -> None: # Load tokenizer self._tokenizer = self._config.tokenizer.get_tokenizer() + # Validate chat template for conversation format + if isinstance(self._source_schema, ConversationSourceConfig): + self._tokenizer.validate_chat_template() + # Decide the datatype based on the tokenizer vocabulary size self._data_type = ( get_unsigned_integer_type(self._tokenizer.vocab_size) @@ -216,92 +225,112 @@ def _preprocessing_config(self) -> LanguageModelPreprocessingConfig: ) def _prepare_sample(self, sample: dict[str, typing.Any]) -> LanguageModelSample: - text = sample[self._source_schema.text] - all_spans = [] - if self._source_schema.has_loss_masking_span: - # Spans are typically stored in the (begin, last) format. We convert to (begin, end) range format. - loss_masking_spans = _sort_spans( - (SpanType.loss_masking, (begin, last + 1)) - for begin, last in np.array(sample[self._source_schema.loss_masking_spans], dtype=np.int32) - .reshape(-1, 2) - .tolist() + token_spans_by_type = collections.defaultdict(list) + image_patches = image_token_maps = image_position_ids = patch_counts = None + + if isinstance(self._source_schema, ConversationSourceConfig): + # Conversation format: tokenize messages and get loss masking spans from chat template + tokens, loss_masking_spans = self._tokenizer.tokenize_chat( + sample[self._source_schema.messages], + True, + True, + data_type=self._data_type, ) - all_spans.extend(loss_masking_spans) - - if self._source_schema.has_preference_spans: - full_chosen_text = text + sample[self._source_schema.chosen_span] + self._tokenizer.tokenizer.eos_token - full_rejected_text = self._tokenizer.tokenizer.bos_token + text + sample[self._source_schema.rejected_span] - # compute chosen span - chosen_spans = [(SpanType.chosen, (len(text), len(full_chosen_text)))] - - # compute rejected span - rejected_span = [ - ( - SpanType.rejected, - ( - len(full_chosen_text) + len(self._tokenizer.tokenizer.bos_token) + len(text), - len(full_chosen_text) + len(full_rejected_text), - ), + token_spans_by_type[SpanType.loss_masking] = loss_masking_spans + elif isinstance(self._source_schema, DocumentSourceConfig): + # Document format: use the text-spans pipeline + text = sample[self._source_schema.text] + all_spans = [] + + if self._source_schema.has_loss_masking_span: + # Spans are typically stored in the (begin, last) format. We convert to (begin, end) range format. + loss_masking_spans = _sort_spans( + (SpanType.loss_masking, (begin, last + 1)) + for begin, last in np.array(sample[self._source_schema.loss_masking_spans], dtype=np.int32) + .reshape(-1, 2) + .tolist() + ) + all_spans.extend(loss_masking_spans) + + if self._source_schema.has_preference_spans: + full_chosen_text = text + sample[self._source_schema.chosen_span] + self._tokenizer.tokenizer.eos_token + full_rejected_text = ( + self._tokenizer.tokenizer.bos_token + text + sample[self._source_schema.rejected_span] ) - ] - # pack texts - text = full_chosen_text + full_rejected_text - all_spans.extend(chosen_spans + rejected_span) - - if self._source_schema.has_images: - # Get the images and positions, sorted by position. - images, image_positions = ( - zip( - *sorted( - zip( - sample[self._source_schema.images], - sample[self._source_schema.image_positions], - strict=True, + # compute chosen span + chosen_spans = [(SpanType.chosen, (len(text), len(full_chosen_text)))] + + # compute rejected span + rejected_span = [ + ( + SpanType.rejected, + ( + len(full_chosen_text) + len(self._tokenizer.tokenizer.bos_token) + len(text), + len(full_chosen_text) + len(full_rejected_text), ), - key=lambda x: x[1], ) + ] + # pack texts + text = full_chosen_text + full_rejected_text + all_spans.extend(chosen_spans + rejected_span) + + if self._source_schema.has_images: + # Get the images and positions, sorted by position. + images, image_positions = ( + zip( + *sorted( + zip( + sample[self._source_schema.images], + sample[self._source_schema.image_positions], + strict=True, + ), + key=lambda x: x[1], + ) + ) + if len(sample[self._source_schema.images]) > 0 + else ([], []) ) - if len(sample[self._source_schema.images]) > 0 - else ([], []) - ) - # Get the image patches and associated data. - image_patches, image_position_ids, image_token_maps, image_token_ids, patch_counts = ( - self._config.image_patches.get_patches_from_images(images, self._data_type) + # Get the image patches and associated data. + image_patches, image_position_ids, image_token_maps, image_token_ids, patch_counts = ( + self._config.image_patches.get_patches_from_images(images, self._data_type) + ) + patch_count_cumsum = padded_cumsum(patch_counts).tolist() + # Add an empty "span" at each image position so we know where to insert them in the tokenized sequence. + all_spans.extend([(SpanType.image, (position, position)) for position in image_positions]) + + # Sort the spans by location (begin), keeping track of their type. + # Note: overlapping spans are not supported (explicit assertion in the tokenizer). + span_types, spans = zip(*_sort_spans(all_spans)) if all_spans else ([], []) + # Tokenize the text, and determine the span locations in the tokenized text. + tokens, token_spans = self._tokenizer.tokenize_with_spans( + text, True, True, text_spans=spans, data_type=self._data_type ) - patch_count_cumsum = padded_cumsum(patch_counts).tolist() - # Add an empty "span" at each image position so we know where to insert them in the tokenized sequence. - all_spans.extend([(SpanType.image, (position, position)) for position in image_positions]) - - # Sort the spans by location (begin), keeping track of their type. - # Note: overlapping spans are not supported (explicit assertion in the tokenizer). - span_types, spans = zip(*_sort_spans(all_spans)) if all_spans else ([], []) - # Tokenize the text, and determine the span locations in the tokenized text. - tokens, token_spans = self._tokenizer.tokenize_with_spans( - text, True, True, text_spans=spans, data_type=self._data_type - ) - # Gather token spans by type. - token_spans_by_type = collections.defaultdict(list) - if self._source_schema.has_images: - # Insert the image token ids in the token sequence and shift the spans accordingly. - tokens_shift = 0 - image_index = 0 - for span_type, (begin, end) in zip(span_types, token_spans, strict=True): - # Account for the tokens already inserted. - begin = begin + tokens_shift - end = end + tokens_shift - if span_type == SpanType.image: - # Shift the token map to the image location. - image_token_maps[patch_count_cumsum[image_index] : patch_count_cumsum[image_index + 1]] += begin - # Insert the placeholder and image break tokens. - tokens = torch.cat([tokens[:begin], image_token_ids[image_index], tokens[begin:]]) - tokens_shift += len(image_token_ids[image_index]) - image_index += 1 - else: - token_spans_by_type[span_type].append((begin, end)) + # Gather token spans by type. + if self._source_schema.has_images: + # Insert the image token ids in the token sequence and shift the spans accordingly. + tokens_shift = 0 + image_index = 0 + for span_type, (begin, end) in zip(span_types, token_spans, strict=True): + # Account for the tokens already inserted. + begin = begin + tokens_shift + end = end + tokens_shift + if span_type == SpanType.image: + # Shift the token map to the image location. + image_token_maps[ + patch_count_cumsum[image_index] : patch_count_cumsum[image_index + 1] + ] += begin + # Insert the placeholder and image break tokens. + tokens = torch.cat([tokens[:begin], image_token_ids[image_index], tokens[begin:]]) + tokens_shift += len(image_token_ids[image_index]) + image_index += 1 + else: + token_spans_by_type[span_type].append((begin, end)) + else: + for span_type, token_span in zip(span_types, token_spans, strict=True): + token_spans_by_type[span_type].append(token_span) else: - for span_type, token_span in zip(span_types, token_spans, strict=True): - token_spans_by_type[span_type].append(token_span) + raise NotImplementedError(f"Unsupported source schema type: {type(self._source_schema)}") sample_size = len(tokens) diff --git a/fast_llm/data/preprocessing/tokenizer.py b/fast_llm/data/preprocessing/tokenizer.py index abfb5b3d2..157744f51 100644 --- a/fast_llm/data/preprocessing/tokenizer.py +++ b/fast_llm/data/preprocessing/tokenizer.py @@ -213,3 +213,105 @@ def _remove_delimiters( @property def eod(self): return self.eod_id + + @staticmethod + def _has_generation_markers(template: str | None) -> bool: + """Check if a template has generation markers.""" + return template is not None and "{% generation %}" in template + + def validate_chat_template(self) -> None: + """ + Validate the tokenizer's chat template has generation markers. + + Raises: + ValueError: If the tokenizer lacks a chat template or generation markers. + """ + template = self.tokenizer.chat_template + + if template is None: + raise ValueError( + "Tokenizer does not have a chat template. " + "Conversation format requires a tokenizer with a built-in chat template " + "containing {% generation %}...{% endgeneration %} markers." + ) + + if not self._has_generation_markers(template): + raise ValueError( + "Tokenizer's chat template does not contain {% generation %}...{% endgeneration %} markers. " + "These markers are required to determine which tokens to train on. " + "Please use a tokenizer with generation markers in its chat template." + ) + + def tokenize_chat( + self, + messages: list[dict[str, str]], + begin: bool = True, + end: bool = True, + data_type: DataType = DataType.int64, + ) -> tuple["torch.Tensor", list[tuple[int, int]]]: + """ + Apply chat template and return (tokens, loss_masking_spans). + + The loss_masking_spans mark token ranges to EXCLUDE from training (where the model + should not learn). These are derived from the chat template's generation markers - + tokens outside {% generation %}...{% endgeneration %} blocks are masked. + """ + import torch + + result = self.tokenizer.apply_chat_template( + messages, + tokenize=True, + return_assistant_tokens_mask=True, + return_dict=True, + add_generation_prompt=False, + ) + tokens = result["input_ids"] + train_mask = result["assistant_masks"] + + # Prepend BOS / append EOS if not already present anywhere in the sequence. + # We check anywhere (not just first/last) because some chat templates add trailing + # whitespace after the final EOS token, e.g. "<|im_end|>\n". + prepend_bos = begin and self.bod_id not in tokens + append_eos = end and self.eod_id not in tokens + tokens = [self.bod_id] * prepend_bos + list(tokens) + [self.eod_id] * append_eos + train_mask = [False] * prepend_bos + [bool(m) for m in train_mask] + [False] * append_eos + + # Convert boolean train mask to loss masking spans (spans where train_mask[i] == False) + loss_masking_spans = _train_mask_to_loss_spans(train_mask) + + if self._config.max_vocab_size is not None: + tokens = ( + torch.tensor( + tokens, + dtype=torch.int64 if len(self.tokenizer) > torch.iinfo(data_type.torch).max else data_type.torch, + ) + % self._config.max_vocab_size + ).to(data_type.torch) + else: + tokens = torch.tensor(tokens, dtype=data_type.torch) + return tokens, loss_masking_spans + + +def _train_mask_to_loss_spans(train_mask: list[bool]) -> list[tuple[int, int]]: + """ + Convert a boolean train mask to loss masking spans. + + Args: + train_mask: Boolean list where True = train on this token, False = don't train + + Returns: + List of (begin, end) spans marking token ranges to EXCLUDE from training + (i.e., where train_mask[i] == False). + """ + spans = [] + start = None + for i, should_train in enumerate(train_mask): + if not should_train: + if start is None: + start = i + elif start is not None: + spans.append((start, i)) + start = None + if start is not None: + spans.append((start, len(train_mask))) + return spans diff --git a/fast_llm/models/gpt/conversion/apriel2.py b/fast_llm/models/gpt/conversion/apriel2.py index 4ed588ed5..91e3be508 100644 --- a/fast_llm/models/gpt/conversion/apriel2.py +++ b/fast_llm/models/gpt/conversion/apriel2.py @@ -39,8 +39,20 @@ def import_config(cls, config: dict) -> dict: "head_groups": config["head_groups"], "head_size": config["head_size"], "rotary": rotary, - "add_linear_biases": config["add_linear_biases"], } + # Per-layer bias configuration mirroring Fast-LLM structure + # If per-layer configs exist, use them; otherwise fall back to add_linear_biases + if "query_layer" in config: + result["query_layer"] = config["query_layer"] + if "key_layer" in config: + result["key_layer"] = config["key_layer"] + if "value_layer" in config: + result["value_layer"] = config["value_layer"] + if "dense_layer" in config: + result["dense_layer"] = config["dense_layer"] + # add_linear_biases serves as default for layers without explicit config + if "add_linear_biases" in config: + result["add_linear_biases"] = config["add_linear_biases"] if "window_size" in config: result["window_size"] = config["window_size"] return result @@ -58,18 +70,37 @@ def export_config(cls, config: AttentionConfig) -> dict: else: raise NotImplementedError(f"Unsupported rotary type: {type(config.rotary).__name__}") - return { + result = { "type": "attention", "heads": config.heads, "head_groups": config.head_groups, "head_size": config.head_size, - "add_linear_biases": config.add_linear_biases, "rotary": { "type": rotary_type, "theta": config.rotary.theta, }, "window_size": config.window_size, } + # Export per-layer bias configuration + # Only include if explicitly set (not None) + if config.query_layer.bias.enabled is not None: + result["query_layer"] = {"bias": {"enabled": config.query_layer.bias.enabled}} + if config.key_layer.bias.enabled is not None: + result["key_layer"] = {"bias": {"enabled": config.key_layer.bias.enabled}} + if config.value_layer.bias.enabled is not None: + result["value_layer"] = {"bias": {"enabled": config.value_layer.bias.enabled}} + if config.dense_layer.bias.enabled is not None: + result["dense_layer"] = {"bias": {"enabled": config.dense_layer.bias.enabled}} + # add_linear_biases as fallback default + result["add_linear_biases"] = config.add_linear_biases + return result + + @classmethod + def _get_effective_bias(cls, layer_config, default: bool) -> bool: + """Get effective bias setting: use layer-specific if set, else default.""" + if layer_config.bias.enabled is not None: + return layer_config.bias.enabled + return default @classmethod def get_converters( @@ -79,11 +110,20 @@ def get_converters( hf_prefix: str, drop_on_export: bool = False, ) -> list[WeightConverter]: + # Determine effective bias for each projection + q_bias = cls._get_effective_bias(config.query_layer, config.add_linear_biases) + k_bias = cls._get_effective_bias(config.key_layer, config.add_linear_biases) + v_bias = cls._get_effective_bias(config.value_layer, config.add_linear_biases) + o_bias = cls._get_effective_bias(config.dense_layer, config.add_linear_biases) + # For key_value, both k and v must have same bias setting + # (they're combined in Fast-LLM's key_value layer) + kv_bias = k_bias and v_bias + return [ *get_weight_and_bias_converters( f"{fast_llm_prefix}.query", f"{hf_prefix}.q_proj", - config.add_linear_biases, + q_bias, QueryWeightConverter, config, drop_on_export=drop_on_export, @@ -91,7 +131,7 @@ def get_converters( *get_weight_and_bias_converters( f"{fast_llm_prefix}.key_value", (f"{hf_prefix}.k_proj", f"{hf_prefix}.v_proj"), - config.add_linear_biases, + kv_bias, KeyValueWeightConverter, config, drop_on_export=drop_on_export, @@ -99,7 +139,7 @@ def get_converters( *get_weight_and_bias_converters( f"{fast_llm_prefix}.dense", f"{hf_prefix}.o_proj", - config.add_linear_biases, + o_bias, drop_on_export=drop_on_export, ), ] @@ -524,6 +564,12 @@ def import_config(cls, config: dict, block_config: dict) -> dict: "gated": mlp_config["gated"], "add_linear_biases": mlp_config["add_linear_biases"], } + # Import per-layer MLP bias settings (layer_1, layer_2) + for layer_name in ("layer_1", "layer_2"): + if layer_name in mlp_config: + layer_cfg = mlp_config[layer_name] + if "bias" in layer_cfg: + mlp[layer_name] = {"bias": layer_cfg["bias"]} normalization = block_config["normalization"] @@ -578,6 +624,11 @@ def export_config(cls, config: DecoderBlockConfig) -> dict: "gated": config.mlp.gated, "add_linear_biases": config.mlp.add_linear_biases, } + # Export per-layer MLP bias settings (layer_1, layer_2) + if config.mlp.layer_1.bias.enabled is not None: + mlp["layer_1"] = {"bias": {"enabled": config.mlp.layer_1.bias.enabled}} + if config.mlp.layer_2.bias.enabled is not None: + mlp["layer_2"] = {"bias": {"enabled": config.mlp.layer_2.bias.enabled}} normalization = {"type": norm_type_str, "epsilon": config.normalization.epsilon} @@ -624,24 +675,56 @@ def get_converters( ) ) - converters.extend( - [ - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_1", - (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), - config.mlp.add_linear_biases, - SplitWeightConverter, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_2", - f"{hf_prefix}.mlp.down_proj", - config.mlp.add_linear_biases, - MLPLayer2Converter, - drop_on_export=drop_on_export, - ), - ] - ) + # Per-layer MLP bias: use layer-specific setting if set, else default + def get_mlp_layer_bias(layer_config, default: bool) -> bool: + if layer_config.bias.enabled is not None: + return layer_config.bias.enabled + return default + + layer_1_bias = get_mlp_layer_bias(config.mlp.layer_1, config.mlp.add_linear_biases) + layer_2_bias = get_mlp_layer_bias(config.mlp.layer_2, config.mlp.add_linear_biases) + + if config.mlp.gated: + # Gated MLP: gate_proj + up_proj -> layer_1 (split), down_proj -> layer_2 + converters.extend( + [ + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_1", + (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), + layer_1_bias, + SplitWeightConverter, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_2", + f"{hf_prefix}.mlp.down_proj", + layer_2_bias, + MLPLayer2Converter, + drop_on_export=drop_on_export, + ), + ] + ) + else: + # Non-gated MLP: up_proj -> layer_1, down_proj -> layer_2 + # Note: layer_2 still needs MLPLayer2Converter for the transpose + converters.extend( + [ + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_1", + f"{hf_prefix}.mlp.up_proj", + layer_1_bias, + WeightConverter, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_2", + f"{hf_prefix}.mlp.down_proj", + layer_2_bias, + MLPLayer2Converter, + drop_on_export=drop_on_export, + ), + ] + ) converters.extend( [ diff --git a/fast_llm/models/gpt/conversion/qwen2.py b/fast_llm/models/gpt/conversion/qwen2.py index a8bc33454..4ebf18c3a 100644 --- a/fast_llm/models/gpt/conversion/qwen2.py +++ b/fast_llm/models/gpt/conversion/qwen2.py @@ -1,15 +1,21 @@ import typing from fast_llm.engine.checkpoint.config import CheckpointFormat +from fast_llm.engine.checkpoint.external import WeightConverter from fast_llm.layers.attention.config import AttentionConfig +from fast_llm.layers.decoder.mlp.config import MLPConfig from fast_llm.models.gpt.conversion.config import Qwen2CheckpointFormat from fast_llm.models.gpt.conversion.llama import ( + KeyValueWeightConverter, LlamaAttentionConverter, LlamaBaseModelConverter, LlamaBlockConverter, LlamaDecoderConverter, LlamaHeadConverter, LlamaHuggingfaceCheckpointHandler, + LlamaMLPConverter, + QueryWeightConverter, + get_weight_and_bias_converters, ) from fast_llm.utils import Assert @@ -17,6 +23,22 @@ class Qwen2AttentionConverter(LlamaAttentionConverter): # TODO: Support sliding window with max_window_layers (need 2 kinds of block?) + @classmethod + def import_config(cls, config: dict) -> dict: + config["attention_bias"] = True + out = super().import_config(config) + out["query_layer"] = {"bias": {"enabled": True}} + out["key_layer"] = {"bias": {"enabled": True}} + out["value_layer"] = {"bias": {"enabled": True}} + out["dense_layer"] = {"bias": {"enabled": False}} + return out + + @classmethod + def export_config(cls, config: AttentionConfig) -> dict: + out = super().export_config(config) + del out["attention_bias"] + return out + @classmethod def _check_config(cls, config: AttentionConfig) -> None: Assert.is_(type(config), AttentionConfig) @@ -32,9 +54,56 @@ def _check_config(cls, config: AttentionConfig) -> None: Assert.is_(config.value_layer.bias.enabled, True) Assert.incl(config.dense_layer.bias.enabled, (None, False)) + @classmethod + def get_converters( + cls, + config: AttentionConfig, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + ) -> list[WeightConverter]: + return [ + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.query", + f"{hf_prefix}.q_proj", + True, + QueryWeightConverter, + config, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.key_value", + (f"{hf_prefix}.k_proj", f"{hf_prefix}.v_proj"), + True, + KeyValueWeightConverter, + config, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.dense", + f"{hf_prefix}.o_proj", + False, + drop_on_export=drop_on_export, + ), + ] + + +class Qwen2MLPConverter(LlamaMLPConverter): + @classmethod + def import_config(cls, config: dict) -> dict: + config["mlp_bias"] = False + return super().import_config(config) + + @classmethod + def export_config(cls, config: MLPConfig) -> dict: + out = super().export_config(config) + del out["mlp_bias"] + return out + class Qwen2BlockConverter(LlamaBlockConverter): mixer_converter_class: typing.ClassVar[type[Qwen2AttentionConverter]] = Qwen2AttentionConverter + mlp_converter_class: typing.ClassVar[type[Qwen2MLPConverter]] = Qwen2MLPConverter class Qwen2DecoderConverter(LlamaDecoderConverter): diff --git a/fast_llm/models/multimodal/conversion/apriel2.py b/fast_llm/models/multimodal/conversion/apriel2.py index b4147a8bf..307a67c63 100644 --- a/fast_llm/models/multimodal/conversion/apriel2.py +++ b/fast_llm/models/multimodal/conversion/apriel2.py @@ -326,9 +326,7 @@ class Apriel2MultimodalBaseModelConverter: @classmethod def import_config(cls, config: dict) -> dict: text_config = Apriel2BaseModelConverter.import_config(config) - vision_config = ( - cls.vision_model_converter_class.import_config(config) if "vision_encoder" in config else None - ) + vision_config = cls.vision_model_converter_class.import_config(config) if "vision_encoder" in config else None result = safe_merge_dicts( text_config, @@ -388,10 +386,7 @@ def get_transformers_configuration_class(cls): @classmethod def get_model_files(cls) -> tuple[str, str, str | None]: - from fast_llm_external_models.apriel2 import ( - configuration_apriel2, - modeling_apriel2, - ) + from fast_llm_external_models.apriel2 import configuration_apriel2, modeling_apriel2 return configuration_apriel2.__file__, modeling_apriel2.__file__, None diff --git a/fast_llm_external_models/apriel2/cache.py b/fast_llm_external_models/apriel2/cache.py index 86c67a085..f83ae87d6 100644 --- a/fast_llm_external_models/apriel2/cache.py +++ b/fast_llm_external_models/apriel2/cache.py @@ -1,17 +1,22 @@ from __future__ import annotations + import torch from transformers.cache_utils import Cache class _AttentionCache: - __slots__ = ["key", "value", "window"] + __slots__ = ["key", "value", "window", "cumulative_length"] def __init__(self, window=None): self.key = None self.value = None self.window = window + self.cumulative_length = 0 def update(self, key, value): + new_tokens = key.shape[-2] + self.cumulative_length += new_tokens + if self.key is None: if self.window and key.shape[-2] > self.window: self.key = key[..., -self.window :, :].contiguous() @@ -35,6 +40,40 @@ def _window(self, cache, new): return cache return torch.cat([cache, new], -2)[..., -self.window :, :].contiguous() + def reset(self): + self.key = None + self.value = None + self.cumulative_length = 0 + + def reorder(self, beam_idx): + if self.key is not None: + self.key = self.key.index_select(0, beam_idx.to(self.key.device)) + self.value = self.value.index_select(0, beam_idx.to(self.value.device)) + + def crop(self, max_length): + if self.key is not None: + self.key = self.key[..., :max_length, :] + self.value = self.value[..., :max_length, :] + self.cumulative_length = self.key.shape[-2] + + def batch_repeat(self, repeats): + if self.key is not None: + self.key = self.key.repeat_interleave(repeats, dim=0) + self.value = self.value.repeat_interleave(repeats, dim=0) + + def batch_select(self, indices): + if self.key is not None: + self.key = self.key.index_select(0, indices.to(self.key.device)) + self.value = self.value.index_select(0, indices.to(self.value.device)) + + @property + def is_initialized(self): + return self.key is not None + + @property + def batch_size(self): + return self.key.shape[0] if self.key is not None else None + class _SSMCache: __slots__ = ["conv", "recurrent"] @@ -43,6 +82,52 @@ def __init__(self): self.conv = None self.recurrent = None + def reset(self): + self.conv = None + self.recurrent = None + + def reorder(self, beam_idx): + if self.conv is not None: + if isinstance(self.conv, tuple): + self.conv = tuple(c.index_select(0, beam_idx.to(c.device)) for c in self.conv) + else: + self.conv = self.conv.index_select(0, beam_idx.to(self.conv.device)) + if self.recurrent is not None: + self.recurrent = self.recurrent.index_select(0, beam_idx.to(self.recurrent.device)) + + def crop(self, max_length): + pass # SSM caches don't have sequence dimension to crop + + def batch_repeat(self, repeats): + if self.conv is not None: + if isinstance(self.conv, tuple): + self.conv = tuple(c.repeat_interleave(repeats, dim=0) for c in self.conv) + else: + self.conv = self.conv.repeat_interleave(repeats, dim=0) + if self.recurrent is not None: + self.recurrent = self.recurrent.repeat_interleave(repeats, dim=0) + + def batch_select(self, indices): + if self.conv is not None: + if isinstance(self.conv, tuple): + self.conv = tuple(c.index_select(0, indices.to(c.device)) for c in self.conv) + else: + self.conv = self.conv.index_select(0, indices.to(self.conv.device)) + if self.recurrent is not None: + self.recurrent = self.recurrent.index_select(0, indices.to(self.recurrent.device)) + + @property + def is_initialized(self): + return self.conv is not None + + @property + def batch_size(self): + if self.conv is None: + return None + if isinstance(self.conv, tuple): + return self.conv[0].shape[0] + return self.conv.shape[0] + class _DummyCacheLayer: pass @@ -93,14 +178,19 @@ def set_active_mixer(self, layer_idx, mixer_name): self.active_mixers[layer_idx] = mixer_name def get_seq_length(self, layer_idx=0): + """Returns the cumulative sequence length of tokens seen by the cache. + + For sliding window caches, this returns the total tokens seen (not just cached). + This matches HuggingFace's DynamicSlidingWindowLayer behavior. + """ layer = self.layers[layer_idx] if isinstance(layer, dict): mixer = self.active_mixers[layer_idx] if mixer and isinstance(layer[mixer], _AttentionCache): - return layer[mixer].key.shape[-2] if layer[mixer].key is not None else 0 + return layer[mixer].cumulative_length return 0 if isinstance(layer, _AttentionCache): - return layer.key.shape[-2] if layer.key is not None else 0 + return layer.cumulative_length return 0 def get_max_cache_shape(self, layer_idx=0): @@ -114,22 +204,61 @@ def get_max_cache_shape(self, layer_idx=0): return None def get_mask_sizes(self, cache_position, layer_idx): + """Return the length and offset of the cache, used to generate the attention mask. + + For standard (non-sliding) attention: + kv_offset = 0 (KV[0] corresponds to sequence position 0) + kv_length = cumulative_length + query_length + + For sliding window attention: + kv_offset = max(cumulative_length - window + 1, 0) + kv_length = min(cumulative_length, window - 1) + query_length + + For SSM/linear layers: + kv_offset = 0, kv_length = query_length (no KV cache to attend to) + """ query_length = cache_position.shape[0] - past_seen_tokens = self.get_seq_length(layer_idx) - kv_length = query_length + past_seen_tokens - kv_offset = past_seen_tokens - return kv_length, kv_offset + layer = self.layers[layer_idx] + + # Handle stochastic layers by getting the active mixer's cache + if isinstance(layer, dict): + mixer = self.active_mixers[layer_idx] + if mixer is None: + # No active mixer set, return defaults + return query_length, 0 + cache = layer[mixer] + else: + cache = layer + + # SSM layers don't have KV cache for attention mask purposes + if isinstance(cache, _SSMCache): + return query_length, 0 + + # Attention cache - check if sliding window + if isinstance(cache, _AttentionCache): + cumulative = cache.cumulative_length + window = cache.window + + if window is not None: + # Sliding window attention + kv_offset = max(cumulative - window + 1, 0) + if cumulative >= window: + kv_length = window - 1 + query_length + else: + kv_length = cumulative + query_length + else: + # Full attention + kv_offset = 0 + kv_length = cumulative + query_length + + return kv_length, kv_offset + + # Fallback + return query_length, 0 @property def has_previous_state(self): - for layer in self.layers: - if isinstance(layer, dict): - for cache in layer.values(): - if isinstance(cache, _SSMCache) and cache.conv is not None: - return True - elif isinstance(layer, _SSMCache) and layer.conv is not None: - return True - return False + return any(isinstance(cache, _SSMCache) and cache.conv is not None for cache in self._iter_caches()) @property def key_cache(self): @@ -147,101 +276,33 @@ def conv_states(self): def recurrent_states(self): return _LayerListAccessor(self, "recurrent") - def reorder_cache(self, beam_idx): - for i, layer in enumerate(self.layers): + def _iter_caches(self): + """Iterate over all leaf cache objects (flattening stochastic layer dicts).""" + for layer in self.layers: if isinstance(layer, dict): - for cache in layer.values(): - self._reorder_cache_obj(cache, beam_idx) + yield from layer.values() else: - self._reorder_cache_obj(layer, beam_idx) + yield layer - def _reorder_cache_obj(self, cache, beam_idx): - if isinstance(cache, _AttentionCache): - if cache.key is not None: - cache.key = cache.key.index_select(0, beam_idx.to(cache.key.device)) - cache.value = cache.value.index_select(0, beam_idx.to(cache.value.device)) - elif isinstance(cache, _SSMCache): - if cache.conv is not None: - # Handle both single tensor (GDN/Mamba) and tuple (KDA) conv states - if isinstance(cache.conv, tuple): - cache.conv = tuple(c.index_select(0, beam_idx.to(c.device)) for c in cache.conv) - else: - cache.conv = cache.conv.index_select(0, beam_idx.to(cache.conv.device)) - if cache.recurrent is not None: - cache.recurrent = cache.recurrent.index_select(0, beam_idx.to(cache.recurrent.device)) + def reorder_cache(self, beam_idx): + for cache in self._iter_caches(): + cache.reorder(beam_idx) def reset(self): - for layer in self.layers: - if isinstance(layer, dict): - for cache in layer.values(): - self._reset_cache_obj(cache) - else: - self._reset_cache_obj(layer) - - def _reset_cache_obj(self, cache): - if isinstance(cache, _AttentionCache): - cache.key = None - cache.value = None - elif isinstance(cache, _SSMCache): - cache.conv = None - cache.recurrent = None + for cache in self._iter_caches(): + cache.reset() def crop(self, max_length): - for layer in self.layers: - if isinstance(layer, dict): - for cache in layer.values(): - if isinstance(cache, _AttentionCache) and cache.key is not None: - cache.key = cache.key[..., :max_length, :] - cache.value = cache.value[..., :max_length, :] - elif isinstance(layer, _AttentionCache) and layer.key is not None: - layer.key = layer.key[..., :max_length, :] - layer.value = layer.value[..., :max_length, :] + for cache in self._iter_caches(): + cache.crop(max_length) def batch_repeat_interleave(self, repeats): - for layer in self.layers: - if isinstance(layer, dict): - for cache in layer.values(): - self._batch_repeat_cache_obj(cache, repeats) - else: - self._batch_repeat_cache_obj(layer, repeats) - - def _batch_repeat_cache_obj(self, cache, repeats): - if isinstance(cache, _AttentionCache): - if cache.key is not None: - cache.key = cache.key.repeat_interleave(repeats, dim=0) - cache.value = cache.value.repeat_interleave(repeats, dim=0) - elif isinstance(cache, _SSMCache): - if cache.conv is not None: - # Handle both single tensor (GDN/Mamba) and tuple (KDA) conv states - if isinstance(cache.conv, tuple): - cache.conv = tuple(c.repeat_interleave(repeats, dim=0) for c in cache.conv) - else: - cache.conv = cache.conv.repeat_interleave(repeats, dim=0) - if cache.recurrent is not None: - cache.recurrent = cache.recurrent.repeat_interleave(repeats, dim=0) + for cache in self._iter_caches(): + cache.batch_repeat(repeats) def batch_select_indices(self, indices): - for layer in self.layers: - if isinstance(layer, dict): - for cache in layer.values(): - self._batch_select_cache_obj(cache, indices) - else: - self._batch_select_cache_obj(layer, indices) - - def _batch_select_cache_obj(self, cache, indices): - if isinstance(cache, _AttentionCache): - if cache.key is not None: - cache.key = cache.key.index_select(0, indices.to(cache.key.device)) - cache.value = cache.value.index_select(0, indices.to(cache.value.device)) - elif isinstance(cache, _SSMCache): - if cache.conv is not None: - # Handle both single tensor (GDN/Mamba) and tuple (KDA) conv states - if isinstance(cache.conv, tuple): - cache.conv = tuple(c.index_select(0, indices.to(c.device)) for c in cache.conv) - else: - cache.conv = cache.conv.index_select(0, indices.to(cache.conv.device)) - if cache.recurrent is not None: - cache.recurrent = cache.recurrent.index_select(0, indices.to(cache.recurrent.device)) + for cache in self._iter_caches(): + cache.batch_select(indices) @property def is_compileable(self): @@ -249,19 +310,7 @@ def is_compileable(self): @property def is_initialized(self): - for layer in self.layers: - if isinstance(layer, dict): - for cache in layer.values(): - if isinstance(cache, _AttentionCache) and cache.key is not None: - return True - if isinstance(cache, _SSMCache) and cache.conv is not None: - return True - else: - if isinstance(layer, _AttentionCache) and layer.key is not None: - return True - if isinstance(layer, _SSMCache) and layer.conv is not None: - return True - return False + return any(cache.is_initialized for cache in self._iter_caches()) @property def is_sliding(self): @@ -280,39 +329,20 @@ def is_sliding(self): @property def max_batch_size(self): - for layer in self.layers: - if isinstance(layer, dict): - for cache in layer.values(): - if isinstance(cache, _AttentionCache) and cache.key is not None: - return cache.key.shape[0] - if isinstance(cache, _SSMCache) and cache.conv is not None: - # Handle both single tensor and tuple conv states - if isinstance(cache.conv, tuple): - return cache.conv[0].shape[0] - return cache.conv.shape[0] - else: - if isinstance(layer, _AttentionCache) and layer.key is not None: - return layer.key.shape[0] - if isinstance(layer, _SSMCache) and layer.conv is not None: - # Handle both single tensor and tuple conv states - if isinstance(layer.conv, tuple): - return layer.conv[0].shape[0] - return layer.conv.shape[0] + for cache in self._iter_caches(): + bs = cache.batch_size + if bs is not None: + return bs return None @property def max_cache_len(self): - max_len = None - for layer in self.layers: - if isinstance(layer, dict): - for cache in layer.values(): - if isinstance(cache, _AttentionCache): - if cache.window is not None: - max_len = cache.window if max_len is None else min(max_len, cache.window) - elif isinstance(layer, _AttentionCache): - if layer.window is not None: - max_len = layer.window if max_len is None else min(max_len, layer.window) - return max_len + windows = [ + cache.window + for cache in self._iter_caches() + if isinstance(cache, _AttentionCache) and cache.window is not None + ] + return min(windows) if windows else None def __len__(self): return len(self.layers) diff --git a/fast_llm_external_models/apriel2/conversion/__init__.py b/fast_llm_external_models/apriel2/conversion/__init__.py index 983a632e0..2c28d1e87 100644 --- a/fast_llm_external_models/apriel2/conversion/__init__.py +++ b/fast_llm_external_models/apriel2/conversion/__init__.py @@ -1,88 +1,138 @@ """Weight conversion system for Apriel2 models. -Architecture Overview -===================== +Overview +======== -This package implements a declarative weight transformation system with two -orthogonal concerns: +This package implements a declarative weight transformation system. The core +abstraction separates config composition (structural) from plan execution (weights). -1. **Config Composition** - Structural transformations of model configs -2. **Plan Building & Execution** - Weight transformations between configs +Conceptual Types +================ -These concerns are intentionally separated: -- Config composition determines WHAT the target architecture looks like -- Plan building determines HOW weights are transformed to match -- The `init` field bridges them: it's config metadata consumed by the plan builder +All configs are ``dict``, but we distinguish three conceptual types: -Key Design Decisions -==================== +**State (S)** - A complete model config without ``init`` fields. + What you load from disk or save after conversion. -**Declarative Plans** - Plans are DATA (JSON-serializable expressions), not functions. This enables: - - Inspection and debugging of transformations - - Serialization for distributed execution - - Composition via substitution rather than function composition - -**Separation of Config and Weights** - The `init` field in surgery specs controls weight handling (transfer vs random) - but does NOT affect config composition. Config composition is purely structural. - After composition, `init` fields are stripped from complete configs. - -**Composition Semantics** - Surgery specs use declarative (merge) composition, not operational (function) - composition. For "additive" surgeries (modifying existing structure), the - monoid action law holds. For "replacement" surgeries (defining complete new - structure), sequential application differs from composed application by design. - -**Cross-Type Derivation** - When converting between mixer types (e.g., attention → mamba), geometric - parameters are derived where possible: - - attention.heads → mamba dimensions (MIL conversion) - - attention.heads → gdn heads (DIL conversion) +**Partial Surgery (P)** - An incomplete config specifying changes. + May contain ``init`` fields (``transfer`` or ``random``). -Module Structure -================ +**Transition Spec (T)** - A complete config WITH ``init`` fields. + The result of applying surgery to a state. Describes both target + structure and weight initialization mode. + +Algebraic Structure +=================== + +**Monoid**: Partial surgeries compose via deep merge:: + + compose_configs : P × P → P + +**Action**: Surgeries act on states to produce transition specs:: + + compose_configs : S × P → T + compose_configs : T × P → T + +**Extraction**: Strip init to get a state:: + + strip_init_fields : T → S + +**Planning**: Build weight transformation from source state + transition spec:: + + plan_surgery : S × T → Plan -- `config.py` - Config composition (compose_configs, apply_surgery) -- `converters.py` - Plan builders (plan_surgery, plan_mil_attention_to_mamba, etc.) -- `expr.py` - Expression types and plan class (Ref, Slice, Concat, Init, ExprPlan) -- `executor.py` - Plan execution (StreamingExecutor, execute) -- `io.py` - Streaming I/O (SafetensorLoader, ShardedSafetensorWriter) -- `llava/` - Source-specific converter for Llava → Apriel2 +The ``init`` Field +================== -Example Usage +The ``init`` field in surgeries specifies weight initialization: + +- ``init: transfer`` → transfer/convert weights from source +- ``init: random`` → randomly initialize weights + +This field is preserved through ``compose_configs`` so ``plan_surgery`` can read it. +Use ``strip_init_fields`` before saving configs to disk. + +Typical Usage ============= +:: + from fast_llm_external_models.apriel2.conversion import ( compose_configs, plan_surgery, + strip_init_fields, execute, ) - # 1. Compose configs to get target architecture - target_config = compose_configs(source_config, surgery_spec) + # Load source state + source_state = load_config(...) # S - # 2. Build plan for weight transformation - plan = plan_surgery(source_config, surgery_spec) + # Apply surgery + surgery = {"decoder": {"block": {"mixer": {"type": "gdn", "init": "random"}}}} # P + transition = compose_configs(source_state, surgery) # T - # 3. Execute plan to transform weights - target_weights = execute(plan, source_weights, seed=42) + # Build and execute plan + plan = plan_surgery(source_state, transition) + weights = execute(plan, source_weights, seed=42) -For streaming I/O with large models: + # Save (strip init first) + target_state = strip_init_fields(transition) # S + save_config(target_state) - from fast_llm_external_models.apriel2.conversion import ( - StreamingExecutor, - SafetensorLoader, - ShardedSafetensorWriter, - ) +For chained surgeries:: + + current_state = source_state # S + current_plan = identity_plan + + for surgery in surgery_chain: # each P + transition = compose_configs(current_state, surgery) # T + plan = plan_surgery(current_state, transition) + current_plan = compose(current_plan, plan) + current_state = strip_init_fields(transition) # S <- IMPORTANT! + +**Note**: The ``strip_init_fields`` call is critical. It ensures that ``init: random`` +applies only to the surgery that introduces a component. Without stripping, subsequent +surgeries would re-randomize existing components. See ``config.py`` docstring for details. + +Key Design Decisions +==================== + +**Declarative Plans** + Plans are data (expressions), not functions. Enables inspection, + serialization, and composition via substitution. + +**Inheritance Semantics** + When S × P → T, unspecified fields inherit from source. + Cross-type derivation maps geometry (attention.heads → gdn.value_heads). - with SafetensorLoader(source_files) as loader: - executor = StreamingExecutor(plan, loader) - with ShardedSafetensorWriter(output_dir) as writer: - for key, tensor in executor.execute(seed=42): - writer.add(key, tensor) +**Additive vs Replacement Surgeries** + Additive surgeries (no ``type:`` declaration) satisfy the action law. + Replacement surgeries (explicit ``type:``) use last-write-wins. + +Module Structure +================ + +- ``config.py`` - Config composition (compose_configs, strip_init_fields) +- ``converters.py`` - Plan builders (plan_surgery, plan_mil_attention_to_mamba) +- ``expr.py`` - Expression types (Ref, Slice, Concat, Init, ExprPlan) +- ``executor.py`` - Plan execution (StreamingExecutor, execute) +- ``io.py`` - Streaming I/O (SafetensorLoader, ShardedSafetensorWriter) """ +# Config composition +from fast_llm_external_models.apriel2.conversion.config import compose_configs, strip_init_fields + +# Plan builders (generic) +from fast_llm_external_models.apriel2.conversion.converters import ( + plan_dil_attention_to_gdn, + plan_kil_attention_to_kda, + plan_mil_attention_to_mamba, + plan_surgery, +) + +# Execution +from fast_llm_external_models.apriel2.conversion.executor import MAX_SEED, StreamingExecutor, execute + # Core types and plan operations from fast_llm_external_models.apriel2.conversion.expr import ( Concat, @@ -104,13 +154,6 @@ substitute, ) -# Execution -from fast_llm_external_models.apriel2.conversion.executor import ( - MAX_SEED, - StreamingExecutor, - execute, -) - # I/O utilities from fast_llm_external_models.apriel2.conversion.io import ( DEFAULT_MAX_SHARD_SIZE, @@ -118,22 +161,9 @@ ShardedSafetensorWriter, ) -# Plan builders (generic) -from fast_llm_external_models.apriel2.conversion.converters import ( - plan_mil_attention_to_mamba, - plan_dil_attention_to_gdn, - plan_kil_attention_to_kda, - plan_surgery, -) - -# Config composition -from fast_llm_external_models.apriel2.conversion.config import compose_configs - # Source-specific converters -from fast_llm_external_models.apriel2.conversion.llava import ( - convert_config as convert_llava_config, - plan_llava_to_apriel2, -) +from fast_llm_external_models.apriel2.conversion.llava import convert_config as convert_llava_config +from fast_llm_external_models.apriel2.conversion.llava import plan_llava_to_apriel2 # Rendering (optional, imported lazily by ExprPlan.render_tree) # from fast_llm_external_models.apriel2.conversion.render import render_tree @@ -175,6 +205,7 @@ "plan_kil_attention_to_kda", # Config composition "compose_configs", + "strip_init_fields", # Source-specific converters "convert_llava_config", "plan_llava_to_apriel2", diff --git a/fast_llm_external_models/apriel2/conversion/config.py b/fast_llm_external_models/apriel2/conversion/config.py index 48f8ff44b..3752688c1 100644 --- a/fast_llm_external_models/apriel2/conversion/config.py +++ b/fast_llm_external_models/apriel2/conversion/config.py @@ -1,56 +1,136 @@ """Config composition for Apriel2 architecture transformations. -This module handles STRUCTURAL composition of configs, independent of weight handling. -The `init` field in surgery specs is preserved as metadata for the plan builder but -does not affect how configs are composed. +Conceptual Types +================ -Composition Cases -================= +The system operates on three conceptual types, all represented as ``dict``: -compose_configs(base, overlay) handles four cases based on completeness: +**State (S)** + A complete structural description of a model. Has ``hidden_size`` and ``decoder``. + Does NOT contain ``init`` fields. Represents WHAT a model looks like. -1. **Complete + Partial** → Apply surgery semantics (inheritance, cross-type derivation) -2. **Partial + Partial** → Deep merge (monoid operation on surgery specs) -3. **Partial + Complete** → Overlay wins (complete config replaces partial) -4. **Complete + Complete** → Deep merge, then strip `init` fields + Example: A saved config.json, or a model you're about to transform. -A config is "complete" if it has `hidden_size` and `decoder` (i.e., it's a full model -config, not a surgery spec). +**Partial Surgery (P)** + An incomplete config specifying fields to change. Missing ``hidden_size`` or + ``decoder``. May contain ``init`` fields specifying weight initialization mode. -Surgery Semantics -================= + Example: ``{"decoder": {"block": {"mixer": {"type": "gdn", "init": "random"}}}}`` -When applying a surgery spec to a complete config: +**Transition Spec (T)** + A complete config WITH ``init`` fields. Describes both the target structure + AND how to initialize weights. This is the output of applying a surgery to + a state - it's a complete specification of the transformation. -**Inheritance** - Unspecified parameters inherit from the source config. New blocks inherit - from the "default" block (first block in pattern, or the single fixed block). + Example: The result of ``compose_configs(state, surgery)`` before stripping. -**Cross-Type Derivation** - When changing mixer types, geometric parameters are derived where possible: - - attention → sliding_window: preserve heads, head_groups, head_size - - attention → gdn: heads → value_heads, head_groups → key_heads - - attention → mamba: derive d_inner, d_xb, dt_rank from hidden_size - - attention → kda: preserve heads, head_size → head_dim +The distinction between S and T is semantic (presence of ``init``), not structural. +Both are "complete" in the sense of having ``hidden_size`` and ``decoder``. -**Stochastic Mixer Composition** - Two semantics based on whether surgery declares `type: stochastic`: - - Replacement: surgery declares type → only surgery's sub-mixers included - - Additive: surgery omits type → source sub-mixers preserved, surgery adds/modifies +Algebraic Structure +=================== - This distinction means the monoid action law holds for additive surgeries but - intentionally fails for replacement surgeries (they have "last-write-wins" semantics). +**Partial Surgeries form a Monoid (P, ∘, {})**:: -The `init` Field -================ + compose_configs : P × P → P (deep merge, overlay wins) + + Identity: compose_configs(p, {}) = compose_configs({}, p) = p + Associativity: compose_configs(compose_configs(a, b), c) + = compose_configs(a, compose_configs(b, c)) + +**Surgeries act on States to produce Transition Specs**:: + + compose_configs : S × P → T (apply surgery with inheritance) + compose_configs : T × P → T (extend transition with more surgery) + +**Action Law (for additive surgeries)**:: + + compose_configs(compose_configs(s, p₁), p₂) = compose_configs(s, compose_configs(p₁, p₂)) + +This law holds when surgeries are "additive" (modifying existing structure without +declaring new types). For "replacement" surgeries (explicitly declaring ``type:``), +the action law intentionally fails - this is last-write-wins semantics. + +**State Extraction**:: + + strip_init_fields : T → S (remove init metadata for saving) + +Operations Summary +================== + +``compose_configs(base, overlay)`` dispatches based on completeness: + +1. **S × P → T** : Apply surgery to state (inheritance, cross-type derivation) +2. **T × P → T** : Extend transition spec with more surgery +3. **P × P → P** : Merge partial surgeries (monoid operation) +4. **S × S → S** : Merge states (deep merge, rare) +5. **P × S → S** : Overlay wins (complete replaces partial) + +``strip_init_fields(config)`` removes all ``init`` fields, converting T → S. + +Inheritance Semantics +===================== + +When applying a surgery (S × P → T): + +- Unspecified fields inherit from source state +- New decoder blocks inherit from the "default" block +- Cross-type derivation maps geometry (attention.heads → gdn.value_heads, etc.) +- Stochastic mixers: additive surgery preserves source mixers, replacement replaces + +The ``init`` Field +================== + +The ``init`` field specifies weight initialization mode for ``plan_surgery()``: + +- ``init: transfer`` → transfer weights from source (possibly with conversion) +- ``init: random`` → randomly initialize weights + +**Key invariant**: ``init`` is preserved through composition so ``plan_surgery()`` +can read it. Use ``strip_init_fields()`` to obtain a pure state for: + +- Saving to disk (config.json should not contain ``init``) +- Starting the next surgery iteration (current_state should be S, not T) + +Typical Usage Pattern +===================== + +:: + + current_state: S = load_config(...) + + for surgery: P in surgery_chain: + transition: T = compose_configs(current_state, surgery) # S × P → T + plan = plan_surgery(current_state, transition) # plan reads init from T + current_state: S = strip_init_fields(transition) # T → S for next iteration + + save_config(current_state) # S has no init fields + +Sequential vs Merged Surgery Application +======================================== + +**IMPORTANT**: Applying surgeries sequentially (with stripping) differs from merging +surgeries first then applying once. This affects ``init`` semantics: + +**Sequential** (recommended):: -The `init` field is metadata for the plan builder, NOT for config composition: -- `init: transfer` → plan builder creates weight transfer mappings -- `init: random` → plan builder creates random initialization + t1 = compose_configs(s, p1) # GDN gets init: random + s1 = strip_init_fields(t1) # GDN loses init + t2 = compose_configs(s1, p2) # GDN has init: None → transfer mode -After surgery is applied to produce a complete config, ALL `init` fields are stripped. -This ensures configs are purely structural and plan creation is Markovian (depends only -on current config + surgery, not on history). +**Merged**:: + + merged = compose_configs(p1, p2) # GDN keeps init: random from p1 + t = compose_configs(s, merged) # GDN has init: random → random mode + +The sequential approach means ``init: random`` applies **only to the surgery that +introduces a component**. Subsequent surgeries transfer existing weights by default. + +This is the intended behavior: if surgery 1 adds GDN with random init, and surgery 2 +adds sliding window (not mentioning GDN), GDN keeps its weights from surgery 1. + +The merged approach would re-randomize GDN in every execution, which is rarely desired. +Always use the sequential pattern shown in "Typical Usage Pattern" above. """ from __future__ import annotations @@ -65,14 +145,42 @@ def is_complete(config: dict) -> bool: def compose_configs(base: dict, overlay: dict | None) -> dict: - """Compose two configs. + """Compose configs. Dispatches based on completeness of arguments. + + Type Signatures (see module docstring for S, P, T definitions):: + + S × P → T Apply surgery to state, get transition spec + T × P → T Extend transition spec with more surgery + P × P → P Merge partial surgeries (monoid operation) + S × S → S Merge states (deep merge) + P × S → S Overlay wins + + The ``init`` field is preserved in all cases. Use ``strip_init_fields()`` + to convert T → S for saving or iteration. Args: - base: Base config (complete or partial surgery spec). - overlay: Overlay config (complete or partial surgery spec). + base: State (S), transition spec (T), or partial surgery (P). + overlay: Partial surgery (P) or state (S). Returns: - Composed config. + Composed config. Type depends on inputs (see signatures above). + + Algebraic Properties: + Monoid: ``compose(compose(p1, p2), p3) == compose(p1, compose(p2, p3))`` + + Action law (additive surgeries): + ``compose(compose(s, p1), p2) == compose(s, compose(p1, p2))`` + + Example:: + + # S × P → T (apply surgery to state) + state = {"hidden_size": 256, "decoder": {...}} + surgery = {"decoder": {"block": {"mixer": {"init": "random"}}}} + transition = compose_configs(state, surgery) # T, has init + + # Build plan, then extract state + plan = plan_surgery(state, transition) + new_state = strip_init_fields(transition) # S, no init """ if not overlay: return copy.deepcopy(base) @@ -94,9 +202,8 @@ def compose_configs(base: dict, overlay: dict | None) -> dict: if not base_complete and overlay_complete: return copy.deepcopy(overlay) - # Case 4: Both complete -> deep merge + # Case 4: Both complete -> deep merge (init preserved for plan_surgery) result = _deep_merge(base, overlay) - _strip_keys(result, {"init"}) return result @@ -128,26 +235,53 @@ def _strip_keys(config: Any, keys_to_strip: set[str]) -> None: _strip_keys(item, keys_to_strip) +def strip_init_fields(config: dict) -> dict: + """Return a copy of config with all ``init`` fields stripped (T → S). + + Converts a transition spec (T) to a state (S) by removing ``init`` metadata. + Use this: + + 1. Before saving configs to disk (config.json should be purely structural) + 2. Between surgery iterations (so subsequent surgeries don't re-randomize) + + See module docstring section "Sequential vs Merged Surgery Application" for + why stripping between iterations is critical. + + Args: + config: Config dict (not modified). Typically a transition spec (T). + + Returns: + A deep copy with all ``init`` fields recursively removed (a state S). + """ + result = copy.deepcopy(config) + _strip_keys(result, {"init"}) + return result + + # ============================================================================= # Surgery application with full semantics # ============================================================================= def apply_surgery(source_config: dict, surgery_config: dict | None) -> dict: - """Apply surgery specification to a complete source config. + """Apply surgery spec to complete config (the monoid action). - This handles: - - Top-level scalar overrides - - Decoder composition (fixed vs pattern) - - Stochastic mixer sub-mixer inheritance - - Cross-type derivation (attention → gdn, attention → mamba) + This is the internal implementation of the monoid action: surgery specs + acting on complete configs. Called by compose_configs when base is complete + and overlay is partial. + + Implements inheritance semantics: + - Unspecified fields inherit from source + - Cross-type derivation maps geometry (attention → gdn, etc.) + - Stochastic sub-mixers inherit from source's main mixer + - `init` fields are PRESERVED for plan_surgery() to see Args: - source_config: Complete Apriel2 config. - surgery_config: Partial surgery specification. + source_config: Complete Apriel2 config (the state being acted on). + surgery_config: Partial surgery spec (the monoid element acting). Returns: - Complete Apriel2 config with surgery applied. + Complete config with surgery applied. `init` fields preserved. """ if not surgery_config: return copy.deepcopy(source_config) @@ -189,8 +323,9 @@ def apply_surgery(source_config: dict, surgery_config: dict | None) -> dict: surgery_config["vision_encoder"], ) - # Strip init keys from final result - _strip_keys(result, {"init"}) + # NOTE: We do NOT strip init keys here. The `init` field is preserved through + # composition so that plan_surgery() can see it and decide between transfer + # vs random initialization. The caller (convert.py) strips init before saving. return result @@ -392,6 +527,12 @@ def _compose_single_mixer(source: dict, surgery: dict, hidden_size: int) -> dict result[key] = surgery[key] elif key in source: result[key] = source[key] + # Copy per-layer bias settings (query_layer, key_layer, value_layer, dense_layer) + for key in ["query_layer", "key_layer", "value_layer", "dense_layer", "add_linear_biases"]: + if key in surgery: + result[key] = surgery[key] + elif key in source: + result[key] = copy.deepcopy(source[key]) # Preserve init if "init" in surgery: result["init"] = surgery["init"] diff --git a/fast_llm_external_models/apriel2/conversion/converters.py b/fast_llm_external_models/apriel2/conversion/converters.py index 6d1350c54..9c9238bb0 100644 --- a/fast_llm_external_models/apriel2/conversion/converters.py +++ b/fast_llm_external_models/apriel2/conversion/converters.py @@ -61,16 +61,7 @@ from __future__ import annotations -from fast_llm_external_models.apriel2.conversion.expr import ( - Concat, - Expr, - ExprPlan, - Init, - Ref, - Slice, - W, -) - +from fast_llm_external_models.apriel2.conversion.expr import Concat, Expr, ExprPlan, Init, Ref, Slice, W # ============================================================================= # SECTION 1: Per-Mixer Plan Functions @@ -79,6 +70,21 @@ # This is the single source of truth for each mixer's weight schema. +def _get_attention_bias_enabled(config: dict, layer_name: str) -> bool: + """Get whether bias is enabled for an attention layer. + + Checks per-layer bias config (e.g., query_layer.bias.enabled). + Falls back to add_linear_biases if not set. + """ + layer_cfg = config.get(layer_name, {}) + bias_cfg = layer_cfg.get("bias", {}) + enabled = bias_cfg.get("enabled") + if enabled is not None: + return enabled + # Fall back to add_linear_biases + return config.get("add_linear_biases", False) + + def _plan_attention_mixer( *, prefix: W, @@ -90,9 +96,13 @@ def _plan_attention_mixer( Weight schema: - q_proj.weight: (q_size, hidden_size) + - q_proj.bias: (q_size,) [if query_layer.bias.enabled] - k_proj.weight: (kv_size, hidden_size) + - k_proj.bias: (kv_size,) [if key_layer.bias.enabled] - v_proj.weight: (kv_size, hidden_size) + - v_proj.bias: (kv_size,) [if value_layer.bias.enabled] - o_proj.weight: (hidden_size, q_size) + - o_proj.bias: (hidden_size,) [if dense_layer.bias.enabled] Args: prefix: Target weight path prefix. @@ -100,12 +110,28 @@ def _plan_attention_mixer( hidden_size: Model hidden size. source_prefix: If provided, passthrough from source. If None, random init. """ + # Check per-layer bias configuration + q_bias = _get_attention_bias_enabled(config, "query_layer") + k_bias = _get_attention_bias_enabled(config, "key_layer") + v_bias = _get_attention_bias_enabled(config, "value_layer") + o_bias = _get_attention_bias_enabled(config, "dense_layer") + if source_prefix is not None: - # Passthrough - return ExprPlan(mappings={ + # Passthrough weights + mappings: dict[W, Expr] = { prefix / proj / "weight": Ref(key=source_prefix / proj / "weight") for proj in ["q_proj", "k_proj", "v_proj", "o_proj"] - }) + } + # Passthrough biases if enabled + if q_bias: + mappings[prefix / "q_proj" / "bias"] = Ref(key=source_prefix / "q_proj" / "bias") + if k_bias: + mappings[prefix / "k_proj" / "bias"] = Ref(key=source_prefix / "k_proj" / "bias") + if v_bias: + mappings[prefix / "v_proj" / "bias"] = Ref(key=source_prefix / "v_proj" / "bias") + if o_bias: + mappings[prefix / "o_proj" / "bias"] = Ref(key=source_prefix / "o_proj" / "bias") + return ExprPlan(mappings=mappings) # Random init heads = config["heads"] @@ -114,12 +140,22 @@ def _plan_attention_mixer( q_size = heads * head_size kv_size = head_groups * head_size - return ExprPlan(mappings={ + mappings = { prefix / "q_proj" / "weight": Init(shape=(q_size, hidden_size), init_type="kaiming"), prefix / "k_proj" / "weight": Init(shape=(kv_size, hidden_size), init_type="kaiming"), prefix / "v_proj" / "weight": Init(shape=(kv_size, hidden_size), init_type="kaiming"), prefix / "o_proj" / "weight": Init(shape=(hidden_size, q_size), init_type="kaiming"), - }) + } + # Random init biases if enabled + if q_bias: + mappings[prefix / "q_proj" / "bias"] = Init(shape=(q_size,), init_type="zeros") + if k_bias: + mappings[prefix / "k_proj" / "bias"] = Init(shape=(kv_size,), init_type="zeros") + if v_bias: + mappings[prefix / "v_proj" / "bias"] = Init(shape=(kv_size,), init_type="zeros") + if o_bias: + mappings[prefix / "o_proj" / "bias"] = Init(shape=(hidden_size,), init_type="zeros") + return ExprPlan(mappings=mappings) def _plan_mamba_mixer( @@ -150,20 +186,22 @@ def _plan_mamba_mixer( """ if source_prefix is not None: # Passthrough - include all possible weights - return ExprPlan(mappings={ - prefix / name: Ref(key=source_prefix / name) - for name in [ - "in_proj.weight", - "out_proj.weight", - "dt_in_proj.weight", - "dt_proj.weight", - "dt_proj.bias", - "conv1d.weight", - "conv1d.bias", - "A_log", - "D", - ] - }) + return ExprPlan( + mappings={ + prefix / name: Ref(key=source_prefix / name) + for name in [ + "in_proj.weight", + "out_proj.weight", + "dt_in_proj.weight", + "dt_proj.weight", + "dt_proj.bias", + "conv1d.weight", + "conv1d.bias", + "A_log", + "D", + ] + } + ) # Random init d_inner = config["d_inner"] @@ -181,9 +219,7 @@ def _plan_mamba_mixer( conv_channels = d_inner if repeat_kv_before_conv else d_xb mappings: dict[W, Expr] = { - prefix / "in_proj" / "weight": Init( - shape=(2 * d_inner + 2 * d_xb, hidden_size), init_type="kaiming" - ), + prefix / "in_proj" / "weight": Init(shape=(2 * d_inner + 2 * d_xb, hidden_size), init_type="kaiming"), prefix / "out_proj" / "weight": Init(shape=(hidden_size, d_inner), init_type="kaiming"), prefix / "dt_in_proj" / "weight": Init(shape=(dt_rank, hidden_size), init_type="kaiming"), prefix / "dt_proj" / "weight": Init(shape=(d_inner, dt_rank), init_type="kaiming"), @@ -230,18 +266,20 @@ def _plan_gdn_mixer( """ if source_prefix is not None: # Passthrough - return ExprPlan(mappings={ - prefix / name: Ref(key=source_prefix / name) - for name in [ - "in_proj_qkvz.weight", - "in_proj_ba.weight", - "out_proj.weight", - "convolution.weight", - "A_log", - "dt_bias", - "norm.weight", - ] - }) + return ExprPlan( + mappings={ + prefix / name: Ref(key=source_prefix / name) + for name in [ + "in_proj_qkvz.weight", + "in_proj_ba.weight", + "out_proj.weight", + "convolution.weight", + "A_log", + "dt_bias", + "norm.weight", + ] + } + ) # Random init num_v_heads = config["value_heads"] @@ -255,17 +293,19 @@ def _plan_gdn_mixer( conv_dim = key_dim * 2 + value_dim qkvz_size = key_dim * 2 + value_dim * 2 # Q, K both key_dim; V, Z both value_dim - return ExprPlan(mappings={ - prefix / "in_proj_qkvz" / "weight": Init(shape=(qkvz_size, hidden_size), init_type="kaiming"), - prefix / "in_proj_ba" / "weight": Init(shape=(num_v_heads * 2, hidden_size), init_type="zeros"), - prefix / "out_proj" / "weight": Init(shape=(hidden_size, value_dim), init_type="kaiming"), - prefix / "convolution" / "weight": Init( - shape=(conv_dim, 1, conv_kernel_size), init_type="scaled_identity_conv" - ), - prefix / "A_log": Init(shape=(num_v_heads,), init_type="slow_decay"), - prefix / "dt_bias": Init(shape=(num_v_heads,), init_type="zeros"), - prefix / "norm" / "weight": Init(shape=(head_v_dim,), init_type="ones"), - }) + return ExprPlan( + mappings={ + prefix / "in_proj_qkvz" / "weight": Init(shape=(qkvz_size, hidden_size), init_type="kaiming"), + prefix / "in_proj_ba" / "weight": Init(shape=(num_v_heads * 2, hidden_size), init_type="zeros"), + prefix / "out_proj" / "weight": Init(shape=(hidden_size, value_dim), init_type="kaiming"), + prefix + / "convolution" + / "weight": Init(shape=(conv_dim, 1, conv_kernel_size), init_type="scaled_identity_conv"), + prefix / "A_log": Init(shape=(num_v_heads,), init_type="slow_decay"), + prefix / "dt_bias": Init(shape=(num_v_heads,), init_type="zeros"), + prefix / "norm" / "weight": Init(shape=(head_v_dim,), init_type="ones"), + } + ) def _plan_kda_mixer( @@ -298,26 +338,28 @@ def _plan_kda_mixer( """ if source_prefix is not None: # Passthrough - return ExprPlan(mappings={ - prefix / name: Ref(key=source_prefix / name) - for name in [ - "q_proj.weight", - "k_proj.weight", - "v_proj.weight", - "o_proj.weight", - "q_conv.weight", - "k_conv.weight", - "v_conv.weight", - "f_a_proj.weight", - "f_b_proj.weight", - "g_a_proj.weight", - "g_b_proj.weight", - "beta_proj.weight", - "A_log", - "dt_bias", - "norm.weight", - ] - }) + return ExprPlan( + mappings={ + prefix / name: Ref(key=source_prefix / name) + for name in [ + "q_proj.weight", + "k_proj.weight", + "v_proj.weight", + "o_proj.weight", + "q_conv.weight", + "k_conv.weight", + "v_conv.weight", + "f_a_proj.weight", + "f_b_proj.weight", + "g_a_proj.weight", + "g_b_proj.weight", + "beta_proj.weight", + "A_log", + "dt_bias", + "norm.weight", + ] + } + ) # Random init num_heads = config["heads"] @@ -325,36 +367,38 @@ def _plan_kda_mixer( projection_size = num_heads * head_dim conv_kernel_size = config.get("convolution_layer", {}).get("kernel_size", 4) - return ExprPlan(mappings={ - # Main projections - prefix / "q_proj" / "weight": Init(shape=(projection_size, hidden_size), init_type="kaiming"), - prefix / "k_proj" / "weight": Init(shape=(projection_size, hidden_size), init_type="kaiming"), - prefix / "v_proj" / "weight": Init(shape=(projection_size, hidden_size), init_type="kaiming"), - prefix / "o_proj" / "weight": Init(shape=(hidden_size, projection_size), init_type="kaiming"), - # Convolutions - prefix / "q_conv" / "weight": Init( - shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv" - ), - prefix / "k_conv" / "weight": Init( - shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv" - ), - prefix / "v_conv" / "weight": Init( - shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv" - ), - # Gate kernels (low-rank factorization) - prefix / "f_a_proj" / "weight": Init(shape=(head_dim, hidden_size), init_type="kaiming"), - prefix / "f_b_proj" / "weight": Init(shape=(projection_size, head_dim), init_type="kaiming"), - # Output gate (low-rank factorization) - prefix / "g_a_proj" / "weight": Init(shape=(head_dim, hidden_size), init_type="kaiming"), - prefix / "g_b_proj" / "weight": Init(shape=(projection_size, head_dim), init_type="kaiming"), - # Beta projection - prefix / "beta_proj" / "weight": Init(shape=(num_heads, hidden_size), init_type="kaiming"), - # Learnable parameters - prefix / "A_log": Init(shape=(num_heads,), init_type="slow_decay"), - prefix / "dt_bias": Init(shape=(projection_size,), init_type="zeros"), - # Normalization - prefix / "norm" / "weight": Init(shape=(head_dim,), init_type="ones"), - }) + return ExprPlan( + mappings={ + # Main projections + prefix / "q_proj" / "weight": Init(shape=(projection_size, hidden_size), init_type="kaiming"), + prefix / "k_proj" / "weight": Init(shape=(projection_size, hidden_size), init_type="kaiming"), + prefix / "v_proj" / "weight": Init(shape=(projection_size, hidden_size), init_type="kaiming"), + prefix / "o_proj" / "weight": Init(shape=(hidden_size, projection_size), init_type="kaiming"), + # Convolutions + prefix + / "q_conv" + / "weight": Init(shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv"), + prefix + / "k_conv" + / "weight": Init(shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv"), + prefix + / "v_conv" + / "weight": Init(shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv"), + # Gate kernels (low-rank factorization) + prefix / "f_a_proj" / "weight": Init(shape=(head_dim, hidden_size), init_type="kaiming"), + prefix / "f_b_proj" / "weight": Init(shape=(projection_size, head_dim), init_type="kaiming"), + # Output gate (low-rank factorization) + prefix / "g_a_proj" / "weight": Init(shape=(head_dim, hidden_size), init_type="kaiming"), + prefix / "g_b_proj" / "weight": Init(shape=(projection_size, head_dim), init_type="kaiming"), + # Beta projection + prefix / "beta_proj" / "weight": Init(shape=(num_heads, hidden_size), init_type="kaiming"), + # Learnable parameters + prefix / "A_log": Init(shape=(num_heads,), init_type="slow_decay"), + prefix / "dt_bias": Init(shape=(projection_size,), init_type="zeros"), + # Normalization + prefix / "norm" / "weight": Init(shape=(head_dim,), init_type="ones"), + } + ) # Dispatcher for per-mixer plan functions @@ -409,16 +453,13 @@ def plan_mil_attention_to_mamba( exprs=( Init(shape=(d_inner, hidden_size), init_type="kaiming"), # z: random Slice( - expr=Ref(key=source_prefix / "v_proj" / "weight"), - slices=((0, d_xb, None), (None, None, None)) + expr=Ref(key=source_prefix / "v_proj" / "weight"), slices=((0, d_xb, None), (None, None, None)) ), # x <- V Slice( - expr=Ref(key=source_prefix / "k_proj" / "weight"), - slices=((0, d_xb, None), (None, None, None)) + expr=Ref(key=source_prefix / "k_proj" / "weight"), slices=((0, d_xb, None), (None, None, None)) ), # B <- K Slice( - expr=Ref(key=source_prefix / "q_proj" / "weight"), - slices=((0, d_inner, None), (None, None, None)) + expr=Ref(key=source_prefix / "q_proj" / "weight"), slices=((0, d_inner, None), (None, None, None)) ), # C <- Q ), dim=0, @@ -532,19 +573,21 @@ def plan_dil_attention_to_gdn( dim=0, ) - return ExprPlan(mappings={ - target_prefix / "in_proj_qkvz" / "weight": in_proj_qkvz_expr, - target_prefix / "in_proj_ba" / "weight": Init( - shape=(2 * num_v_heads, hidden_size), init_type="zeros" - ), # b=a=0 → β=0.5 - target_prefix / "out_proj" / "weight": Ref(key=source_prefix / "o_proj" / "weight"), - target_prefix / "convolution" / "weight": Init( - shape=(conv_dim, 1, conv_kernel_size), init_type="scaled_identity_conv" - ), - target_prefix / "A_log": Init(shape=(num_v_heads,), init_type="slow_decay"), - target_prefix / "dt_bias": Init(shape=(num_v_heads,), init_type="zeros"), - target_prefix / "norm" / "weight": Init(shape=(head_v_dim,), init_type="ones"), - }) + return ExprPlan( + mappings={ + target_prefix / "in_proj_qkvz" / "weight": in_proj_qkvz_expr, + target_prefix + / "in_proj_ba" + / "weight": Init(shape=(2 * num_v_heads, hidden_size), init_type="zeros"), # b=a=0 → β=0.5 + target_prefix / "out_proj" / "weight": Ref(key=source_prefix / "o_proj" / "weight"), + target_prefix + / "convolution" + / "weight": Init(shape=(conv_dim, 1, conv_kernel_size), init_type="scaled_identity_conv"), + target_prefix / "A_log": Init(shape=(num_v_heads,), init_type="slow_decay"), + target_prefix / "dt_bias": Init(shape=(num_v_heads,), init_type="zeros"), + target_prefix / "norm" / "weight": Init(shape=(head_v_dim,), init_type="ones"), + } + ) def plan_kil_attention_to_kda( @@ -595,9 +638,7 @@ def plan_kil_attention_to_kda( for h in range(num_heads): src_h = h % source_num_q_heads row_start = src_h * source_head_dim - q_slices.append( - Slice(expr=q_ref, slices=((row_start, row_start + head_dim, None), (None, None, None))) - ) + q_slices.append(Slice(expr=q_ref, slices=((row_start, row_start + head_dim, None), (None, None, None)))) q_expr = Concat(exprs=tuple(q_slices), dim=0) # K: tile source KV heads to fill target projection_size @@ -608,9 +649,7 @@ def plan_kil_attention_to_kda( for h in range(num_heads): src_h = h % source_num_kv_heads row_start = src_h * source_head_dim - k_slices.append( - Slice(expr=k_ref, slices=((row_start, row_start + head_dim, None), (None, None, None))) - ) + k_slices.append(Slice(expr=k_ref, slices=((row_start, row_start + head_dim, None), (None, None, None)))) k_expr = Concat(exprs=tuple(k_slices), dim=0) # V: tile source KV heads to fill target projection_size @@ -621,41 +660,41 @@ def plan_kil_attention_to_kda( for h in range(num_heads): src_h = h % source_num_kv_heads row_start = src_h * source_head_dim - v_slices.append( - Slice(expr=v_ref, slices=((row_start, row_start + head_dim, None), (None, None, None))) - ) + v_slices.append(Slice(expr=v_ref, slices=((row_start, row_start + head_dim, None), (None, None, None)))) v_expr = Concat(exprs=tuple(v_slices), dim=0) - return ExprPlan(mappings={ - # Transfer main projections - target_prefix / "q_proj" / "weight": q_expr, - target_prefix / "k_proj" / "weight": k_expr, - target_prefix / "v_proj" / "weight": v_expr, - target_prefix / "o_proj" / "weight": Ref(key=source_prefix / "o_proj" / "weight"), - # Random init: convolutions (scaled identity for near-passthrough initially) - target_prefix / "q_conv" / "weight": Init( - shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv" - ), - target_prefix / "k_conv" / "weight": Init( - shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv" - ), - target_prefix / "v_conv" / "weight": Init( - shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv" - ), - # Random init: gate kernels (low-rank factorization) - target_prefix / "f_a_proj" / "weight": Init(shape=(head_dim, hidden_size), init_type="kaiming"), - target_prefix / "f_b_proj" / "weight": Init(shape=(projection_size, head_dim), init_type="kaiming"), - # Random init: output gate (low-rank factorization) - target_prefix / "g_a_proj" / "weight": Init(shape=(head_dim, hidden_size), init_type="kaiming"), - target_prefix / "g_b_proj" / "weight": Init(shape=(projection_size, head_dim), init_type="kaiming"), - # Random init: beta projection - target_prefix / "beta_proj" / "weight": Init(shape=(num_heads, hidden_size), init_type="kaiming"), - # Random init: learnable parameters - target_prefix / "A_log": Init(shape=(num_heads,), init_type="slow_decay"), - target_prefix / "dt_bias": Init(shape=(projection_size,), init_type="zeros"), - # Random init: normalization - target_prefix / "norm" / "weight": Init(shape=(head_dim,), init_type="ones"), - }) + return ExprPlan( + mappings={ + # Transfer main projections + target_prefix / "q_proj" / "weight": q_expr, + target_prefix / "k_proj" / "weight": k_expr, + target_prefix / "v_proj" / "weight": v_expr, + target_prefix / "o_proj" / "weight": Ref(key=source_prefix / "o_proj" / "weight"), + # Random init: convolutions (scaled identity for near-passthrough initially) + target_prefix + / "q_conv" + / "weight": Init(shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv"), + target_prefix + / "k_conv" + / "weight": Init(shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv"), + target_prefix + / "v_conv" + / "weight": Init(shape=(projection_size, 1, conv_kernel_size), init_type="scaled_identity_conv"), + # Random init: gate kernels (low-rank factorization) + target_prefix / "f_a_proj" / "weight": Init(shape=(head_dim, hidden_size), init_type="kaiming"), + target_prefix / "f_b_proj" / "weight": Init(shape=(projection_size, head_dim), init_type="kaiming"), + # Random init: output gate (low-rank factorization) + target_prefix / "g_a_proj" / "weight": Init(shape=(head_dim, hidden_size), init_type="kaiming"), + target_prefix / "g_b_proj" / "weight": Init(shape=(projection_size, head_dim), init_type="kaiming"), + # Random init: beta projection + target_prefix / "beta_proj" / "weight": Init(shape=(num_heads, hidden_size), init_type="kaiming"), + # Random init: learnable parameters + target_prefix / "A_log": Init(shape=(num_heads,), init_type="slow_decay"), + target_prefix / "dt_bias": Init(shape=(projection_size,), init_type="zeros"), + # Random init: normalization + target_prefix / "norm" / "weight": Init(shape=(head_dim,), init_type="ones"), + } + ) # ============================================================================= @@ -786,7 +825,70 @@ def plan_surgery( source_config: dict, target_config: dict, ) -> ExprPlan: - """Build plan for Apriel2→Apriel2 surgery (MIL, DIL, KIL, stochastic mixers, etc.).""" + """Build a weight conversion plan: S × T → Plan. + + Creates an ExprPlan mapping target weight keys to expressions over source weights. + Handles same-type passthrough, cross-type conversions (MIL, DIL, KIL), and + stochastic mixer routing. + + Type Signature:: + + plan_surgery : S × T → Plan + + Where S is a state (source) and T is a transition spec (target with ``init`` fields). + + The ``init`` Field + ------------------ + + The ``init`` field in ``target_config`` controls weight initialization: + + - ``init: transfer`` (or absent) → create Ref expressions (transfer from source) + - ``init: random`` → create Init expressions (random initialization) + + This is why ``target_config`` should be a transition spec (T) from ``compose_configs``, + not a stripped state (S). If ``init`` fields are missing, all components default to + transfer mode. + + Args: + source_config: State (S) - complete config describing source architecture. + Must have hidden_size, decoder, etc. No ``init`` fields expected. + target_config: Transition spec (T) - complete config with ``init`` fields. + Use ``compose_configs(source, surgery)`` to produce this. + + Returns: + ExprPlan mapping target weight keys to expressions over source weights. + + Example:: + + # Apply a surgery that wraps attention in a stochastic mixer + surgery_spec = { + "decoder": {"block": {"mixer": { + "type": "stochastic", + "mixers": { + "attention": {"init": "transfer"}, + "gdn": {"type": "gdn", "init": "random"}, + } + }}} + } + + # S × P → T + transition = compose_configs(source_config, surgery_spec) + + # S × T → Plan + plan = plan_surgery(source_config, transition) + + # Execute + new_weights = execute(plan, source_weights, seed=42) + + # T → S for saving + target_state = strip_init_fields(transition) + + Note: + Both arguments must be complete (have hidden_size and decoder). + The target_config should retain ``init`` fields from the surgery spec. + Passing a stripped state as target will cause all components to use + transfer mode, which may not be intended. + """ hidden_size = target_config.get("hidden_size", source_config.get("hidden_size")) assert hidden_size is not None, "hidden_size must be specified in source or target config" @@ -804,18 +906,24 @@ def plan_surgery( target_block = _get_block_config(target_decoder, target_layer_idx) plan += _plan_mixer( - target_layer_idx, source_layer_idx, - source_block.get("mixer", {}), target_block.get("mixer", {}), + target_layer_idx, + source_layer_idx, + source_block.get("mixer", {}), + target_block.get("mixer", {}), hidden_size, ) plan += _plan_mlp( - target_layer_idx, source_layer_idx, - source_block.get("mlp", {}), target_block.get("mlp", {}), + target_layer_idx, + source_layer_idx, + source_block.get("mlp", {}), + target_block.get("mlp", {}), hidden_size, ) plan += _plan_norms( - target_layer_idx, source_layer_idx, - source_block, target_block, + target_layer_idx, + source_layer_idx, + source_block, + target_block, hidden_size, ) @@ -839,14 +947,16 @@ def _plan_non_decoder_weights(config: dict) -> ExprPlan: embed = W("model", "embed_tokens", "weight") mappings[embed] = Ref(key=embed) - head = W("lm_head", "weight") - mappings[head] = Ref(key=head) + # lm_head only if not tied to embeddings + if not config.get("tie_word_embeddings", False): + head = W("lm_head", "weight") + mappings[head] = Ref(key=head) norm = W("model", "norm", "weight") mappings[norm] = Ref(key=norm) - if "vision_encoder" in config: - vision_config = config["vision_encoder"] + vision_config = config.get("vision_encoder") + if vision_config: vision = W("model", "vision_encoder") patch_emb = vision / "embeddings" / "patch_embeddings" / "weight" @@ -950,9 +1060,13 @@ def _plan_mixer( source_prefix = source_mixer_base plan += _plan_mixer_transfer( - matched_source_type, sub_type, - matched_source, sub_config, - source_prefix, target_prefix, hidden_size, + matched_source_type, + sub_type, + matched_source, + sub_config, + source_prefix, + target_prefix, + hidden_size, ) # Passthrough source sub-mixers not in target spec @@ -963,8 +1077,13 @@ def _plan_mixer( source_prefix = source_layer / "mixer" / "mixers" / sub_name target_prefix = target_layer / "mixer" / "mixers" / sub_name plan += _plan_mixer_transfer( - sub_type, sub_type, sub_config, sub_config, - source_prefix, target_prefix, hidden_size, + sub_type, + sub_type, + sub_config, + sub_config, + source_prefix, + target_prefix, + hidden_size, ) return plan @@ -980,12 +1099,34 @@ def _plan_mixer( source_prefix = source_layer / "mixer" return _plan_mixer_transfer( - main_source_type, target_type, - main_source, target_mixer, - source_prefix, target_prefix, hidden_size, + main_source_type, + target_type, + main_source, + target_mixer, + source_prefix, + target_prefix, + hidden_size, ) +def _get_mlp_bias_enabled(config: dict, layer_name: str) -> bool: + """Get whether bias is enabled for an MLP layer. + + Checks per-layer bias config (e.g., layer_1.bias.enabled, layer_2.bias.enabled). + Falls back to add_linear_biases if not set. + + Note: layer_1 corresponds to gate_proj and up_proj (gated MLP) or just up_proj (non-gated) + layer_2 corresponds to down_proj + """ + layer_cfg = config.get(layer_name, {}) + bias_cfg = layer_cfg.get("bias", {}) + enabled = bias_cfg.get("enabled") + if enabled is not None: + return enabled + # Fall back to add_linear_biases + return config.get("add_linear_biases", False) + + def _plan_mlp( target_layer_idx: int, source_layer_idx: int, @@ -1006,7 +1147,7 @@ def _plan_mlp_transfer( target_mlp: dict, hidden_size: int, ) -> ExprPlan: - """Passthrough for MLP weights.""" + """Passthrough for MLP weights and biases.""" source_mlp_path = W("model", "decoder", "blocks", source_layer_idx, "mlp") target_mlp_path = W("model", "decoder", "blocks", target_layer_idx, "mlp") @@ -1019,10 +1160,36 @@ def _plan_mlp_transfer( f"Use 'init: random' to initialize randomly." ) - return ExprPlan(mappings={ - target_mlp_path / proj / "weight": Ref(key=source_mlp_path / proj / "weight") - for proj in ["gate_proj", "up_proj", "down_proj"] - }) + # Check per-layer bias configuration + layer_1_bias = _get_mlp_bias_enabled(target_mlp, "layer_1") + layer_2_bias = _get_mlp_bias_enabled(target_mlp, "layer_2") + + # Check if gated MLP (has gate_proj) or non-gated (only up_proj) + gated = target_mlp.get("gated", True) # Default to gated for backwards compatibility + + # Passthrough weights + # layer_1 = gate_proj + up_proj (gated) or just up_proj (non-gated) + # layer_2 = down_proj + if gated: + weight_projs = ["gate_proj", "up_proj", "down_proj"] + else: + weight_projs = ["up_proj", "down_proj"] + + mappings: dict[W, Expr] = { + target_mlp_path / proj / "weight": Ref(key=source_mlp_path / proj / "weight") for proj in weight_projs + } + + # Passthrough biases if enabled + if layer_1_bias: + if gated: + mappings[target_mlp_path / "gate_proj" / "bias"] = Ref(key=source_mlp_path / "gate_proj" / "bias") + mappings[target_mlp_path / "up_proj" / "bias"] = Ref(key=source_mlp_path / "up_proj" / "bias") + + # layer_2 = down_proj + if layer_2_bias: + mappings[target_mlp_path / "down_proj" / "bias"] = Ref(key=source_mlp_path / "down_proj" / "bias") + + return ExprPlan(mappings=mappings) def _plan_random_mlp( @@ -1030,20 +1197,41 @@ def _plan_random_mlp( target_mlp: dict, hidden_size: int, ) -> ExprPlan: - """Random initialization for MLP.""" + """Random initialization for MLP weights and biases.""" target_mlp_path = W("model", "decoder", "blocks", target_layer_idx, "mlp") intermediate_size = target_mlp["intermediate_size"] - return ExprPlan(mappings={ - target_mlp_path / "gate_proj" / "weight": Init( - shape=(intermediate_size, hidden_size), init_type="kaiming" - ), - target_mlp_path / "up_proj" / "weight": Init( + + # Check per-layer bias configuration + layer_1_bias = _get_mlp_bias_enabled(target_mlp, "layer_1") + layer_2_bias = _get_mlp_bias_enabled(target_mlp, "layer_2") + + # Check if gated MLP (has gate_proj) or non-gated (only up_proj) + gated = target_mlp.get("gated", True) # Default to gated for backwards compatibility + + # Random init weights + mappings: dict[W, Expr] = {} + if gated: + mappings[target_mlp_path / "gate_proj" / "weight"] = Init( shape=(intermediate_size, hidden_size), init_type="kaiming" - ), - target_mlp_path / "down_proj" / "weight": Init( - shape=(hidden_size, intermediate_size), init_type="kaiming" - ), - }) + ) + mappings[target_mlp_path / "up_proj" / "weight"] = Init( + shape=(intermediate_size, hidden_size), init_type="kaiming" + ) + mappings[target_mlp_path / "down_proj" / "weight"] = Init( + shape=(hidden_size, intermediate_size), init_type="kaiming" + ) + + # Random init biases if enabled + if layer_1_bias: + if gated: + mappings[target_mlp_path / "gate_proj" / "bias"] = Init(shape=(intermediate_size,), init_type="zeros") + mappings[target_mlp_path / "up_proj" / "bias"] = Init(shape=(intermediate_size,), init_type="zeros") + + # layer_2 = down_proj + if layer_2_bias: + mappings[target_mlp_path / "down_proj" / "bias"] = Init(shape=(hidden_size,), init_type="zeros") + + return ExprPlan(mappings=mappings) def _plan_norms( @@ -1083,10 +1271,12 @@ def _plan_norms_transfer( f"Use 'init: random' to initialize randomly." ) - return ExprPlan(mappings={ - target_layer / norm_name / "weight": Ref(key=source_layer / norm_name / "weight") - for norm_name in ["input_layernorm", "post_attention_layernorm"] - }) + return ExprPlan( + mappings={ + target_layer / norm_name / "weight": Ref(key=source_layer / norm_name / "weight") + for norm_name in ["input_layernorm", "post_attention_layernorm"] + } + ) def _plan_random_norms( @@ -1095,7 +1285,9 @@ def _plan_random_norms( ) -> ExprPlan: """Random initialization for normalization layers.""" target_layer = W("model", "decoder", "blocks", target_layer_idx) - return ExprPlan(mappings={ - target_layer / norm_name / "weight": Init(shape=(hidden_size,), init_type="ones") - for norm_name in ["input_layernorm", "post_attention_layernorm"] - }) + return ExprPlan( + mappings={ + target_layer / norm_name / "weight": Init(shape=(hidden_size,), init_type="ones") + for norm_name in ["input_layernorm", "post_attention_layernorm"] + } + ) diff --git a/fast_llm_external_models/apriel2/conversion/executor.py b/fast_llm_external_models/apriel2/conversion/executor.py index a6c5672f0..b0779c97f 100644 --- a/fast_llm_external_models/apriel2/conversion/executor.py +++ b/fast_llm_external_models/apriel2/conversion/executor.py @@ -29,7 +29,8 @@ from __future__ import annotations import hashlib -from typing import Callable, Iterator +from collections.abc import Iterator +from typing import Callable import torch from torch import Tensor @@ -81,8 +82,7 @@ def execute( break else: raise ValueError( - "Cannot infer device/dtype: plan has no source references. " - "Provide device and dtype explicitly." + "Cannot infer device/dtype: plan has no source references. " "Provide device and dtype explicitly." ) generator = torch.Generator(device=device) @@ -94,10 +94,7 @@ def execute( # Verify device/dtype consistency for key, tensor in sources.items(): if tensor.device != device or tensor.dtype != dtype: - raise ValueError( - f"Source {key} has {tensor.device}/{tensor.dtype}, " - f"expected {device}/{dtype}" - ) + raise ValueError(f"Source {key} has {tensor.device}/{tensor.dtype}, " f"expected {device}/{dtype}") # Deterministic per-target seed key_offset = int(hashlib.md5(str(target_key).encode()).hexdigest()[:8], 16) diff --git a/fast_llm_external_models/apriel2/conversion/expr.py b/fast_llm_external_models/apriel2/conversion/expr.py index 4867a27ae..34ea106fc 100644 --- a/fast_llm_external_models/apriel2/conversion/expr.py +++ b/fast_llm_external_models/apriel2/conversion/expr.py @@ -52,7 +52,8 @@ import math from collections import defaultdict -from typing import Annotated, Any, Callable, Iterator, Literal, TypedDict, Union, Unpack +from collections.abc import Iterator +from typing import Annotated, Any, Callable, Literal, TypedDict, Union, Unpack import torch from pydantic import BaseModel, ConfigDict, Field, GetCoreSchemaHandler, TypeAdapter @@ -60,7 +61,6 @@ from pydantic_core import CoreSchema, core_schema from torch import Tensor - # ============================================================================= # Weight Path Builder # ============================================================================= @@ -78,7 +78,7 @@ class W(str): mappings[q] = Ref(key=source_q) """ - def __new__(cls, *parts) -> "W": + def __new__(cls, *parts) -> W: # Join parts, stripping any leading/trailing dots from each cleaned = [] for p in parts: @@ -89,12 +89,12 @@ def __new__(cls, *parts) -> "W": cleaned.append(s) return super().__new__(cls, ".".join(cleaned)) - def __truediv__(self, other) -> "W": + def __truediv__(self, other) -> W: if isinstance(other, (list, tuple)): return W(self, *other) return W(self, other) - def __rtruediv__(self, other) -> "W": + def __rtruediv__(self, other) -> W: return W(other, self) @classmethod @@ -156,7 +156,7 @@ class Slice(BaseModel): model_config = ConfigDict(frozen=True) type: Literal["slice"] = "slice" - expr: "Expr" + expr: Expr slices: tuple[tuple[int | None, int | None, int | None], ...] def find_refs(self) -> set[W]: @@ -184,7 +184,7 @@ class Concat(BaseModel): model_config = ConfigDict(frozen=True) type: Literal["concat"] = "concat" - exprs: tuple["Expr", ...] + exprs: tuple[Expr, ...] dim: int = 0 def find_refs(self) -> set[W]: @@ -303,7 +303,7 @@ class Reshape(BaseModel): model_config = ConfigDict(frozen=True) type: Literal["reshape"] = "reshape" - expr: "Expr" + expr: Expr shape: tuple[int, ...] def find_refs(self) -> set[W]: @@ -442,10 +442,10 @@ def __getitem__(self, key: W) -> Expr: def __contains__(self, key: W) -> bool: return key in self.mappings - def __or__(self, other: "ExprPlan") -> "ExprPlan": + def __or__(self, other: ExprPlan) -> ExprPlan: return compose(self, other) - def __add__(self, other: "ExprPlan") -> "ExprPlan": + def __add__(self, other: ExprPlan) -> ExprPlan: return merge(self, other) def source_keys(self) -> set[str]: @@ -471,7 +471,7 @@ def summary(self) -> dict[str, Any]: "metadata": self.metadata, } - def fuse(self) -> "ExprPlan": + def fuse(self) -> ExprPlan: return ExprPlan( mappings={k: fuse(v) for k, v in self.mappings.items()}, source_format=self.source_format, diff --git a/fast_llm_external_models/apriel2/conversion/io.py b/fast_llm_external_models/apriel2/conversion/io.py index e1a261d7e..1f64df0b9 100644 --- a/fast_llm_external_models/apriel2/conversion/io.py +++ b/fast_llm_external_models/apriel2/conversion/io.py @@ -62,7 +62,7 @@ def __init__(self, files: list[Path], device: str = "cpu"): self._handles: dict[Path, Any] = {} self._key_index: dict[str, Path] = {} - def __enter__(self) -> "SafetensorLoader": + def __enter__(self) -> SafetensorLoader: # Pre-build index: key -> file (one-time O(n×m), then O(1) lookups) for f in self.files: handle = safe_open(f, framework="pt", device=self.device) @@ -128,7 +128,7 @@ def __init__( self._finalized: bool = False self._result_path: Path | None = None - def __enter__(self) -> "ShardedSafetensorWriter": + def __enter__(self) -> ShardedSafetensorWriter: return self def __exit__(self, exc_type, exc_val, exc_tb) -> None: @@ -180,8 +180,7 @@ def _flush(self) -> None: shard_file = self.output_dir / f"{self.base_name}-{self._shard_index:05d}.safetensors.tmp" logger.debug( - f"Writing shard {self._shard_index}: {len(self._buffer)} tensors, " - f"{self._buffer_bytes / 1e9:.2f} GB" + f"Writing shard {self._shard_index}: {len(self._buffer)} tensors, " f"{self._buffer_bytes / 1e9:.2f} GB" ) save_file(self._buffer, shard_file) self._shard_files.append(shard_file) diff --git a/fast_llm_external_models/apriel2/conversion/llava/plan.py b/fast_llm_external_models/apriel2/conversion/llava/plan.py index df485efbd..a97e46c1a 100644 --- a/fast_llm_external_models/apriel2/conversion/llava/plan.py +++ b/fast_llm_external_models/apriel2/conversion/llava/plan.py @@ -1,11 +1,6 @@ """Llava to Apriel2 weight conversion plan.""" -from fast_llm_external_models.apriel2.conversion.expr import ( - Expr, - ExprPlan, - Ref, - W, -) +from fast_llm_external_models.apriel2.conversion.expr import Expr, ExprPlan, Ref, W def plan_llava_to_apriel2(llava_config: dict) -> ExprPlan: diff --git a/fast_llm_external_models/apriel2/conversion/qwen2/__init__.py b/fast_llm_external_models/apriel2/conversion/qwen2/__init__.py new file mode 100644 index 000000000..d0a0b8e6e --- /dev/null +++ b/fast_llm_external_models/apriel2/conversion/qwen2/__init__.py @@ -0,0 +1,6 @@ +"""Qwen2/Qwen2.5 to Apriel2 conversion module.""" + +from fast_llm_external_models.apriel2.conversion.qwen2.config import convert_config +from fast_llm_external_models.apriel2.conversion.qwen2.plan import plan_qwen2_to_apriel2 + +__all__ = ["convert_config", "plan_qwen2_to_apriel2"] diff --git a/fast_llm_external_models/apriel2/conversion/qwen2/config.py b/fast_llm_external_models/apriel2/conversion/qwen2/config.py new file mode 100644 index 000000000..70629fe0e --- /dev/null +++ b/fast_llm_external_models/apriel2/conversion/qwen2/config.py @@ -0,0 +1,79 @@ +"""Qwen2/Qwen2.5 to Apriel2 config conversion.""" + + +def convert_config(qwen2_config: dict) -> dict: + """Convert Qwen2/Qwen2.5 config to Apriel2TextConfig format. + + Qwen2.5 architecture: + - Standard transformer with GQA (grouped query attention) + - QKV bias enabled, O bias disabled + - MLP bias disabled + - Gated SwiGLU MLP + - RMSNorm + - RoPE embeddings + + Args: + qwen2_config: HuggingFace Qwen2Config as dict + + Returns: + Apriel2TextConfig-compatible dict + """ + hidden_size = qwen2_config["hidden_size"] + num_attention_heads = qwen2_config["num_attention_heads"] + num_key_value_heads = qwen2_config.get("num_key_value_heads", num_attention_heads) + head_dim = hidden_size // num_attention_heads + + # Qwen2 uses QKV bias but not O bias - mirror Fast-LLM's per-layer config + return { + "model_type": "apriel2_text", + "architectures": ["Apriel2ForCausalLM"], + "auto_map": { + "AutoConfig": "configuration_apriel2.Apriel2TextConfig", + "AutoModel": "modeling_apriel2.Apriel2TextModel", + "AutoModelForCausalLM": "modeling_apriel2.Apriel2ForCausalLM", + }, + "hidden_size": hidden_size, + "vocab_size": qwen2_config["vocab_size"], + "tie_word_embeddings": qwen2_config.get("tie_word_embeddings", False), + "decoder": { + "type": "fixed", + "num_blocks": qwen2_config["num_hidden_layers"], + "block": { + "mixer": { + "type": "attention", + "heads": num_attention_heads, + "head_groups": num_key_value_heads, + "head_size": head_dim, + # Per-layer bias config matching Fast-LLM structure + "query_layer": {"bias": {"enabled": True}}, + "key_layer": {"bias": {"enabled": True}}, + "value_layer": {"bias": {"enabled": True}}, + "dense_layer": {"bias": {"enabled": False}}, + "rotary": { + "type": "mistral_1d", + "theta": qwen2_config.get("rope_theta", 1000000.0), + }, + }, + "mlp": { + "type": "mlp", + "intermediate_size": qwen2_config["intermediate_size"], + "activation": qwen2_config.get("hidden_act", "silu"), + "gated": True, + "add_linear_biases": False, + }, + "normalization": { + "type": "rms_norm", + "epsilon": qwen2_config.get("rms_norm_eps", 1e-6), + }, + }, + }, + "head": { + "normalization": { + "type": "rms_norm", + "epsilon": qwen2_config.get("rms_norm_eps", 1e-6), + } + }, + "embeddings": { + "max_position_embeddings": qwen2_config.get("max_position_embeddings", 32768), + }, + } diff --git a/fast_llm_external_models/apriel2/conversion/qwen2/plan.py b/fast_llm_external_models/apriel2/conversion/qwen2/plan.py new file mode 100644 index 000000000..c1ec4af8b --- /dev/null +++ b/fast_llm_external_models/apriel2/conversion/qwen2/plan.py @@ -0,0 +1,100 @@ +"""Qwen2/Qwen2.5 to Apriel2 weight conversion plan.""" + +from fast_llm_external_models.apriel2.conversion.expr import Expr, ExprPlan, Ref, W + + +def plan_qwen2_to_apriel2(qwen2_config: dict) -> ExprPlan: + """Build an expression plan for Qwen2/Qwen2.5 to Apriel2 conversion. + + This is a pure mapping (all Ref expressions) since Qwen2→Apriel2 + is just renaming keys. The weight tensors are identical. + + Key mapping (source keys have "model." prefix in safetensors): + Qwen2 (safetensor key) Apriel2 + ---------------------- ------- + model.embed_tokens.weight -> model.embed_tokens.weight + model.norm.weight -> model.norm.weight + model.layers.{i}.input_layernorm.weight -> model.decoder.blocks.{i}.input_layernorm.weight + model.layers.{i}.post_attention_layernorm.weight -> model.decoder.blocks.{i}.post_attention_layernorm.weight + model.layers.{i}.self_attn.q_proj.weight -> model.decoder.blocks.{i}.mixer.q_proj.weight + model.layers.{i}.self_attn.q_proj.bias -> model.decoder.blocks.{i}.mixer.q_proj.bias + model.layers.{i}.self_attn.k_proj.weight -> model.decoder.blocks.{i}.mixer.k_proj.weight + model.layers.{i}.self_attn.k_proj.bias -> model.decoder.blocks.{i}.mixer.k_proj.bias + model.layers.{i}.self_attn.v_proj.weight -> model.decoder.blocks.{i}.mixer.v_proj.weight + model.layers.{i}.self_attn.v_proj.bias -> model.decoder.blocks.{i}.mixer.v_proj.bias + model.layers.{i}.self_attn.o_proj.weight -> model.decoder.blocks.{i}.mixer.o_proj.weight + model.layers.{i}.mlp.gate_proj.weight -> model.decoder.blocks.{i}.mlp.gate_proj.weight + model.layers.{i}.mlp.up_proj.weight -> model.decoder.blocks.{i}.mlp.up_proj.weight + model.layers.{i}.mlp.down_proj.weight -> model.decoder.blocks.{i}.mlp.down_proj.weight + + Note: Qwen2 has QKV biases but no O bias. The Apriel2 config uses per-layer + bias settings (query_layer.bias.enabled=True, dense_layer.bias.enabled=False) + to match this exactly - no workarounds needed. + + Args: + qwen2_config: HuggingFace Qwen2Config as dict + + Returns: + ExprPlan with Ref mappings + """ + mappings: dict[str, Expr] = {} + + num_layers = qwen2_config["num_hidden_layers"] + + # Static mappings (embeddings and final norm) + # Note: Qwen2 safetensor keys have "model." prefix + static_mappings = [ + (W("model", "embed_tokens", "weight"), W("model", "embed_tokens", "weight")), + (W("model", "norm", "weight"), W("model", "norm", "weight")), + ] + + # lm_head - only if not tied + if not qwen2_config.get("tie_word_embeddings", False): + static_mappings.append((W("lm_head", "weight"), W("lm_head", "weight"))) + + for src, tgt in static_mappings: + mappings[tgt] = Ref(key=src) + + # Layer mappings + for layer in range(num_layers): + # Source has "model.layers.{i}" prefix + qwen_layer = W("model", "layers", layer) + apriel_layer = W("model", "decoder", "blocks", layer) + + # Attention projection weights + for proj in ["q_proj", "k_proj", "v_proj", "o_proj"]: + src = qwen_layer / "self_attn" / proj / "weight" + tgt = apriel_layer / "mixer" / proj / "weight" + mappings[tgt] = Ref(key=src) + + # QKV biases (Qwen2 has these, but not O bias) + for proj in ["q_proj", "k_proj", "v_proj"]: + src = qwen_layer / "self_attn" / proj / "bias" + tgt = apriel_layer / "mixer" / proj / "bias" + mappings[tgt] = Ref(key=src) + + # Note: o_proj has no bias in Qwen2, and Apriel2 config has dense_layer.bias.enabled=False + + # MLP projections + for proj in ["gate_proj", "up_proj", "down_proj"]: + src = qwen_layer / "mlp" / proj / "weight" + tgt = apriel_layer / "mlp" / proj / "weight" + mappings[tgt] = Ref(key=src) + + # Layer norms + mappings[apriel_layer / "input_layernorm" / "weight"] = Ref(key=qwen_layer / "input_layernorm" / "weight") + mappings[apriel_layer / "post_attention_layernorm" / "weight"] = Ref( + key=qwen_layer / "post_attention_layernorm" / "weight" + ) + + return ExprPlan( + mappings=mappings, + source_format="qwen2", + target_format="apriel2", + metadata={ + "num_layers": num_layers, + "hidden_size": qwen2_config["hidden_size"], + "num_attention_heads": qwen2_config["num_attention_heads"], + "num_key_value_heads": qwen2_config.get("num_key_value_heads", qwen2_config["num_attention_heads"]), + }, + ) diff --git a/fast_llm_external_models/apriel2/conversion/render.py b/fast_llm_external_models/apriel2/conversion/render.py index d71fa03e1..f9a0c8ac1 100644 --- a/fast_llm_external_models/apriel2/conversion/render.py +++ b/fast_llm_external_models/apriel2/conversion/render.py @@ -8,17 +8,11 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING +from fast_llm_external_models.apriel2.conversion.expr import Concat, Init, Ref, Reshape, Slice + if TYPE_CHECKING: from fast_llm_external_models.apriel2.conversion.expr import Expr, ExprPlan -from fast_llm_external_models.apriel2.conversion.expr import ( - Concat, - Init, - Ref, - Reshape, - Slice, -) - @dataclass class PlanTreeNode: @@ -28,10 +22,10 @@ class PlanTreeNode: After merging, leaf nodes contain aggregated values from multiple siblings. """ - children: dict[str, "PlanTreeNode"] = field(default_factory=dict) + children: dict[str, PlanTreeNode] = field(default_factory=dict) # For leaf nodes: list of (sibling_key, expr) pairs # Before merge: single item, after merge: multiple items from merged siblings - values: list[tuple[str, "Expr"]] = field(default_factory=list) + values: list[tuple[str, Expr]] = field(default_factory=list) def is_leaf(self) -> bool: return len(self.children) == 0 @@ -61,7 +55,7 @@ def _build_plan_tree(plan: ExprPlan) -> PlanTreeNode: return root -def _expr_signature(expr: "Expr") -> tuple: +def _expr_signature(expr: Expr) -> tuple: """Get a signature for an expression that determines merge compatibility. Expressions with different signatures should not be merged together. @@ -453,7 +447,7 @@ def _render_plan_tree( ) -def _format_aggregated_leaf(values: list[tuple[str, "Expr"]]) -> str: +def _format_aggregated_leaf(values: list[tuple[str, Expr]]) -> str: """Format a leaf with aggregated values using pattern discovery. Args: @@ -494,7 +488,7 @@ def _format_aggregated_leaf(values: list[tuple[str, "Expr"]]) -> str: return _format_single_expr(first_expr) -def _format_single_expr(expr: "Expr") -> str: +def _format_single_expr(expr: Expr) -> str: """Format a single expression using ML notation.""" match expr: case Ref(key=key): @@ -531,7 +525,7 @@ def _format_single_expr(expr: "Expr") -> str: return f"= {type(expr).__name__}" -def _format_concat_part(expr: "Expr") -> str: +def _format_concat_part(expr: Expr) -> str: """Format a single part of a concat (for short display).""" match expr: case Ref(key=key): @@ -570,7 +564,7 @@ def _format_slice_notation(slices: tuple) -> str: return f"[{', '.join(slice_strs)}]" -def _format_aggregated_concat(values: list[tuple[str, "Expr"]]) -> str: +def _format_aggregated_concat(values: list[tuple[str, Expr]]) -> str: """Format aggregated Concat expressions with pattern discovery.""" # Get the first concat to understand structure first_concat = values[0][1] @@ -590,7 +584,7 @@ def _format_aggregated_concat(values: list[tuple[str, "Expr"]]) -> str: return f"= [{sep.join(formatted_parts)}]" -def _format_aggregated_concat_part(values: list[tuple[str, "Expr"]]) -> str: +def _format_aggregated_concat_part(values: list[tuple[str, Expr]]) -> str: """Format a single part of an aggregated concat.""" if len(values) == 1: return _format_concat_part(values[0][1]) @@ -619,7 +613,7 @@ def _format_aggregated_concat_part(values: list[tuple[str, "Expr"]]) -> str: return _format_concat_part(first_expr) -def _format_aggregated_slice(values: list[tuple[str, "Expr"]]) -> str: +def _format_aggregated_slice(values: list[tuple[str, Expr]]) -> str: """Format aggregated Slice expressions with pattern discovery.""" first_slice = values[0][1] if not isinstance(first_slice, Slice): diff --git a/fast_llm_external_models/apriel2/convert.py b/fast_llm_external_models/apriel2/convert.py index cbf921b31..66c419dfd 100644 --- a/fast_llm_external_models/apriel2/convert.py +++ b/fast_llm_external_models/apriel2/convert.py @@ -15,6 +15,7 @@ Supported source formats: - llava: Llava/Pixtral models +- qwen2: Qwen2/Qwen2.5 models - apriel2: Apriel2 models (surgery-only mode - no conversion, just apply surgeries) """ @@ -29,10 +30,7 @@ import yaml from tqdm import tqdm -# Allow running as script or module -if __name__ == "__main__": - sys.path.insert(0, str(Path(__file__).parent.parent.parent)) - +# Import source-specific converters from fast_llm_external_models.apriel2.conversion import ( DEFAULT_MAX_SHARD_SIZE, ExprPlan, @@ -41,11 +39,16 @@ StreamingExecutor, compose, compose_configs, - plan_surgery, ) - -# Import source-specific converters from fast_llm_external_models.apriel2.conversion import llava as llava_converter +from fast_llm_external_models.apriel2.conversion import plan_surgery +from fast_llm_external_models.apriel2.conversion import qwen2 as qwen2_converter +from fast_llm_external_models.apriel2.conversion import strip_init_fields + +# Allow running as script or module +if __name__ == "__main__": + sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + logger = logging.getLogger(__name__) @@ -73,6 +76,7 @@ def _identity_plan(config: dict) -> ExprPlan: # Each entry maps format name to (config_converter, plan_builder) SOURCE_FORMATS: dict[str, tuple[Callable[[dict], dict], Callable[[dict], ExprPlan]]] = { "llava": (llava_converter.convert_config, llava_converter.plan_llava_to_apriel2), + "qwen2": (qwen2_converter.convert_config, qwen2_converter.plan_qwen2_to_apriel2), "apriel2": (_identity_config, _identity_plan), } @@ -88,8 +92,12 @@ def detect_source_format(config: dict) -> str | None: if model_type in ("llava", "pixtral") or "text_config" in config: return "llava" + # Qwen2/Qwen2.5 detection + if model_type == "qwen2": + return "qwen2" + # Apriel2 detection - check for Apriel2-specific structure - if model_type == "apriel2" or "decoder" in config: + if model_type in ("apriel2", "apriel2_text") or "decoder" in config: return "apriel2" return None @@ -142,15 +150,21 @@ def build_plan( # Apply surgery chain if requested if surgery_configs: for i, surgery_config in enumerate(surgery_configs, 1): - surgery_plan = plan_surgery(current_config, surgery_config) - logger.info(f"Built surgery plan [{i}/{len(surgery_configs)}]: {surgery_plan.summary()['num_targets']} targets") + # S × P → T: compose state with surgery to get transition spec + target_config = compose_configs(current_config, surgery_config) + + # S × T → Plan: build plan from source state and transition spec + surgery_plan = plan_surgery(current_config, target_config) + logger.info( + f"Built surgery plan [{i}/{len(surgery_configs)}]: {surgery_plan.summary()['num_targets']} targets" + ) - # Compose: current -> surgery + # Compose plans current_plan = compose(current_plan, surgery_plan) logger.info(f"Composed plan [{i}/{len(surgery_configs)}]: {current_plan.summary()['num_targets']} targets") - # Compose configs: merge surgery spec into current config - current_config = compose_configs(current_config, surgery_config) + # T → S: strip init for next iteration (init is consumed by plan_surgery) + current_config = strip_init_fields(target_config) return current_plan, current_config @@ -211,9 +225,7 @@ def convert( executor = StreamingExecutor(full_plan, loader) with ShardedSafetensorWriter(output_dir, max_shard_size=max_shard_size) as writer: - for target_key, tensor in tqdm( - executor.execute(seed), desc="Converting", total=len(full_plan) - ): + for target_key, tensor in tqdm(executor.execute(seed), desc="Converting", total=len(full_plan)): writer.add(target_key, tensor) return final_config @@ -282,9 +294,7 @@ def resolve_input(input_path: str) -> Path: def main(): - parser = argparse.ArgumentParser( - description="Convert HuggingFace checkpoint to Apriel2 HF format" - ) + parser = argparse.ArgumentParser(description="Convert HuggingFace checkpoint to Apriel2 HF format") parser.add_argument( "input", type=str, @@ -384,8 +394,7 @@ def main(): safetensor_files = sorted(input_dir.glob("*.safetensors")) if not safetensor_files: raise ValueError( - f"No safetensor files found in {input_dir}. " - "Plan-based conversion requires safetensor files." + f"No safetensor files found in {input_dir}. " "Plan-based conversion requires safetensor files." ) # Convert using plan-based approach with streaming sharded output @@ -400,11 +409,11 @@ def main(): show_plan=args.show_plan or args.verbose, ) - # Save config + # Save config (build_plan returns S which has no init, but strip defensively) output_config_file = args.output_dir / "config.json" logger.info(f"Saving config to {output_config_file}") with open(output_config_file, "w") as f: - json.dump(apriel2_config, f, indent=2) + json.dump(strip_init_fields(apriel2_config), f, indent=2) # Copy tokenizer files copy_tokenizer_files(input_dir, args.output_dir) diff --git a/fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml b/fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml new file mode 100644 index 000000000..34672916c --- /dev/null +++ b/fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml @@ -0,0 +1,103 @@ +# Dataset preparation config for Tulu 3 SFT mixture with Qwen2 tokenizer +# +# This config converts the Tulu 3 SFT dataset (conversation format) to +# Fast-LLM's memmap format, with automatic loss masking span computation +# to train only on assistant responses. +# +# ============================================================================= +# TOKENIZER SETUP (one-time) +# ============================================================================= +# +# The tokenizer must have a chat template with {% generation %} markers. +# Qwen2's default template doesn't have these, so we need to patch it. +# +# IMPORTANT: The entire assistant turn (opening tag + content + closing tag) +# must be inside the {% generation %} block. This ensures the model learns to +# produce the full assistant response including special tokens like <|im_end|>. +# Reference: https://huggingface.co/HuggingFaceTB/SmolLM3-3B/blob/main/chat_template.jinja +# +# Run this Python script to create a patched tokenizer: +# +# from transformers import AutoTokenizer +# +# tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") +# +# # Patch chat template: wrap ENTIRE assistant turn in generation markers +# tokenizer.chat_template = '''{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system +# You are a helpful assistant.<|im_end|> +# ' }}{% endif %}{% if message['role'] == 'assistant' %}{% generation %}{{ '<|im_start|>assistant +# ' + message['content'] + '<|im_end|> +# ' }}{% endgeneration %}{% else %}{{ '<|im_start|>' + message['role'] + ' +# ' + message['content'] + '<|im_end|> +# ' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant +# ' }}{% endif %}''' +# +# tokenizer.save_pretrained("/path/to/qwen2-instruct-with-markers") +# +# ============================================================================= +# DATA PREPARATION +# ============================================================================= +# +# Small dataset (for testing): +# +# fast-llm prepare gpt_memmap \ +# -c fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml \ +# dataset.split=train[:1000] \ +# output_path=/path/to/tulu3-prepared-small +# +# Full dataset (~939K samples, ~6 minutes): +# +# fast-llm prepare gpt_memmap \ +# -c fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml +# +# ============================================================================= +# VERIFICATION +# ============================================================================= +# +# To verify the prepared dataset has loss masking spans: +# +# import pathlib +# from fast_llm.data.dataset.memmap import MemmapDataset +# from fast_llm.data.sample.language_model import LanguageModelSample +# from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig +# +# dataset = MemmapDataset[LanguageModelSample]( +# 'tulu3', +# pathlib.Path('/path/to/tulu3-prepared/shard_0_0.fast_llm_dataset'), +# LanguageModelPreprocessingConfig(use_loss_masking_spans=True) +# ) +# +# doc = dataset.get_document(0) +# print(f'Tokens: {len(doc.tokens.tokens)}') +# print(f'Loss masking spans: {doc.loss_masking_spans.ranges}') +# +# ============================================================================= + +# Dataset configuration +dataset: + # Tulu 3 SFT mixture from AllenAI + path: allenai/tulu-3-sft-mixture + split: train + + # Source schema for conversation format + source_schema: + # Use conversation type (vs default "document" type) + type: conversation + + # Column containing the messages list + messages: messages + +# Tokenizer configuration +# IMPORTANT: Must use a tokenizer with {% generation %} markers in its chat template. +# See instructions above to create a patched Qwen2 tokenizer. +tokenizer: + path: /path/to/qwen2-instruct-with-markers + # Qwen2 doesn't have a BOS token by default, use <|endoftext|> as BOS + bos_token: "<|endoftext|>" + +# Output configuration +output_path: /path/to/tulu3-prepared + +# Processing configuration +num_workers: 8 +documents_per_shard: 100000 diff --git a/fast_llm_external_models/apriel2/examples/train_supernet_qwen2.yaml b/fast_llm_external_models/apriel2/examples/train_supernet_qwen2.yaml new file mode 100644 index 000000000..5b190955f --- /dev/null +++ b/fast_llm_external_models/apriel2/examples/train_supernet_qwen2.yaml @@ -0,0 +1,193 @@ +# Training config for Qwen2-based Apriel2 stochastic supernet on Tulu 3 SFT data +# +# This config trains a stochastic supernet where each layer can sample from +# multiple mixer types (attention, sliding window, gated delta net, KDA). +# Only the mixer weights are trained; all other weights are frozen. +# Activation-level distillation from a teacher model guides the training. +# +# ============================================================================= +# PREREQUISITES +# ============================================================================= +# +# 1. TOKENIZER SETUP +# +# Qwen2's default chat template doesn't have generation markers needed for +# loss masking. Create a patched tokenizer following the SmolLM3 pattern: +# https://huggingface.co/HuggingFaceTB/SmolLM3-3B/blob/main/chat_template.jinja +# +# IMPORTANT: The ENTIRE assistant turn (opening tag + content + closing tag) +# must be inside {% generation %}...{% endgeneration %} markers. +# +# from transformers import AutoTokenizer +# tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") +# # Wrap entire assistant turn in generation markers (NOT just content!) +# tokenizer.chat_template = '''{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system +# You are a helpful assistant.<|im_end|> +# ' }}{% endif %}{% if message['role'] == 'assistant' %}{% generation %}{{ '<|im_start|>assistant +# ' + message['content'] + '<|im_end|> +# ' }}{% endgeneration %}{% else %}{{ '<|im_start|>' + message['role'] + ' +# ' + message['content'] + '<|im_end|> +# ' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant +# ' }}{% endif %}''' +# tokenizer.save_pretrained("/path/to/qwen2-instruct-with-markers") +# +# 2. PREPARE TULU 3 DATASET +# +# Small dataset (for testing): +# +# fast-llm prepare gpt_memmap \ +# -c fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml \ +# tokenizer.path=/path/to/qwen2-instruct-with-markers \ +# dataset.split=train[:1000] \ +# output_path=/path/to/tulu3-prepared-small +# +# Full dataset (~939K samples, ~6 minutes): +# +# fast-llm prepare gpt_memmap \ +# -c fast_llm_external_models/apriel2/examples/prepare_tulu3.yaml \ +# tokenizer.path=/path/to/qwen2-instruct-with-markers \ +# output_path=/path/to/tulu3-prepared +# +# 3. CONVERT QWEN2 TO APRIEL2 SUPERNET (student model) +# +# This creates a stochastic supernet with multiple mixer types per layer: +# +# python fast_llm_external_models/apriel2/convert.py \ +# Qwen/Qwen2.5-0.5B-Instruct \ +# /path/to/qwen2-supernet \ +# --surgery fast_llm_external_models/apriel2/examples/stochastic_supernet.yaml +# +# 4. CONVERT QWEN2 TO APRIEL2 (teacher model) +# +# The teacher is the original model without surgery, used for distillation: +# +# python fast_llm_external_models/apriel2/convert.py \ +# Qwen/Qwen2.5-0.5B-Instruct \ +# /path/to/qwen2-teacher +# +# 5. RUN TRAINING +# +# Update paths below and run: +# +# fast-llm train gpt \ +# -c fast_llm_external_models/apriel2/examples/train_supernet_qwen2.yaml +# +# For long runs, use nohup: +# +# nohup fast-llm train gpt \ +# -c fast_llm_external_models/apriel2/examples/train_supernet_qwen2.yaml \ +# > training.log 2>&1 & +# tail -f training.log +# +# ============================================================================= +# PERFORMANCE TUNING +# ============================================================================= +# +# Default config uses seq=4096, micro_batch=2, batch=16 which gives: +# - ~8k tokens/s/gpu throughput +# - ~61GB GPU memory usage +# - ~25 hours for 1B tokens on single GPU +# +# Adjust batch settings based on your GPU memory: +# - Reduce micro_batch_size if OOM +# - Increase micro_batch_size/batch_size if memory available +# +# ============================================================================= +# OUTPUT +# ============================================================================= +# +# Checkpoints: /path/to/qwen2-supernet-trained/checkpoints/{iteration}/ +# Exports: /path/to/qwen2-supernet-trained/export/apriel2_text/{iteration}/ +# +# ============================================================================= + +# Load pretrained model (Qwen2 converted to Apriel2 supernet) +pretrained: + path: /path/to/qwen2-supernet + format: apriel2_text + model_weights: true + load_config: model + +# Model config +model: + base_model: + # Freeze all components except the mixer + decoder: + block: + mlp: + lr_scale: 0.0 # Freeze MLP + normalization: + lr_scale: 0.0 # Freeze layer norms + # Activation-level distillation from teacher + distillation_model: teacher + activation_distillation_factor: 0.8 + embeddings: + lr_scale: 0.0 # Freeze word embeddings + head: + lr_scale: 0.0 # Freeze output head + cross_entropy_implementation: torch + multi_stage: + zero_stage: 2 + distributed: + compute_dtype: bf16 + seed: 42 + +# Teacher model for activation-level distillation +reference_models: + teacher: + model: + type: gpt + pretrained: + path: /path/to/qwen2-teacher + format: apriel2_text + model_weights: true + load_config: model + +# Batch configuration (tuned for ~61GB GPU memory, ~8k tokens/s) +batch: + sequence_length: 4096 + micro_batch_size: 2 + batch_size: 16 + +# Data configuration (prepared Tulu 3 dataset) +data: + datasets: + training: + type: file + path: /path/to/tulu3-prepared/fast_llm_config.yaml + +# Optimizer configuration +optimizer: + learning_rate: + base: 1.0e-05 + decay_style: cosine + warmup_iterations: 100 + decay_iterations: 10000 + minimum: 1.0e-06 + weight_decay: 0.1 + beta_1: 0.9 + beta_2: 0.95 + +# Training configuration +# At seq=4096, batch=16: ~65k tokens/iter, ~280 iters/hour +# 10000 iters ≈ 650M tokens ≈ 35 hours +training: + train_iters: 10000 + num_workers: 4 + logs: + interval: 10 + checkpoint: + interval: 280 # ~hourly + export: + interval: 280 # ~hourly (useful for development/testing during training) + format: apriel2_text + test_iters: 0 + evaluators: {} + # Weights & Biases configuration (optional, uncomment to enable) + # wandb: + # entity_name: your-entity + # project_name: your-project + +# Experiment directory +run: + experiment_dir: /path/to/qwen2-supernet-trained diff --git a/fast_llm_external_models/apriel2/examples/train_supernet_small.yaml b/fast_llm_external_models/apriel2/examples/train_supernet_small.yaml index 78c22e57f..be4d06e0a 100644 --- a/fast_llm_external_models/apriel2/examples/train_supernet_small.yaml +++ b/fast_llm_external_models/apriel2/examples/train_supernet_small.yaml @@ -107,7 +107,7 @@ model: lr_scale: 0.0 # Freeze layer norms (norm_1 and norm_2 in each block) # Activation-level distillation: teach mixers to mimic teacher's attention outputs distillation_model: teacher - activation_distillation_factor: 0.1 + activation_distillation_factor: 0.8 embeddings: lr_scale: 0.0 # Freeze word embeddings head: diff --git a/fast_llm_external_models/apriel2/modeling_apriel2.py b/fast_llm_external_models/apriel2/modeling_apriel2.py index 4c263b4e2..240240cd6 100644 --- a/fast_llm_external_models/apriel2/modeling_apriel2.py +++ b/fast_llm_external_models/apriel2/modeling_apriel2.py @@ -24,8 +24,8 @@ is_torch_flex_attn_available, ) -from fast_llm_external_models.apriel2.cache import Apriel2Cache -from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config, Apriel2TextConfig +from .cache import Apriel2Cache +from .configuration_apriel2 import Apriel2Config, Apriel2TextConfig # GDN implementation - matches Fast-LLM's gdn.py exactly try: @@ -395,14 +395,30 @@ def __init__(self, d_model: int, mixer_config: dict, layer_idx: int, config): # cross_document_attention: if False, use cu_seqlens to isolate sequences (e.g., images) self.cross_document_attention = mixer_config.get("cross_document_attention", True) - # Whether to add biases to linear projections - add_bias = mixer_config.get("add_linear_biases", False) - - # Projections (Fast-LLM weight names: q_proj, k_proj, v_proj, o_proj) - self.q_proj = nn.Linear(d_model, self.num_heads * self.head_dim, bias=add_bias) - self.k_proj = nn.Linear(d_model, self.num_key_value_heads * self.head_dim, bias=add_bias) - self.v_proj = nn.Linear(d_model, self.num_key_value_heads * self.head_dim, bias=add_bias) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, d_model, bias=add_bias) + # Bias configuration mirroring Fast-LLM's structure: + # - add_linear_biases: bool (default for all projections) + # - query_layer: {"bias": {"enabled": bool}} (per-layer override) + # - key_layer: {"bias": {"enabled": bool}} + # - value_layer: {"bias": {"enabled": bool}} + # - dense_layer: {"bias": {"enabled": bool}} + default_bias = mixer_config.get("add_linear_biases", False) + + def get_layer_bias(layer_name: str) -> bool: + layer_cfg = mixer_config.get(layer_name, {}) + bias_cfg = layer_cfg.get("bias", {}) + enabled = bias_cfg.get("enabled") + return default_bias if enabled is None else enabled + + q_bias = get_layer_bias("query_layer") + k_bias = get_layer_bias("key_layer") + v_bias = get_layer_bias("value_layer") + o_bias = get_layer_bias("dense_layer") + + # Projections + self.q_proj = nn.Linear(d_model, self.num_heads * self.head_dim, bias=q_bias) + self.k_proj = nn.Linear(d_model, self.num_key_value_heads * self.head_dim, bias=k_bias) + self.v_proj = nn.Linear(d_model, self.num_key_value_heads * self.head_dim, bias=v_bias) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, d_model, bias=o_bias) @classmethod def setup( @@ -1017,6 +1033,8 @@ def torch_chunk_gated_delta_rule( if not output_final_state: last_recurrent_state = None + elif last_recurrent_state is not None: + last_recurrent_state = last_recurrent_state.to(initial_dtype) core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1]) core_attn_out = core_attn_out[:, :, :sequence_length] core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) @@ -1225,7 +1243,9 @@ def forward(self, hidden_states: torch.Tensor, past_key_values=None, attention_m mixed_qkv = self.convolution.update( mixed_qkv.squeeze(2), # [batch, conv_dim, 1] -> [batch, conv_dim] conv_state, - ).unsqueeze(2) # [batch, conv_dim] -> [batch, conv_dim, 1] + ).unsqueeze( + 2 + ) # [batch, conv_dim] -> [batch, conv_dim, 1] else: # Prefill mode use_cache = past_key_values is not None @@ -1270,8 +1290,14 @@ def forward(self, hidden_states: torch.Tensor, past_key_values=None, attention_m output_final_state=past_key_values is not None, use_qk_l2norm_in_kernel=True, ) + # Ensure state is in same dtype as hidden_states (fla kernel may return float32) + if last_recurrent_state is not None: + last_recurrent_state = last_recurrent_state.to(hidden_states.dtype) else: # Recurrent mode for single token decode + # Convert recurrent_state to match hidden_states dtype if needed + if recurrent_state is not None and recurrent_state.dtype != hidden_states.dtype: + recurrent_state = recurrent_state.to(hidden_states.dtype) output, last_recurrent_state = self._recurrent_gated_delta_rule( query, key, value, g, beta_gate, recurrent_state ) @@ -1294,7 +1320,16 @@ def forward(self, hidden_states: torch.Tensor, past_key_values=None, attention_m return (output,) def _recurrent_gated_delta_rule(self, query, key, value, g, beta, state): - """Single-step recurrent update for cached inference.""" + """Single-step recurrent update for cached inference. + + Input shapes: [batch, seq=1, heads, dim] + Need shapes: [batch, heads, dim] for einsum operations + """ + # Transpose from [batch, seq, heads, dim] to [batch, heads, seq, dim] + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + # L2 normalize query and key query = _l2norm(query, dim=-1, eps=1e-6) key = _l2norm(key, dim=-1, eps=1e-6) @@ -1307,7 +1342,9 @@ def _recurrent_gated_delta_rule(self, query, key, value, g, beta, state): beta = beta.squeeze(1) # Update state: S = exp(g) * S + beta * k^T @ v - decay = g.exp().unsqueeze(-1).unsqueeze(-1) # [batch, heads, 1, 1] + # Keep everything in the same dtype as input (exp() returns float32, need to convert back) + input_dtype = query.dtype + decay = g.exp().to(input_dtype).unsqueeze(-1).unsqueeze(-1) # [batch, heads, 1, 1] k_outer_v = torch.einsum("bhk,bhv->bhkv", key * beta.unsqueeze(-1), value) state = decay * state + k_outer_v @@ -1315,6 +1352,12 @@ def _recurrent_gated_delta_rule(self, query, key, value, g, beta, state): output = torch.einsum("bhk,bhkv->bhv", query, state) output = output.unsqueeze(2) # [batch, heads, 1, v_dim] + # Transpose back to [batch, seq=1, heads, v_dim] + output = output.transpose(1, 2) + + # Ensure state matches output dtype + state = state.to(output.dtype) + return output, state @classmethod @@ -1447,9 +1490,7 @@ def __init__( # Normalization - use GatedRMSNormalization (same wrapper as GDN, with sigmoid activation) self.norm = GatedRMSNormalization(self.head_dim, eps=self.norm_eps, activation=self.norm_activation) - def _apply_conv( - self, x: torch.Tensor, conv: CausalConv1d, conv_state: torch.Tensor | None, use_cache: bool - ): + def _apply_conv(self, x: torch.Tensor, conv: CausalConv1d, conv_state: torch.Tensor | None, use_cache: bool): """ Apply causal convolution with cache support. @@ -1828,16 +1869,36 @@ def __init__( self.post_attention_layernorm = self._create_norm(norm_config, hidden_size, rms_norm_eps) def _create_mlp(self, mlp_config: dict, hidden_size: int): - """Create MLP based on config.""" + """Create MLP based on config. + + Supports per-layer bias configuration mirroring Fast-LLM: + - add_linear_biases: default bias setting for all layers + - layer_1.bias.enabled: override for up_proj/gate_proj + - layer_2.bias.enabled: override for down_proj + """ mlp_type = mlp_config.get("type", "mlp") if mlp_type == "mlp": intermediate_size = mlp_config["intermediate_size"] activation = mlp_config.get("activation", "silu") - gated = mlp_config["gated"] - bias = mlp_config.get("add_linear_biases", False) + gated = mlp_config.get("gated", False) + + # Per-layer bias configuration (mirrors Fast-LLM structure) + default_bias = mlp_config.get("add_linear_biases", False) + + def get_layer_bias(layer_name: str) -> bool: + layer_cfg = mlp_config.get(layer_name, {}) + bias_cfg = layer_cfg.get("bias", {}) + enabled = bias_cfg.get("enabled") + return default_bias if enabled is None else enabled + + layer_1_bias = get_layer_bias("layer_1") + layer_2_bias = get_layer_bias("layer_2") if gated: + # MistralMLP uses gate_proj, up_proj, down_proj (all bias controlled together) + # For now, we use the default bias setting for gated MLPs + # TODO: Add per-layer bias support to gated MLP mlp_cfg = SimpleNamespace( hidden_size=hidden_size, intermediate_size=intermediate_size, @@ -1845,7 +1906,13 @@ def _create_mlp(self, mlp_config: dict, hidden_size: int): ) return MistralMLP(mlp_cfg) else: - return SimpleMLP(hidden_size, intermediate_size, activation, bias) + return SimpleMLP( + hidden_size, + intermediate_size, + activation, + layer_1_bias=layer_1_bias, + layer_2_bias=layer_2_bias, + ) else: raise ValueError(f"Unknown MLP type: {mlp_type}") @@ -2179,6 +2246,8 @@ def forward( class Apriel2ForCausalLM(Apriel2PreTrainedModel, GenerationMixin): """Apriel2 model with a language modeling head (text-only).""" + _tied_weights_keys = ["lm_head.weight"] + def __init__(self, config: Apriel2TextConfig): super().__init__(config) self.model = Apriel2TextModel(config) @@ -2186,6 +2255,7 @@ def __init__(self, config: Apriel2TextConfig): self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing + # post_init() calls init_weights() which calls tie_weights() if config.tie_word_embeddings self.post_init() def get_input_embeddings(self): @@ -2583,14 +2653,26 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: class SimpleMLP(nn.Module): - """Non-gated MLP: up_proj -> activation -> down_proj.""" + """Non-gated MLP: up_proj -> activation -> down_proj. + + Supports per-layer bias configuration mirroring Fast-LLM: + - layer_1_bias: bias for up_proj (layer_1 in Fast-LLM naming) + - layer_2_bias: bias for down_proj (layer_2 in Fast-LLM naming) + """ - def __init__(self, hidden_size: int, intermediate_size: int, activation: str = "silu", bias: bool = False): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + activation: str = "silu", + layer_1_bias: bool = False, + layer_2_bias: bool = False, + ): super().__init__() from transformers.activations import ACT2FN - self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=bias) - self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=bias) + self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=layer_1_bias) + self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=layer_2_bias) self.act_fn = ACT2FN[activation] def forward(self, x): diff --git a/fast_llm_external_models/tests/test_apriel2/conftest.py b/fast_llm_external_models/tests/test_apriel2/conftest.py index 8585aec65..21b90b097 100644 --- a/fast_llm_external_models/tests/test_apriel2/conftest.py +++ b/fast_llm_external_models/tests/test_apriel2/conftest.py @@ -1,23 +1,44 @@ """Test fixtures for Apriel2 model tests.""" +from collections.abc import Generator from pathlib import Path -from typing import Generator import pytest import torch from transformers import LlavaConfig, LlavaForConditionalGeneration, MistralConfig +from fast_llm_external_models.apriel2.cache import _AttentionCache, _SSMCache + + +# Register custom marks +def pytest_configure(config): + config.addinivalue_line("markers", "slow: marks tests as slow (deselect with '-m \"not slow\"')") + + +def _can_import_fast_llm(): + """Check if Fast-LLM is available.""" + try: + return True + except ImportError: + return False + # Skip marker for tests that require CUDA for Mamba forward pass requires_cuda = pytest.mark.skipif( - not torch.cuda.is_available(), - reason="SSM mixers (Mamba) require CUDA for forward pass" + not torch.cuda.is_available(), reason="SSM mixers (Mamba) require CUDA for forward pass" ) +# Skip marker for tests that require Fast-LLM +requires_fastllm = pytest.mark.skipif(not _can_import_fast_llm(), reason="Fast-LLM not available") -@pytest.fixture(autouse=True) + +@pytest.fixture(scope="module", autouse=True) def set_default_device(): - """Set default device to CUDA for all tests (Mamba requires CUDA).""" + """Set default device to CUDA for all tests (Mamba requires CUDA). + + Module-scoped to ensure it runs before any module-scoped fixtures + that load models (e.g., qwen2_model_and_tokenizer). + """ if torch.cuda.is_available(): old_device = torch.get_default_device() torch.set_default_device("cuda") @@ -27,9 +48,12 @@ def set_default_device(): yield -@pytest.fixture(autouse=True) +@pytest.fixture(scope="module", autouse=True) def set_default_dtype(): - """Set default dtype to float32 for numerical comparison tests.""" + """Set default dtype to float32 for numerical comparison tests. + + Module-scoped to ensure it runs before any module-scoped fixtures. + """ old_dtype = torch.get_default_dtype() torch.set_default_dtype(torch.float32) yield @@ -135,14 +159,11 @@ def model_pair(request, small_pixtral_model, tmp_path): tuple: (source_model, target_model, expected_atol, variant_name) """ import json + from safetensors import safe_open from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config - from fast_llm_external_models.apriel2.conversion import ( - convert_llava_config, - execute, - plan_llava_to_apriel2, - ) + from fast_llm_external_models.apriel2.conversion import convert_llava_config, execute, plan_llava_to_apriel2 from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForConditionalGeneration source = small_pixtral_model @@ -638,12 +659,12 @@ def apriel2_config_comprehensive(): "type": "pattern", "num_blocks": 6, "pattern": [ - "attn", # 0: pure full attention - "swa", # 1: pure sliding window attention - "mamba", # 2: pure mamba - "gdn", # 3: pure gated delta net - "stoch_attn_mamba", # 4: stochastic attention + mamba - "stoch_swa_gdn", # 5: stochastic swa + gated delta net + "attn", # 0: pure full attention + "swa", # 1: pure sliding window attention + "mamba", # 2: pure mamba + "gdn", # 3: pure gated delta net + "stoch_attn_mamba", # 4: stochastic attention + mamba + "stoch_swa_gdn", # 5: stochastic swa + gated delta net ], "blocks": { "attn": { @@ -761,6 +782,52 @@ def apriel2_config_comprehensive(): ) +@pytest.fixture +def apriel2_config_with_bias(): + """Apriel2 config with Qwen-style per-layer bias and non-gated MLP. + + This config exercises: + - Per-layer attention bias (QKV bias enabled, O bias disabled) + - Non-gated MLP with per-layer bias (layer_1 enabled, layer_2 disabled) + - Config structure parity with Fast-LLM's AffineLinearConfig + + Critical for testing bias inheritance through surgery operations. + """ + from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config + + return Apriel2Config( + vocab_size=100, + hidden_size=64, + decoder={ + "type": "fixed", + "num_blocks": 2, + "block": { + "mixer": { + "type": "attention", + "heads": 4, + "head_groups": 2, + "head_size": 16, + "rotary": {"type": "mistral_1d", "theta": 10000.0}, + # Qwen-style: QKV bias enabled, O bias disabled + "query_layer": {"bias": {"enabled": True}}, + "key_layer": {"bias": {"enabled": True}}, + "value_layer": {"bias": {"enabled": True}}, + "dense_layer": {"bias": {"enabled": False}}, + }, + "mlp": { + "type": "mlp", + "intermediate_size": 256, + "gated": False, # Non-gated MLP (SimpleMLP) + # Per-layer MLP bias + "layer_1": {"bias": {"enabled": True}}, + "layer_2": {"bias": {"enabled": False}}, + }, + "normalization": {"type": "rms_norm", "epsilon": 1e-5}, + }, + }, + ) + + @pytest.fixture def apriel2_cache(apriel2_config_tiny): """Create empty Apriel2Cache from tiny config.""" @@ -863,6 +930,77 @@ def additive_surgery_chain(): ] +@pytest.fixture +def bias_surgery_chain(): + """Surgery chain that exercises bias inheritance through surgery operations. + + Designed to be used with apriel2_config_with_bias as the source config. + Tests that per-layer bias settings (Qwen-style QKV bias, non-gated MLP bias) + are correctly inherited through: + - Stochastic wrapper creation + - Adding new sub-mixers that inherit from source + - Cross-type derivation (attention → sliding_window) + + Source config has: + - Attention: query/key/value bias enabled, dense bias disabled + - MLP: layer_1 bias enabled, layer_2 bias disabled (non-gated) + """ + return [ + # S1: Wrap in stochastic - bias should transfer to attention sub-mixer + { + "decoder": { + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"init": "transfer"}, + }, + }, + "mlp": {"init": "transfer"}, + "normalization": {"init": "transfer"}, + }, + }, + }, + # S2: Add sliding_window - should inherit bias from source attention + { + "decoder": { + "block": { + "mixer": { + "mixers": { + "sliding_window": { + "type": "attention", + "init": "transfer", + "window_size": 512, + }, + }, + }, + }, + }, + }, + # S3: Add new attention with DIFFERENT bias config (random init) + { + "decoder": { + "block": { + "mixer": { + "mixers": { + "full_bias_attn": { + "type": "attention", + "init": "random", + "heads": 4, + "head_groups": 2, + "head_size": 16, + "rotary": {"type": "mistral_1d", "theta": 10000.0}, + "add_linear_biases": True, # All biases enabled + }, + }, + }, + }, + }, + }, + ] + + @pytest.fixture def comprehensive_torture_chain(): """Comprehensive torture chain exercising ALL conversion paths. @@ -885,7 +1023,7 @@ def comprehensive_torture_chain(): # MIL requires: d_inner <= Q rows (256), d_xb <= K/V rows (128) mamba_params = { "d_inner": 256, # Must be <= heads*head_size = 256 - "d_xb": 64, # Must be <= head_groups*head_size = 128 + "d_xb": 64, # Must be <= head_groups*head_size = 128 "dt_rank": 16, "d_state": 16, "d_conv": 4, @@ -1532,3 +1670,330 @@ def torture_surgery_chain(): }, }, ] + + +# ============================================================================= +# Shared Config Dict Fixtures (for compose_configs / plan_surgery tests) +# ============================================================================= + + +@pytest.fixture +def base_config_dict(): + """Complete Apriel2 config dict without biases (Llama-style). + + Use this as the base config for testing compose_configs and plan_surgery. + Returns a dict (not Apriel2Config) for direct use with compose_configs. + """ + return { + "model_type": "apriel2", + "hidden_size": 256, + "vocab_size": 1000, + "bos_token_id": 1, + "eos_token_id": 2, + "tie_word_embeddings": False, + "decoder": { + "type": "fixed", + "num_blocks": 2, + "block": { + "mixer": { + "type": "attention", + "heads": 8, + "head_groups": 4, + "head_size": 32, + "rotary": {"type": "mistral_1d", "theta": 10000.0}, + }, + "mlp": {"type": "mlp", "intermediate_size": 512, "gated": True}, + "normalization": {"type": "rms_norm", "epsilon": 1e-5}, + }, + }, + } + + +@pytest.fixture +def base_config_with_bias_dict(): + """Complete Apriel2 config dict with Qwen-style biases. + + - QKV bias enabled, O bias disabled + - Gated MLP (no per-layer bias control in this style) + + Use this for testing bias inheritance through surgery operations. + Returns a dict (not Apriel2Config) for direct use with compose_configs. + """ + return { + "model_type": "apriel2", + "hidden_size": 256, + "vocab_size": 1000, + "bos_token_id": 1, + "eos_token_id": 2, + "tie_word_embeddings": False, + "decoder": { + "type": "fixed", + "num_blocks": 2, + "block": { + "mixer": { + "type": "attention", + "heads": 8, + "head_groups": 4, + "head_size": 32, + "rotary": {"type": "mistral_1d", "theta": 10000.0}, + "query_layer": {"bias": {"enabled": True}}, + "key_layer": {"bias": {"enabled": True}}, + "value_layer": {"bias": {"enabled": True}}, + "dense_layer": {"bias": {"enabled": False}}, + }, + "mlp": {"type": "mlp", "intermediate_size": 512, "gated": True}, + "normalization": {"type": "rms_norm", "epsilon": 1e-5}, + }, + }, + } + + +def make_weights_for_config(config: dict) -> dict: + """Create random weights matching a config's expected schema. + + This is a helper function (not a fixture) for creating test weights. + Use it in tests that need weights for plan execution. + + Args: + config: Complete Apriel2 config dict + + Returns: + Dict mapping weight key strings to torch tensors + """ + from fast_llm_external_models.apriel2.conversion import W + + hidden = config["hidden_size"] + vocab = config["vocab_size"] + decoder = config["decoder"] + num_blocks = decoder["num_blocks"] + block = decoder["block"] + mixer = block["mixer"] + mlp = block["mlp"] + + heads = mixer["heads"] + head_groups = mixer["head_groups"] + head_size = mixer["head_size"] + inter = mlp["intermediate_size"] + + # Check bias settings + has_q_bias = mixer.get("query_layer", {}).get("bias", {}).get("enabled", False) + has_k_bias = mixer.get("key_layer", {}).get("bias", {}).get("enabled", False) + has_v_bias = mixer.get("value_layer", {}).get("bias", {}).get("enabled", False) + + weights = {} + weights["model.embed_tokens.weight"] = torch.randn(vocab, hidden) + + for i in range(num_blocks): + p = f"model.decoder.blocks.{i}" + + # Attention + weights[f"{p}.mixer.q_proj.weight"] = torch.randn(heads * head_size, hidden) + weights[f"{p}.mixer.k_proj.weight"] = torch.randn(head_groups * head_size, hidden) + weights[f"{p}.mixer.v_proj.weight"] = torch.randn(head_groups * head_size, hidden) + weights[f"{p}.mixer.o_proj.weight"] = torch.randn(hidden, heads * head_size) + + if has_q_bias: + weights[f"{p}.mixer.q_proj.bias"] = torch.randn(heads * head_size) + if has_k_bias: + weights[f"{p}.mixer.k_proj.bias"] = torch.randn(head_groups * head_size) + if has_v_bias: + weights[f"{p}.mixer.v_proj.bias"] = torch.randn(head_groups * head_size) + + # MLP + weights[f"{p}.mlp.up_proj.weight"] = torch.randn(inter, hidden) + weights[f"{p}.mlp.gate_proj.weight"] = torch.randn(inter, hidden) + weights[f"{p}.mlp.down_proj.weight"] = torch.randn(hidden, inter) + + # Norms + weights[f"{p}.input_layernorm.weight"] = torch.randn(hidden) + weights[f"{p}.post_attention_layernorm.weight"] = torch.randn(hidden) + + weights["model.norm.weight"] = torch.randn(hidden) + weights["lm_head.weight"] = torch.randn(vocab, hidden) + + return {W(k): v for k, v in weights.items()} + + +# ============================================================================= +# Cache Test Fixtures - Tensor Dimensions +# ============================================================================= + + +@pytest.fixture +def batch_size(): + """Default batch size for cache tests.""" + return 2 + + +@pytest.fixture +def num_heads(): + """Default number of attention heads for cache tests.""" + return 4 + + +@pytest.fixture +def head_dim(): + """Default head dimension for cache tests.""" + return 16 + + +@pytest.fixture +def make_kv(batch_size, num_heads, head_dim): + """Factory fixture for creating KV tensors.""" + + def _make_kv(seq_len): + return ( + torch.randn(batch_size, num_heads, seq_len, head_dim), + torch.randn(batch_size, num_heads, seq_len, head_dim), + ) + + return _make_kv + + +# ============================================================================= +# Cache Test Fixtures - HuggingFace Cache Layers +# ============================================================================= + + +@pytest.fixture +def hf_dynamic_layer(): + """HuggingFace DynamicLayer for full attention contract testing.""" + from transformers.cache_utils import DynamicLayer + + return DynamicLayer() + + +@pytest.fixture +def hf_sliding_layer(window_size): + """HuggingFace DynamicSlidingWindowLayer for sliding window contract testing.""" + from transformers.cache_utils import DynamicSlidingWindowLayer + + return DynamicSlidingWindowLayer(sliding_window=window_size) + + +# ============================================================================= +# Cache Test Fixtures - Apriel2 Low-level Caches +# ============================================================================= + + +@pytest.fixture +def apriel_attention_cache(): + """Apriel2 attention cache without window (full attention).""" + return _AttentionCache(window=None) + + +@pytest.fixture +def apriel_sliding_cache(window_size): + """Apriel2 attention cache with sliding window.""" + return _AttentionCache(window=window_size) + + +@pytest.fixture +def ssm_cache(): + """Apriel2 SSM cache for Mamba/GDN/KDA layers.""" + return _SSMCache() + + +# ============================================================================= +# Cache Test Fixtures - Apriel2 Configs (Simple Versions) +# ============================================================================= + + +@pytest.fixture +def attention_config(): + """Pure attention config (2 layers, no sliding window).""" + from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config + + return Apriel2Config( + vocab_size=100, + hidden_size=64, + decoder={ + "type": "fixed", + "num_blocks": 2, + "block": { + "mixer": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16}, + "mlp": {"type": "mlp", "intermediate_size": 256}, + "normalization": {"type": "rms_norm", "epsilon": 1e-5}, + }, + }, + ) + + +@pytest.fixture +def swa_config(): + """Sliding window attention config (2 layers, window=8).""" + from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config + + return Apriel2Config( + vocab_size=100, + hidden_size=64, + decoder={ + "type": "fixed", + "num_blocks": 2, + "block": { + "mixer": { + "type": "attention", + "heads": 4, + "head_groups": 2, + "head_size": 16, + "window_size": 8, + }, + "mlp": {"type": "mlp", "intermediate_size": 256}, + "normalization": {"type": "rms_norm", "epsilon": 1e-5}, + }, + }, + ) + + +@pytest.fixture +def ssm_config(): + """Pure SSM config (2 layers).""" + from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config + + return Apriel2Config( + vocab_size=100, + hidden_size=64, + decoder={ + "type": "fixed", + "num_blocks": 2, + "block": { + "mixer": {"type": "mamba", "state_size": 16}, + "mlp": {"type": "mlp", "intermediate_size": 256}, + "normalization": {"type": "rms_norm", "epsilon": 1e-5}, + }, + }, + ) + + +@pytest.fixture +def stochastic_config(): + """Stochastic mixer config with attention and mamba (2 layers).""" + from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config + + return Apriel2Config( + vocab_size=100, + hidden_size=64, + decoder={ + "type": "fixed", + "num_blocks": 2, + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16}, + "mamba": {"type": "mamba", "state_size": 16}, + }, + }, + "mlp": {"type": "mlp", "intermediate_size": 256}, + "normalization": {"type": "rms_norm", "epsilon": 1e-5}, + }, + }, + ) + + +# Parameterized window size fixture (used by hf_sliding_layer and apriel_sliding_cache) +@pytest.fixture(params=[4, 8, 16, 32]) +def window_size(request): + """Parameterized window sizes for sliding window tests.""" + return request.param diff --git a/fast_llm_external_models/tests/test_apriel2/test_cache.py b/fast_llm_external_models/tests/test_apriel2/test_cache.py deleted file mode 100644 index ca8158b4f..000000000 --- a/fast_llm_external_models/tests/test_apriel2/test_cache.py +++ /dev/null @@ -1,1258 +0,0 @@ -"""Comprehensive tests for Apriel2Cache. - -Architecture Overview -===================== -Apriel2Cache manages state for autoregressive generation across different mixer types: - -1. **Attention Cache** (_AttentionCache): Stores key/value states - - Supports sliding window (window_size) for SWA - - Efficient roll optimization for single-token decode - -2. **SSM Cache** (_SSMCache): Stores conv and recurrent states - - Used by Mamba, GDN, KDA - - KDA uses tuple conv states (q, k, v), others use single tensor - -3. **Stochastic Mixer Routing**: For layers with multiple mixer options - - Each mixer has independent cache (no sharing) - - active_mixer pointer routes operations to correct sub-cache - - Switching mixers preserves each mixer's independent state - -Cache Invalidation Semantics -============================ -When switching between mixers in a stochastic layer: -- Each mixer maintains its OWN independent history -- Switching does NOT invalidate the previous mixer's cache -- Switching does NOT copy state between mixers -- To invalidate: call reset() explicitly - -This is intentional for training with stochastic sampling where each mixer -should learn from its own history. For inference, main_mixer_name is fixed. - -Test Organization -================= -1. CREATION & PROPERTIES - Cache initialization, config parsing -2. ATTENTION CACHE - Updates, sliding window, concatenation -3. SSM CACHE - Conv states, recurrent states, KDA tuples -4. STOCHASTIC ROUTING - Active mixer, isolation, switching -5. CACHE INVALIDATION - Reset, per-mixer reset, coherence -6. BEAM SEARCH - batch_repeat, reorder, select -7. HF INTEGRATION - get_mask_sizes, indexing, properties -8. GENERATION PATTERNS - Prefill→decode, crop→continue -9. ERROR HANDLING - Guards, bounds, invalid operations -""" - -import pytest -import torch - -from fast_llm_external_models.apriel2.cache import ( - Apriel2Cache, - _AttentionCache, - _SSMCache, -) - - -# ============================================================================= -# FIXTURES - Configs and Sample Data -# ============================================================================= - - -@pytest.fixture -def tiny_attention_config(): - """Minimal config with pure attention layers.""" - from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config - - return Apriel2Config( - vocab_size=100, - hidden_size=64, - decoder={ - "type": "fixed", - "num_blocks": 2, - "block": { - "mixer": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16}, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - }, - ) - - -@pytest.fixture -def swa_config(): - """Config with sliding window attention.""" - from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config - - return Apriel2Config( - vocab_size=100, - hidden_size=64, - decoder={ - "type": "fixed", - "num_blocks": 2, - "block": { - "mixer": { - "type": "attention", - "heads": 4, - "head_groups": 2, - "head_size": 16, - "window_size": 8, # Small for testing - }, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - }, - ) - - -@pytest.fixture -def ssm_config(): - """Config with pure SSM layers (mamba).""" - from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config - - return Apriel2Config( - vocab_size=100, - hidden_size=64, - decoder={ - "type": "fixed", - "num_blocks": 2, - "block": { - "mixer": { - "type": "mamba", - "d_inner": 128, - "d_state": 16, - "dt_rank": 4, - "d_conv": 4, - }, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - }, - ) - - -@pytest.fixture -def kda_config(): - """Config with pure KDA layers.""" - from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config - - return Apriel2Config( - vocab_size=100, - hidden_size=64, - decoder={ - "type": "fixed", - "num_blocks": 2, - "block": { - "mixer": { - "type": "kda", - "heads": 4, - "head_dim": 16, - "convolution_layer": {"kernel_size": 4}, - "normalization": {"epsilon": 1e-5}, - }, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - }, - ) - - -@pytest.fixture -def stochastic_config(): - """Config with stochastic mixer (attention + mamba).""" - from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config - - return Apriel2Config( - vocab_size=100, - hidden_size=64, - decoder={ - "type": "pattern", - "num_blocks": 2, - "pattern": ["attn", "stochastic"], - "blocks": { - "attn": { - "mixer": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16}, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - "stochastic": { - "mixer": { - "type": "stochastic", - "main_mixer_name": "attention", - "mixers": { - "attention": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16}, - "mamba": {"type": "mamba", "d_inner": 128, "d_state": 16, "dt_rank": 4, "d_conv": 4}, - }, - }, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - }, - }, - ) - - -@pytest.fixture -def all_mixers_config(): - """Config with stochastic mixer containing all 5 mixer types.""" - from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config - - return Apriel2Config( - vocab_size=100, - hidden_size=64, - decoder={ - "type": "pattern", - "num_blocks": 2, - "pattern": ["attn", "all_mixers"], - "blocks": { - "attn": { - "mixer": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16}, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - "all_mixers": { - "mixer": { - "type": "stochastic", - "main_mixer_name": "attention", - "mixers": { - "attention": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16}, - "swa": { - "type": "attention", - "heads": 4, - "head_groups": 2, - "head_size": 16, - "window_size": 1024, - }, - "mamba": {"type": "mamba", "d_inner": 128, "d_state": 16, "dt_rank": 4, "d_conv": 4}, - "gdn": { - "type": "gdn", - "value_heads": 4, - "key_heads": 2, - "key_head_dim": 16, - "value_head_dim": 16, - "convolution_layer": {"kernel_size": 4}, - }, - "kda": { - "type": "kda", - "heads": 4, - "head_dim": 16, - "convolution_layer": {"kernel_size": 4}, - "normalization": {"epsilon": 1e-5}, - }, - }, - }, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - }, - }, - ) - - -@pytest.fixture -def multi_window_config(): - """Config with multiple different window sizes.""" - from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config - - return Apriel2Config( - vocab_size=100, - hidden_size=64, - decoder={ - "type": "pattern", - "num_blocks": 3, - "pattern": ["full", "small_window", "large_window"], - "blocks": { - "full": { - "mixer": {"type": "attention", "heads": 4, "head_groups": 2, "head_size": 16}, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - "small_window": { - "mixer": { - "type": "attention", - "heads": 4, - "head_groups": 2, - "head_size": 16, - "window_size": 512, - }, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - "large_window": { - "mixer": { - "type": "attention", - "heads": 4, - "head_groups": 2, - "head_size": 16, - "window_size": 2048, - }, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - }, - }, - ) - - -@pytest.fixture -def sample_kv(): - """Sample key/value tensors: [batch=2, heads=4, seq=10, head_dim=16].""" - return torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16) - - -@pytest.fixture -def sample_conv_single(): - """Sample single-tensor conv state: [batch=2, d_inner=128, kernel=4].""" - return torch.randn(2, 128, 4) - - -@pytest.fixture -def sample_conv_tuple(): - """Sample tuple conv state for KDA: (q, k, v) each [batch=2, d=64, kernel=3].""" - return (torch.randn(2, 64, 3), torch.randn(2, 64, 3), torch.randn(2, 64, 3)) - - -@pytest.fixture -def sample_recurrent(): - """Sample recurrent state: [batch=2, heads=4, head_dim=16, d_state=16].""" - return torch.randn(2, 4, 16, 16) - - -# ============================================================================= -# SECTION 1: CACHE CREATION & PROPERTIES -# ============================================================================= - - -class TestCacheCreation: - """Test cache initialization from config.""" - - def test_attention_cache_creation(self, tiny_attention_config): - """Create cache for pure attention config.""" - cache = Apriel2Cache(tiny_attention_config) - - assert len(cache) == 2 - assert cache.mixer_types == ["attention", "attention"] - assert all(isinstance(l, _AttentionCache) for l in cache.layers) - - def test_ssm_cache_creation(self, ssm_config): - """Create cache for pure SSM config.""" - cache = Apriel2Cache(ssm_config) - - assert len(cache) == 2 - assert cache.mixer_types == ["mamba", "mamba"] - assert all(isinstance(l, _SSMCache) for l in cache.layers) - - def test_kda_cache_creation(self, kda_config): - """Create cache for pure KDA config.""" - cache = Apriel2Cache(kda_config) - - assert len(cache) == 2 - assert cache.mixer_types == ["kda", "kda"] - assert all(isinstance(l, _SSMCache) for l in cache.layers) - - def test_stochastic_cache_creation(self, stochastic_config): - """Create cache for stochastic mixer config.""" - cache = Apriel2Cache(stochastic_config) - - assert len(cache) == 2 - # Layer 0: pure attention, Layer 1: stochastic (dict) - assert isinstance(cache.layers[0], _AttentionCache) - assert isinstance(cache.layers[1], dict) - assert set(cache.layers[1].keys()) == {"attention", "mamba"} - - def test_swa_window_captured(self, swa_config): - """Verify sliding window size is captured.""" - cache = Apriel2Cache(swa_config) - - assert cache.layers[0].window == 8 - assert cache.is_sliding == [True, True] - - def test_active_mixers_initialized_none(self, stochastic_config): - """Verify active_mixers starts as None for all layers.""" - cache = Apriel2Cache(stochastic_config) - - assert cache.active_mixers == [None, None] - - -class TestCacheProperties: - """Test cache property accessors.""" - - def test_empty_cache_properties(self, tiny_attention_config): - """Test properties of uninitialized cache.""" - cache = Apriel2Cache(tiny_attention_config) - - assert cache.is_initialized == False - assert cache.has_previous_state == False - assert cache.max_batch_size is None - assert cache.max_cache_len is None - assert cache.is_compileable == False - - def test_is_initialized_attention(self, tiny_attention_config, sample_kv): - """is_initialized detects attention cache.""" - cache = Apriel2Cache(tiny_attention_config) - cache.update(*sample_kv, layer_idx=0) - - assert cache.is_initialized == True - - def test_is_initialized_ssm(self, ssm_config, sample_conv_single): - """is_initialized detects SSM cache.""" - cache = Apriel2Cache(ssm_config) - cache.conv_states[0] = sample_conv_single - - assert cache.is_initialized == True - - def test_has_previous_state_ssm_only(self, ssm_config, sample_conv_single): - """has_previous_state only looks at SSM conv states.""" - cache = Apriel2Cache(ssm_config) - - assert cache.has_previous_state == False - cache.conv_states[0] = sample_conv_single - assert cache.has_previous_state == True - - def test_has_previous_state_ignores_attention(self, tiny_attention_config, sample_kv): - """has_previous_state ignores attention cache.""" - cache = Apriel2Cache(tiny_attention_config) - cache.update(*sample_kv, layer_idx=0) - - # Attention cache is set, but has_previous_state only checks SSM - assert cache.has_previous_state == False - - def test_max_batch_size_from_attention(self, tiny_attention_config, sample_kv): - """max_batch_size from attention cache.""" - cache = Apriel2Cache(tiny_attention_config) - cache.update(*sample_kv, layer_idx=0) - - assert cache.max_batch_size == 2 - - def test_max_batch_size_from_ssm(self, ssm_config, sample_conv_single): - """max_batch_size from SSM cache.""" - cache = Apriel2Cache(ssm_config) - cache.conv_states[0] = sample_conv_single - - assert cache.max_batch_size == 2 - - def test_max_batch_size_from_kda_tuple(self, kda_config, sample_conv_tuple): - """max_batch_size from KDA tuple conv state.""" - cache = Apriel2Cache(kda_config) - cache.conv_states[0] = sample_conv_tuple - - assert cache.max_batch_size == 2 - - def test_max_cache_len_single_window(self, swa_config): - """max_cache_len with single window size.""" - cache = Apriel2Cache(swa_config) - assert cache.max_cache_len == 8 - - def test_max_cache_len_multiple_windows(self, multi_window_config): - """max_cache_len returns minimum window.""" - cache = Apriel2Cache(multi_window_config) - assert cache.max_cache_len == 512 # min(512, 2048) - - def test_max_cache_len_no_windows(self, tiny_attention_config): - """max_cache_len is None when no windows.""" - cache = Apriel2Cache(tiny_attention_config) - assert cache.max_cache_len is None - - def test_is_sliding_mixed(self, multi_window_config): - """is_sliding reflects per-layer window presence.""" - cache = Apriel2Cache(multi_window_config) - assert cache.is_sliding == [False, True, True] - - -# ============================================================================= -# SECTION 2: ATTENTION CACHE OPERATIONS -# ============================================================================= - - -class TestAttentionCacheBasics: - """Test basic attention cache operations.""" - - def test_update_stores_kv(self, tiny_attention_config, sample_kv): - """update() stores key/value states.""" - cache = Apriel2Cache(tiny_attention_config) - key, value = sample_kv - - k_out, v_out = cache.update(key, value, layer_idx=0) - - torch.testing.assert_close(k_out, key) - torch.testing.assert_close(v_out, value) - assert cache.get_seq_length(0) == 10 - - def test_update_concatenates(self, tiny_attention_config, sample_kv): - """Subsequent updates concatenate.""" - cache = Apriel2Cache(tiny_attention_config) - key, value = sample_kv - - cache.update(key, value, layer_idx=0) - k_out, v_out = cache.update(key, value, layer_idx=0) - - assert k_out.shape[-2] == 20 - assert cache.get_seq_length(0) == 20 - - def test_key_value_cache_accessors(self, tiny_attention_config, sample_kv): - """Test key_cache and value_cache accessors.""" - cache = Apriel2Cache(tiny_attention_config) - cache.update(*sample_kv, layer_idx=0) - - assert cache.key_cache[0] is not None - assert cache.value_cache[0] is not None - torch.testing.assert_close(cache.key_cache[0], sample_kv[0]) - - -class TestSlidingWindowAttention: - """Test sliding window attention behavior.""" - - def test_initial_within_window(self, swa_config): - """Initial sequence within window is kept.""" - cache = Apriel2Cache(swa_config) - key = torch.randn(2, 4, 5, 16) # seq=5 < window=8 - value = torch.randn(2, 4, 5, 16) - - cache.update(key, value, layer_idx=0) - - assert cache.get_seq_length(0) == 5 - - def test_initial_exceeds_window(self, swa_config): - """Initial sequence > window is truncated to last window tokens.""" - cache = Apriel2Cache(swa_config) - key = torch.arange(12).float().view(1, 1, 12, 1).expand(2, 4, 12, 16) - value = key.clone() - - k_out, v_out = cache.update(key, value, layer_idx=0) - - assert cache.get_seq_length(0) == 8 - # Should keep tokens 4-11 (last 8) - assert k_out[0, 0, 0, 0].item() == 4.0 - - def test_single_token_roll_path(self, swa_config): - """Single token decode with full window uses efficient roll.""" - cache = Apriel2Cache(swa_config) - - # Fill window exactly - key1 = torch.arange(8).float().view(1, 1, 8, 1).expand(2, 4, 8, 16) - cache.update(key1, key1.clone(), layer_idx=0) - - # Decode single token - key2 = torch.full((2, 4, 1, 16), 8.0) - k_out, _ = cache.update(key2, key2.clone(), layer_idx=0) - - assert cache.get_seq_length(0) == 8 - assert k_out[0, 0, 0, 0].item() == 1.0 # Token 0 rolled out - assert k_out[0, 0, 7, 0].item() == 8.0 # New token at end - - def test_multi_token_cat_slice_path(self, swa_config): - """Multiple tokens use cat+slice path.""" - cache = Apriel2Cache(swa_config) - - # Fill window - key1 = torch.randn(2, 4, 8, 16) - cache.update(key1, key1.clone(), layer_idx=0) - - # Add 3 tokens - key2 = torch.randn(2, 4, 3, 16) - k_out, _ = cache.update(key2, key2.clone(), layer_idx=0) - - assert cache.get_seq_length(0) == 8 - torch.testing.assert_close(k_out[..., -3:, :], key2) - - def test_partial_then_fill_then_overflow(self, swa_config): - """Progressive filling: partial → full → overflow.""" - cache = Apriel2Cache(swa_config) - - cache.update(torch.randn(2, 4, 5, 16), torch.randn(2, 4, 5, 16), layer_idx=0) - assert cache.get_seq_length(0) == 5 - - cache.update(torch.randn(2, 4, 3, 16), torch.randn(2, 4, 3, 16), layer_idx=0) - assert cache.get_seq_length(0) == 8 - - cache.update(torch.randn(2, 4, 2, 16), torch.randn(2, 4, 2, 16), layer_idx=0) - assert cache.get_seq_length(0) == 8 - - def test_contiguous_output(self, swa_config): - """Outputs are contiguous after windowing.""" - cache = Apriel2Cache(swa_config) - - cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0) - cache.update(torch.randn(2, 4, 5, 16), torch.randn(2, 4, 5, 16), layer_idx=0) - - assert cache.layers[0].key.is_contiguous() - assert cache.layers[0].value.is_contiguous() - - -# ============================================================================= -# SECTION 3: SSM CACHE OPERATIONS -# ============================================================================= - - -class TestSSMCacheBasics: - """Test basic SSM cache operations.""" - - def test_conv_states_accessor(self, ssm_config, sample_conv_single): - """Test conv_states accessor.""" - cache = Apriel2Cache(ssm_config) - - cache.conv_states[0] = sample_conv_single - torch.testing.assert_close(cache.conv_states[0], sample_conv_single) - - def test_recurrent_states_accessor(self, ssm_config, sample_recurrent): - """Test recurrent_states accessor.""" - cache = Apriel2Cache(ssm_config) - - cache.recurrent_states[0] = sample_recurrent - torch.testing.assert_close(cache.recurrent_states[0], sample_recurrent) - - def test_ssm_seq_length_always_zero(self, ssm_config, sample_conv_single): - """get_seq_length returns 0 for SSM (no KV cache).""" - cache = Apriel2Cache(ssm_config) - cache.conv_states[0] = sample_conv_single - - assert cache.get_seq_length(0) == 0 - - -class TestKDACache: - """Test KDA-specific cache operations with tuple conv states.""" - - def test_tuple_conv_storage(self, kda_config, sample_conv_tuple): - """KDA stores tuple conv states.""" - cache = Apriel2Cache(kda_config) - - cache.conv_states[0] = sample_conv_tuple - - assert isinstance(cache.conv_states[0], tuple) - assert len(cache.conv_states[0]) == 3 - for i in range(3): - torch.testing.assert_close(cache.conv_states[0][i], sample_conv_tuple[i]) - - def test_tuple_with_recurrent(self, kda_config, sample_conv_tuple, sample_recurrent): - """KDA can have both tuple conv and recurrent states.""" - cache = Apriel2Cache(kda_config) - - cache.conv_states[0] = sample_conv_tuple - cache.recurrent_states[0] = sample_recurrent - - assert isinstance(cache.conv_states[0], tuple) - assert cache.recurrent_states[0] is not None - - def test_has_previous_state_detects_tuple(self, kda_config, sample_conv_tuple): - """has_previous_state works with tuple conv states.""" - cache = Apriel2Cache(kda_config) - - assert cache.has_previous_state == False - cache.conv_states[0] = sample_conv_tuple - assert cache.has_previous_state == True - - -# ============================================================================= -# SECTION 4: STOCHASTIC ROUTING -# ============================================================================= - - -class TestStochasticRouting: - """Test stochastic mixer cache routing.""" - - def test_set_active_mixer(self, stochastic_config): - """set_active_mixer sets the pointer.""" - cache = Apriel2Cache(stochastic_config) - - cache.set_active_mixer(1, "attention") - assert cache.active_mixers[1] == "attention" - - cache.set_active_mixer(1, "mamba") - assert cache.active_mixers[1] == "mamba" - - def test_operations_route_to_active(self, stochastic_config, sample_kv): - """Operations route to currently active mixer.""" - cache = Apriel2Cache(stochastic_config) - - cache.set_active_mixer(1, "attention") - cache.update(*sample_kv, layer_idx=1) - attn_len = cache.get_seq_length(1) - - cache.set_active_mixer(1, "mamba") - mamba_len = cache.get_seq_length(1) - - assert attn_len == 10 - assert mamba_len == 0 # Mamba cache is separate and empty - - def test_each_mixer_independent_cache(self, stochastic_config, sample_kv, sample_conv_single): - """Each mixer maintains independent cache.""" - cache = Apriel2Cache(stochastic_config) - - # Fill attention cache - cache.set_active_mixer(1, "attention") - cache.update(*sample_kv, layer_idx=1) - - # Fill mamba cache - cache.set_active_mixer(1, "mamba") - cache.conv_states[1] = sample_conv_single - - # Both preserved - cache.set_active_mixer(1, "attention") - assert cache.get_seq_length(1) == 10 - - cache.set_active_mixer(1, "mamba") - torch.testing.assert_close(cache.conv_states[1], sample_conv_single) - - -class TestMixerSwitching: - """Test behavior when switching between mixers mid-generation.""" - - def test_switch_preserves_previous_state(self, stochastic_config, sample_kv): - """Switching mixers preserves previous mixer's state.""" - cache = Apriel2Cache(stochastic_config) - - cache.set_active_mixer(1, "attention") - cache.update(*sample_kv, layer_idx=1) - original_key = cache.layers[1]["attention"].key.clone() - - # Switch to mamba, do something - cache.set_active_mixer(1, "mamba") - cache.conv_states[1] = torch.randn(2, 128, 4) - - # Switch back - attention unchanged - cache.set_active_mixer(1, "attention") - torch.testing.assert_close(cache.layers[1]["attention"].key, original_key) - - def test_switch_does_not_copy_state(self, stochastic_config, sample_kv): - """Switching does NOT copy state between mixers.""" - cache = Apriel2Cache(stochastic_config) - - # Fill attention with 10 tokens - cache.set_active_mixer(1, "attention") - cache.update(*sample_kv, layer_idx=1) - - # Switch to mamba - it has NO history from attention - cache.set_active_mixer(1, "mamba") - assert cache.conv_states[1] is None - assert cache.recurrent_states[1] is None - - def test_has_previous_state_checks_all_sub_caches(self, stochastic_config): - """has_previous_state checks ALL sub-caches, not just active.""" - cache = Apriel2Cache(stochastic_config) - - cache.set_active_mixer(1, "mamba") - cache.conv_states[1] = torch.randn(2, 128, 4) - - # Even if we switch away, has_previous_state still detects it - cache.set_active_mixer(1, "attention") - assert cache.has_previous_state == True - - -class TestAllMixerTypes: - """Test cache isolation across all 5 mixer types.""" - - def test_all_five_mixer_types_isolated(self, all_mixers_config): - """All 5 mixer types maintain isolated caches.""" - cache = Apriel2Cache(all_mixers_config) - layer_idx = 1 # Stochastic layer - - # Fill each mixer's cache - cache.set_active_mixer(layer_idx, "attention") - attn_kv = (torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16)) - cache.update(*attn_kv, layer_idx=layer_idx) - - cache.set_active_mixer(layer_idx, "swa") - swa_kv = (torch.randn(2, 4, 5, 16), torch.randn(2, 4, 5, 16)) - cache.update(*swa_kv, layer_idx=layer_idx) - - cache.set_active_mixer(layer_idx, "mamba") - mamba_conv = torch.randn(2, 128, 4) - cache.conv_states[layer_idx] = mamba_conv - - cache.set_active_mixer(layer_idx, "gdn") - gdn_conv = torch.randn(2, 64, 3) - cache.conv_states[layer_idx] = gdn_conv - - cache.set_active_mixer(layer_idx, "kda") - kda_conv = (torch.randn(2, 64, 3), torch.randn(2, 64, 3), torch.randn(2, 64, 3)) - cache.conv_states[layer_idx] = kda_conv - - # Verify all preserved - cache.set_active_mixer(layer_idx, "attention") - assert cache.get_seq_length(layer_idx) == 10 - - cache.set_active_mixer(layer_idx, "swa") - assert cache.get_seq_length(layer_idx) == 5 - - cache.set_active_mixer(layer_idx, "mamba") - torch.testing.assert_close(cache.conv_states[layer_idx], mamba_conv) - - cache.set_active_mixer(layer_idx, "gdn") - torch.testing.assert_close(cache.conv_states[layer_idx], gdn_conv) - - cache.set_active_mixer(layer_idx, "kda") - assert isinstance(cache.conv_states[layer_idx], tuple) - - -# ============================================================================= -# SECTION 5: CACHE INVALIDATION -# ============================================================================= - - -class TestCacheInvalidation: - """Test cache invalidation and reset semantics. - - Key principle: Each mixer maintains independent state. To invalidate: - - reset() clears ALL caches across ALL layers and mixers - - There is no per-mixer reset (by design - each mixer is independent) - """ - - def test_reset_clears_attention(self, tiny_attention_config, sample_kv): - """reset() clears attention cache.""" - cache = Apriel2Cache(tiny_attention_config) - cache.update(*sample_kv, layer_idx=0) - - cache.reset() - - assert cache.is_initialized == False - assert cache.get_seq_length(0) == 0 - - def test_reset_clears_ssm(self, ssm_config, sample_conv_single, sample_recurrent): - """reset() clears SSM cache.""" - cache = Apriel2Cache(ssm_config) - cache.conv_states[0] = sample_conv_single - cache.recurrent_states[0] = sample_recurrent - - cache.reset() - - assert cache.has_previous_state == False - assert cache.conv_states[0] is None - assert cache.recurrent_states[0] is None - - def test_reset_clears_kda_tuple(self, kda_config, sample_conv_tuple): - """reset() clears KDA tuple conv states.""" - cache = Apriel2Cache(kda_config) - cache.conv_states[0] = sample_conv_tuple - - cache.reset() - - assert cache.conv_states[0] is None - - def test_reset_clears_all_stochastic_mixers(self, all_mixers_config): - """reset() clears ALL mixer caches in stochastic layer.""" - cache = Apriel2Cache(all_mixers_config) - layer_idx = 1 - - # Fill all mixers - cache.set_active_mixer(layer_idx, "attention") - cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=layer_idx) - - cache.set_active_mixer(layer_idx, "mamba") - cache.conv_states[layer_idx] = torch.randn(2, 128, 4) - - cache.set_active_mixer(layer_idx, "kda") - cache.conv_states[layer_idx] = (torch.randn(2, 64, 3),) * 3 - - cache.reset() - - # All cleared - assert cache.layers[layer_idx]["attention"].key is None - assert cache.layers[layer_idx]["mamba"].conv is None - assert cache.layers[layer_idx]["kda"].conv is None - - def test_crop_truncates_attention(self, tiny_attention_config, sample_kv): - """crop() truncates attention cache to max_length.""" - cache = Apriel2Cache(tiny_attention_config) - cache.update(*sample_kv, layer_idx=0) - - cache.crop(5) - - assert cache.get_seq_length(0) == 5 - - def test_crop_affects_all_layers(self, tiny_attention_config, sample_kv): - """crop() affects all layers.""" - cache = Apriel2Cache(tiny_attention_config) - cache.update(*sample_kv, layer_idx=0) - cache.update(*sample_kv, layer_idx=1) - - cache.crop(3) - - assert cache.get_seq_length(0) == 3 - assert cache.get_seq_length(1) == 3 - - def test_crop_ignores_ssm(self, ssm_config, sample_conv_single): - """crop() only affects attention, not SSM.""" - cache = Apriel2Cache(ssm_config) - cache.conv_states[0] = sample_conv_single - - cache.crop(5) # Should not crash - - # Conv state unchanged - torch.testing.assert_close(cache.conv_states[0], sample_conv_single) - - -# ============================================================================= -# SECTION 6: BEAM SEARCH OPERATIONS -# ============================================================================= - - -class TestBatchRepeatInterleave: - """Test batch_repeat_interleave for beam search expansion.""" - - def test_repeat_attention(self, tiny_attention_config, sample_kv): - """Repeat attention cache for beam search.""" - cache = Apriel2Cache(tiny_attention_config) - cache.update(*sample_kv, layer_idx=0) - - cache.batch_repeat_interleave(3) - - assert cache.max_batch_size == 6 # 2 * 3 - - def test_repeat_ssm(self, ssm_config, sample_conv_single, sample_recurrent): - """Repeat SSM cache for beam search.""" - cache = Apriel2Cache(ssm_config) - cache.conv_states[0] = sample_conv_single - cache.recurrent_states[0] = sample_recurrent - - cache.batch_repeat_interleave(4) - - assert cache.conv_states[0].shape[0] == 8 # 2 * 4 - assert cache.recurrent_states[0].shape[0] == 8 - - def test_repeat_kda_tuple(self, kda_config, sample_conv_tuple): - """Repeat KDA tuple conv states.""" - cache = Apriel2Cache(kda_config) - cache.conv_states[0] = sample_conv_tuple - - cache.batch_repeat_interleave(3) - - for c in cache.conv_states[0]: - assert c.shape[0] == 6 - - def test_repeat_stochastic_all_mixers(self, all_mixers_config): - """Repeat all mixer caches in stochastic layer.""" - cache = Apriel2Cache(all_mixers_config) - layer_idx = 1 - - cache.set_active_mixer(layer_idx, "attention") - cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=layer_idx) - - cache.set_active_mixer(layer_idx, "mamba") - cache.conv_states[layer_idx] = torch.randn(2, 128, 4) - - cache.batch_repeat_interleave(2) - - cache.set_active_mixer(layer_idx, "attention") - assert cache.layers[layer_idx]["attention"].key.shape[0] == 4 - - cache.set_active_mixer(layer_idx, "mamba") - assert cache.conv_states[layer_idx].shape[0] == 4 - - def test_repeat_skips_none(self, tiny_attention_config): - """Repeat gracefully skips None caches.""" - cache = Apriel2Cache(tiny_attention_config) - # Don't fill anything - - cache.batch_repeat_interleave(3) # Should not crash - - assert cache.max_batch_size is None - - -class TestReorderCache: - """Test reorder_cache for beam search hypothesis selection.""" - - def test_reorder_attention(self, tiny_attention_config, sample_kv): - """Reorder attention cache.""" - cache = Apriel2Cache(tiny_attention_config) - key, value = sample_kv - # Make batches distinguishable - key = torch.arange(2).float().view(2, 1, 1, 1).expand(2, 4, 10, 16) - cache.update(key, key.clone(), layer_idx=0) - - beam_idx = torch.tensor([1, 0]) - cache.reorder_cache(beam_idx) - - assert cache.layers[0].key[0, 0, 0, 0].item() == 1.0 - assert cache.layers[0].key[1, 0, 0, 0].item() == 0.0 - - def test_reorder_ssm(self, ssm_config): - """Reorder SSM cache.""" - cache = Apriel2Cache(ssm_config) - conv = torch.arange(2).float().view(2, 1, 1).expand(2, 128, 4) - cache.conv_states[0] = conv.clone() - - beam_idx = torch.tensor([1, 0]) - cache.reorder_cache(beam_idx) - - assert cache.conv_states[0][0, 0, 0].item() == 1.0 - - def test_reorder_kda_tuple(self, kda_config): - """Reorder KDA tuple conv states.""" - cache = Apriel2Cache(kda_config) - conv_q = torch.arange(2).float().view(2, 1, 1).expand(2, 64, 3) - cache.conv_states[0] = (conv_q.clone(), conv_q.clone(), conv_q.clone()) - - beam_idx = torch.tensor([1, 0]) - cache.reorder_cache(beam_idx) - - for c in cache.conv_states[0]: - assert c[0, 0, 0].item() == 1.0 - - -class TestBatchSelectIndices: - """Test batch_select_indices for beam selection.""" - - def test_select_attention(self, tiny_attention_config, sample_kv): - """Select subset of attention cache.""" - cache = Apriel2Cache(tiny_attention_config) - key = torch.arange(4).float().view(4, 1, 1, 1).expand(4, 4, 10, 16) - cache.update(key, key.clone(), layer_idx=0) - - indices = torch.tensor([0, 3]) - cache.batch_select_indices(indices) - - assert cache.max_batch_size == 2 - assert cache.layers[0].key[0, 0, 0, 0].item() == 0.0 - assert cache.layers[0].key[1, 0, 0, 0].item() == 3.0 - - def test_select_kda_tuple(self, kda_config): - """Select subset of KDA tuple conv states.""" - cache = Apriel2Cache(kda_config) - conv = tuple(torch.arange(4).float().view(4, 1, 1).expand(4, 64, 3).clone() for _ in range(3)) - cache.conv_states[0] = conv - - indices = torch.tensor([1, 2]) - cache.batch_select_indices(indices) - - for c in cache.conv_states[0]: - assert c.shape[0] == 2 - assert c[0, 0, 0].item() == 1.0 - - -# ============================================================================= -# SECTION 7: HUGGINGFACE INTEGRATION -# ============================================================================= - - -class TestGetMaskSizes: - """Test get_mask_sizes() for attention mask computation.""" - - def test_empty_cache(self, tiny_attention_config): - """Mask sizes with empty cache.""" - cache = Apriel2Cache(tiny_attention_config) - cache_position = torch.arange(10) - - kv_length, kv_offset = cache.get_mask_sizes(cache_position, layer_idx=0) - - assert kv_length == 10 - assert kv_offset == 0 - - def test_with_cached_tokens(self, tiny_attention_config, sample_kv): - """Mask sizes with cached tokens.""" - cache = Apriel2Cache(tiny_attention_config) - cache.update(*sample_kv, layer_idx=0) # 10 tokens - - cache_position = torch.arange(5) - kv_length, kv_offset = cache.get_mask_sizes(cache_position, layer_idx=0) - - assert kv_length == 15 # 10 + 5 - assert kv_offset == 10 - - def test_single_token_decode(self, tiny_attention_config, sample_kv): - """Mask sizes for single token decode.""" - cache = Apriel2Cache(tiny_attention_config) - cache.update(*sample_kv, layer_idx=0) - - cache_position = torch.arange(1) - kv_length, kv_offset = cache.get_mask_sizes(cache_position, layer_idx=0) - - assert kv_length == 11 - assert kv_offset == 10 - - def test_ssm_returns_query_only(self, ssm_config, sample_conv_single): - """SSM layers return query_length (no KV cache).""" - cache = Apriel2Cache(ssm_config) - cache.conv_states[0] = sample_conv_single - - cache_position = torch.arange(5) - kv_length, kv_offset = cache.get_mask_sizes(cache_position, layer_idx=0) - - assert kv_length == 5 - assert kv_offset == 0 - - -class TestCacheIndexing: - """Test cache[idx] indexing.""" - - def test_attention_returns_kv(self, tiny_attention_config, sample_kv): - """Indexing attention layer returns (key, value).""" - cache = Apriel2Cache(tiny_attention_config) - cache.update(*sample_kv, layer_idx=0) - - result = cache[0] - - assert isinstance(result, tuple) - torch.testing.assert_close(result[0], sample_kv[0]) - - def test_empty_returns_empty_tensors(self, tiny_attention_config): - """Indexing empty layer returns empty tensors.""" - cache = Apriel2Cache(tiny_attention_config) - - result = cache[0] - - assert result[0].numel() == 0 - assert result[1].numel() == 0 - - def test_ssm_returns_empty(self, ssm_config, sample_conv_single): - """Indexing SSM layer returns empty (no KV).""" - cache = Apriel2Cache(ssm_config) - cache.conv_states[0] = sample_conv_single - - result = cache[0] - - assert result[0].numel() == 0 - - def test_stochastic_attention_returns_kv(self, stochastic_config, sample_kv): - """Indexing stochastic with attention active returns KV.""" - cache = Apriel2Cache(stochastic_config) - cache.set_active_mixer(1, "attention") - cache.update(*sample_kv, layer_idx=1) - - result = cache[1] - - torch.testing.assert_close(result[0], sample_kv[0]) - - -# ============================================================================= -# SECTION 8: GENERATION PATTERNS -# ============================================================================= - - -class TestGenerationPatterns: - """Test real-world generation patterns.""" - - def test_prefill_then_decode(self, tiny_attention_config, sample_kv): - """Prefill with long prompt, then decode token-by-token.""" - cache = Apriel2Cache(tiny_attention_config) - cache.update(*sample_kv, layer_idx=0) # Prefill 10 tokens - - for _ in range(5): - new_kv = (torch.randn(2, 4, 1, 16), torch.randn(2, 4, 1, 16)) - cache.update(*new_kv, layer_idx=0) - - assert cache.get_seq_length(0) == 15 - - def test_crop_then_continue(self, tiny_attention_config, sample_kv): - """Crop old context, continue generation.""" - cache = Apriel2Cache(tiny_attention_config) - cache.update(*sample_kv, layer_idx=0) - cache.update(*sample_kv, layer_idx=0) # 20 tokens - - cache.crop(5) # Keep last 5 - cache.update(torch.randn(2, 4, 3, 16), torch.randn(2, 4, 3, 16), layer_idx=0) - - assert cache.get_seq_length(0) == 8 - - def test_reset_between_generations(self, tiny_attention_config, sample_kv): - """Reset between independent generations.""" - cache = Apriel2Cache(tiny_attention_config) - - # First generation - cache.update(*sample_kv, layer_idx=0) - assert cache.is_initialized == True - - # Reset - cache.reset() - assert cache.is_initialized == False - - # Second generation - cache.update(*sample_kv, layer_idx=0) - assert cache.get_seq_length(0) == 10 - - def test_multi_layer_consistency(self, tiny_attention_config, sample_kv): - """All layers updated consistently.""" - cache = Apriel2Cache(tiny_attention_config) - - for layer_idx in range(2): - cache.update(*sample_kv, layer_idx=layer_idx) - cache.update(torch.randn(2, 4, 1, 16), torch.randn(2, 4, 1, 16), layer_idx=layer_idx) - - for layer_idx in range(2): - assert cache.get_seq_length(layer_idx) == 11 - - -# ============================================================================= -# SECTION 9: ERROR HANDLING -# ============================================================================= - - -class TestErrorHandling: - """Test error conditions and guards.""" - - def test_stochastic_update_without_active_mixer(self, stochastic_config): - """update() on stochastic without active_mixer raises.""" - cache = Apriel2Cache(stochastic_config) - - with pytest.raises(RuntimeError, match="needs active_mixer set"): - cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=1) - - def test_stochastic_accessor_without_active_mixer(self, stochastic_config): - """Accessing stochastic cache without active_mixer raises.""" - cache = Apriel2Cache(stochastic_config) - - with pytest.raises(RuntimeError, match="requires set_active_mixer"): - _ = cache.conv_states[1] - - def test_accessor_error_lists_available_mixers(self, stochastic_config): - """Error message lists available mixers.""" - cache = Apriel2Cache(stochastic_config) - - with pytest.raises(RuntimeError, match="Available mixers:"): - _ = cache.key_cache[1] - - def test_invalid_mixer_name(self, stochastic_config): - """Invalid mixer name raises KeyError on access.""" - cache = Apriel2Cache(stochastic_config) - cache.set_active_mixer(1, "nonexistent") - - with pytest.raises(KeyError): - cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=1) - - def test_layer_idx_out_of_bounds(self, tiny_attention_config): - """Out-of-bounds layer_idx raises IndexError.""" - cache = Apriel2Cache(tiny_attention_config) - - with pytest.raises(IndexError): - cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=999) - - -# ============================================================================= -# SECTION 10: INTERNAL CLASSES -# ============================================================================= - - -class TestAttentionCacheInternal: - """Test internal _AttentionCache class directly.""" - - def test_unbounded_growth(self): - """No window allows unbounded growth.""" - cache = _AttentionCache(window=None) - - for _ in range(10): - cache.update(torch.randn(2, 4, 100, 16), torch.randn(2, 4, 100, 16)) - - assert cache.key.shape[-2] == 1000 - - def test_window_enforced(self): - """Window caps cache size.""" - cache = _AttentionCache(window=50) - - for _ in range(10): - cache.update(torch.randn(2, 4, 100, 16), torch.randn(2, 4, 100, 16)) - - assert cache.key.shape[-2] == 50 - - -class TestSSMCacheInternal: - """Test internal _SSMCache class directly.""" - - def test_initial_none(self): - """Initial states are None.""" - cache = _SSMCache() - - assert cache.conv is None - assert cache.recurrent is None - - def test_stores_tuple(self): - """Can store tuple (for KDA).""" - cache = _SSMCache() - cache.conv = (torch.randn(2, 64, 3),) * 3 - - assert isinstance(cache.conv, tuple) diff --git a/fast_llm_external_models/tests/test_apriel2/test_cache_apriel2_specific.py b/fast_llm_external_models/tests/test_apriel2/test_cache_apriel2_specific.py new file mode 100644 index 000000000..b45779454 --- /dev/null +++ b/fast_llm_external_models/tests/test_apriel2/test_cache_apriel2_specific.py @@ -0,0 +1,341 @@ +"""Tests for Apriel2-specific cache behaviors with no HuggingFace equivalent. + +This module tests features unique to Apriel2Cache that cannot be validated +against upstream HF implementations: + +1. Stochastic mixer routing (switching between attention/SSM per layer) +2. Multi-mixer layer support +3. Error handling and guard rails +4. Beam search operations (batch_repeat, reorder, select) +5. Crop operation + +Fixtures used from conftest.py: + - stochastic_config: Stochastic mixer config with attention and mamba + - attention_config: Pure attention config + - ssm_config: Pure SSM config +""" + +import pytest +import torch + +from fast_llm_external_models.apriel2.cache import Apriel2Cache + +# ============================================================================= +# STOCHASTIC MIXER ROUTING +# ============================================================================= + + +class TestStochasticMixerRouting: + """Test routing operations to correct sub-cache in stochastic layers.""" + + def test_set_active_mixer(self, stochastic_config): + """set_active_mixer updates routing for layer.""" + cache = Apriel2Cache(stochastic_config) + + cache.set_active_mixer(0, "attention") + assert cache.active_mixers[0] == "attention" + + cache.set_active_mixer(0, "mamba") + assert cache.active_mixers[0] == "mamba" + + def test_update_routes_to_active_mixer(self, stochastic_config): + """update() stores in correct sub-cache based on active_mixer.""" + cache = Apriel2Cache(stochastic_config) + + # Route to attention + cache.set_active_mixer(0, "attention") + cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0) + + # Attention sub-cache should have data + assert cache.layers[0]["attention"].key is not None + # Mamba sub-cache should be empty + assert cache.layers[0]["mamba"].conv is None + + def test_each_mixer_has_independent_cache(self, stochastic_config): + """Each mixer in a stochastic layer has its own independent state.""" + cache = Apriel2Cache(stochastic_config) + + # Store in attention + cache.set_active_mixer(0, "attention") + cache.update(torch.randn(2, 4, 5, 16), torch.randn(2, 4, 5, 16), layer_idx=0) + + # Switch to mamba and store + cache.set_active_mixer(0, "mamba") + cache.layers[0]["mamba"].conv = torch.randn(2, 64, 4) + + # Attention data should be unchanged + assert cache.layers[0]["attention"].cumulative_length == 5 + + def test_switching_preserves_all_states(self, stochastic_config): + """Switching active_mixer doesn't clear other mixer's state.""" + cache = Apriel2Cache(stochastic_config) + + # Build up attention state + cache.set_active_mixer(0, "attention") + cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0) + attn_key = cache.layers[0]["attention"].key.clone() + + # Switch to mamba + cache.set_active_mixer(0, "mamba") + + # Attention state preserved + torch.testing.assert_close(cache.layers[0]["attention"].key, attn_key) + + +# ============================================================================= +# ERROR HANDLING +# ============================================================================= + + +class TestErrorHandling: + """Test guard rails and error messages.""" + + def test_update_without_active_mixer_raises(self, stochastic_config): + """update() on stochastic layer without active_mixer raises RuntimeError.""" + cache = Apriel2Cache(stochastic_config) + + with pytest.raises(RuntimeError, match="needs active_mixer set"): + cache.update(torch.randn(2, 4, 5, 16), torch.randn(2, 4, 5, 16), layer_idx=0) + + def test_accessor_without_active_mixer_raises(self, stochastic_config): + """Accessing key_cache/value_cache without active_mixer raises RuntimeError.""" + cache = Apriel2Cache(stochastic_config) + + with pytest.raises(RuntimeError, match="requires set_active_mixer"): + _ = cache.key_cache[0] + + def test_error_message_lists_available_mixers(self, stochastic_config): + """Error message includes list of available mixers.""" + cache = Apriel2Cache(stochastic_config) + + with pytest.raises(RuntimeError, match="attention.*mamba|mamba.*attention"): + _ = cache.key_cache[0] + + +# ============================================================================= +# BEAM SEARCH OPERATIONS +# ============================================================================= + + +class TestBeamSearchOperations: + """Test batch manipulation for beam search.""" + + def test_batch_repeat_interleave_attention(self, attention_config): + """batch_repeat_interleave expands batch dimension.""" + cache = Apriel2Cache(attention_config) + cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0) + + cache.batch_repeat_interleave(3) + + assert cache.layers[0].key.shape[0] == 6 # 2 * 3 + + def test_batch_repeat_interleave_ssm(self, ssm_config): + """batch_repeat_interleave works for SSM caches.""" + cache = Apriel2Cache(ssm_config) + cache.layers[0].conv = torch.randn(2, 64, 4) + + cache.batch_repeat_interleave(3) + + assert cache.layers[0].conv.shape[0] == 6 + + def test_batch_repeat_interleave_kda_tuple(self, ssm_config): + """batch_repeat_interleave handles KDA tuple conv states.""" + cache = Apriel2Cache(ssm_config) + cache.layers[0].conv = (torch.randn(2, 64, 4),) * 3 + + cache.batch_repeat_interleave(3) + + assert cache.layers[0].conv[0].shape[0] == 6 + + def test_reorder_cache_attention(self, attention_config): + """reorder_cache reorders batch dimension.""" + cache = Apriel2Cache(attention_config) + k = torch.arange(4).float().view(4, 1, 1, 1).expand(4, 4, 10, 16) + cache.update(k, k.clone(), layer_idx=0) + + beam_idx = torch.tensor([3, 2, 1, 0]) + cache.reorder_cache(beam_idx) + + # Check reordering + assert cache.layers[0].key[0, 0, 0, 0].item() == 3.0 + assert cache.layers[0].key[3, 0, 0, 0].item() == 0.0 + + def test_batch_select_indices(self, attention_config): + """batch_select_indices selects subset of batch.""" + cache = Apriel2Cache(attention_config) + cache.update(torch.randn(4, 4, 10, 16), torch.randn(4, 4, 10, 16), layer_idx=0) + + indices = torch.tensor([0, 2]) + cache.batch_select_indices(indices) + + assert cache.layers[0].key.shape[0] == 2 + + def test_reorder_cache_ssm_tuple(self, ssm_config): + """reorder_cache handles KDA tuple conv states.""" + cache = Apriel2Cache(ssm_config) + # Create distinguishable tensors for each batch position + conv0 = torch.full((1, 64, 4), 0.0) + conv1 = torch.full((1, 64, 4), 1.0) + conv2 = torch.full((1, 64, 4), 2.0) + cache.layers[0].conv = ( + torch.cat([conv0, conv1, conv2], dim=0), + torch.cat([conv0, conv1, conv2], dim=0), + torch.cat([conv0, conv1, conv2], dim=0), + ) + + beam_idx = torch.tensor([2, 1, 0]) + cache.reorder_cache(beam_idx) + + # Check reordering: batch[0] should now have value 2.0 + assert cache.layers[0].conv[0][0, 0, 0].item() == 2.0 + assert cache.layers[0].conv[0][2, 0, 0].item() == 0.0 + + def test_batch_select_indices_ssm_tuple(self, ssm_config): + """batch_select_indices handles KDA tuple conv states.""" + cache = Apriel2Cache(ssm_config) + cache.layers[0].conv = (torch.randn(4, 64, 4),) * 3 + + indices = torch.tensor([0, 2]) + cache.batch_select_indices(indices) + + assert cache.layers[0].conv[0].shape[0] == 2 + assert cache.layers[0].conv[1].shape[0] == 2 + assert cache.layers[0].conv[2].shape[0] == 2 + + +# ============================================================================= +# CROP OPERATION +# ============================================================================= + + +class TestCropOperation: + """Test cache truncation.""" + + def test_crop_truncates_attention(self, attention_config): + """crop() truncates attention cache.""" + cache = Apriel2Cache(attention_config) + cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0) + + cache.crop(5) + + assert cache.layers[0].key.shape[-2] == 5 + assert cache.get_seq_length(0) == 5 + + def test_crop_affects_all_layers(self, attention_config): + """crop() affects all layers.""" + cache = Apriel2Cache(attention_config) + cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0) + cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=1) + + cache.crop(3) + + assert cache.layers[0].key.shape[-2] == 3 + assert cache.layers[1].key.shape[-2] == 3 + + def test_crop_ignores_ssm(self, ssm_config): + """crop() doesn't affect SSM caches (they don't have seq dimension).""" + cache = Apriel2Cache(ssm_config) + cache.layers[0].conv = torch.randn(2, 64, 4) + + # Should not raise + cache.crop(5) + + # SSM state unchanged + assert cache.layers[0].conv.shape == (2, 64, 4) + + +# ============================================================================= +# CACHE PROPERTIES +# ============================================================================= + + +class TestCacheProperties: + """Test cache property methods.""" + + def test_is_initialized_attention(self, attention_config): + """is_initialized True after update.""" + cache = Apriel2Cache(attention_config) + assert not cache.is_initialized + + cache.update(torch.randn(2, 4, 5, 16), torch.randn(2, 4, 5, 16), layer_idx=0) + assert cache.is_initialized + + def test_is_initialized_ssm(self, ssm_config): + """is_initialized True after setting conv state.""" + cache = Apriel2Cache(ssm_config) + assert not cache.is_initialized + + cache.layers[0].conv = torch.randn(2, 64, 4) + assert cache.is_initialized + + def test_has_previous_state_ssm_only(self, ssm_config): + """has_previous_state checks SSM conv states.""" + cache = Apriel2Cache(ssm_config) + assert not cache.has_previous_state + + cache.layers[0].conv = torch.randn(2, 64, 4) + assert cache.has_previous_state + + def test_has_previous_state_ignores_attention(self, attention_config): + """has_previous_state ignores attention caches.""" + cache = Apriel2Cache(attention_config) + cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0) + + # Attention-only cache returns False for has_previous_state + assert not cache.has_previous_state + + def test_reset_clears_ssm_states(self, ssm_config): + """reset() clears SSM conv and recurrent states.""" + cache = Apriel2Cache(ssm_config) + cache.layers[0].conv = torch.randn(2, 64, 4) + cache.layers[0].recurrent = torch.randn(2, 64, 16) + + cache.reset() + + assert cache.layers[0].conv is None + assert cache.layers[0].recurrent is None + + def test_max_batch_size_from_ssm_tuple(self, ssm_config): + """max_batch_size works with KDA tuple conv states.""" + cache = Apriel2Cache(ssm_config) + cache.layers[0].conv = (torch.randn(3, 64, 4),) * 3 + + assert cache.max_batch_size == 3 + + def test_max_batch_size(self, attention_config): + """max_batch_size returns batch dimension.""" + cache = Apriel2Cache(attention_config) + cache.update(torch.randn(3, 4, 10, 16), torch.randn(3, 4, 10, 16), layer_idx=0) + + assert cache.max_batch_size == 3 + + def test_len_returns_num_layers(self, attention_config): + """__len__ returns number of layers.""" + cache = Apriel2Cache(attention_config) + assert len(cache) == 2 + + +# ============================================================================= +# INDEXING +# ============================================================================= + + +class TestCacheIndexing: + """Test __getitem__ for HF compatibility.""" + + def test_getitem_returns_kv_tuple(self, attention_config): + """cache[idx] returns (key, value) tuple.""" + cache = Apriel2Cache(attention_config) + cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0) + + k, v = cache[0] + assert k.shape == (2, 4, 10, 16) + assert v.shape == (2, 4, 10, 16) + + def test_getitem_empty_returns_empty_tensors(self, attention_config): + """cache[idx] on empty cache returns empty tensors.""" + cache = Apriel2Cache(attention_config) + + k, v = cache[0] + assert k.numel() == 0 + assert v.numel() == 0 diff --git a/fast_llm_external_models/tests/test_apriel2/test_cache_contracts.py b/fast_llm_external_models/tests/test_apriel2/test_cache_contracts.py new file mode 100644 index 000000000..8ceabfb91 --- /dev/null +++ b/fast_llm_external_models/tests/test_apriel2/test_cache_contracts.py @@ -0,0 +1,591 @@ +"""Contract tests for Apriel2Cache against HuggingFace cache implementations. + +This module tests that Apriel2Cache components behave equivalently to their +HuggingFace counterparts. This ensures compatibility with HF's generation +infrastructure (mask creation, beam search, etc.). + +Mapping: + Apriel2 Component HuggingFace Equivalent + ----------------- ---------------------- + _AttentionCache (no window) -> DynamicLayer + _AttentionCache (window) -> DynamicSlidingWindowLayer + _SSMCache -> MambaCache (different interface, same concept) + +Apriel2-specific features (stochastic routing, multi-mixer layers) are tested +separately in test_cache_apriel2_specific.py since they have no HF equivalent. + +Fixtures used from conftest.py: + - batch_size, num_heads, head_dim: Tensor dimensions + - hf_dynamic_layer: HuggingFace DynamicLayer + - hf_sliding_layer: HuggingFace DynamicSlidingWindowLayer (parameterized by window_size) + - apriel_attention_cache: Apriel2 _AttentionCache (no window) + - apriel_sliding_cache: Apriel2 _AttentionCache (with window, parameterized) + - window_size: Parameterized window sizes [4, 8, 16, 32] + - attention_config, swa_config: Apriel2 configs +""" + +import pytest +import torch + +from fast_llm_external_models.apriel2.cache import Apriel2Cache, _AttentionCache + +# ============================================================================= +# SECTION 1: FULL ATTENTION - _AttentionCache vs DynamicLayer +# ============================================================================= + + +class TestFullAttentionContract: + """Test _AttentionCache (no window) matches HuggingFace DynamicLayer. + + DynamicLayer is the standard cache for full causal attention. + We test that our cache produces identical mask parameters. + """ + + # ------------------------------------------------------------------------- + # get_seq_length: Must match exactly for generation to work + # ------------------------------------------------------------------------- + + @pytest.mark.parametrize("seq_len", [1, 5, 10, 50, 100]) + def test_get_seq_length_after_prefill( + self, hf_dynamic_layer, apriel_attention_cache, batch_size, num_heads, head_dim, seq_len + ): + """After prefill, cumulative_length matches HF get_seq_length.""" + key = torch.randn(batch_size, num_heads, seq_len, head_dim) + value = torch.randn(batch_size, num_heads, seq_len, head_dim) + + hf_dynamic_layer.update(key.clone(), value.clone()) + apriel_attention_cache.update(key.clone(), value.clone()) + + assert apriel_attention_cache.cumulative_length == hf_dynamic_layer.get_seq_length() + + @pytest.mark.parametrize("prefill_len", [1, 5, 10]) + @pytest.mark.parametrize("decode_steps", [1, 5, 10, 20]) + def test_get_seq_length_during_decode( + self, hf_dynamic_layer, apriel_attention_cache, batch_size, num_heads, head_dim, prefill_len, decode_steps + ): + """During decode, cumulative_length tracks total tokens seen.""" + # Prefill + key = torch.randn(batch_size, num_heads, prefill_len, head_dim) + value = torch.randn(batch_size, num_heads, prefill_len, head_dim) + hf_dynamic_layer.update(key.clone(), value.clone()) + apriel_attention_cache.update(key.clone(), value.clone()) + + # Decode + for step in range(decode_steps): + key = torch.randn(batch_size, num_heads, 1, head_dim) + value = torch.randn(batch_size, num_heads, 1, head_dim) + hf_dynamic_layer.update(key.clone(), value.clone()) + apriel_attention_cache.update(key.clone(), value.clone()) + + assert ( + apriel_attention_cache.cumulative_length == hf_dynamic_layer.get_seq_length() + ), f"Mismatch at decode step {step}" + + # ------------------------------------------------------------------------- + # get_mask_sizes: Verify HF behavior for documentation + # ------------------------------------------------------------------------- + + @pytest.mark.parametrize("prefill_len", [1, 5, 10]) + @pytest.mark.parametrize("decode_steps", [0, 1, 5, 10]) + def test_hf_mask_sizes_kv_length( + self, hf_dynamic_layer, apriel_attention_cache, batch_size, num_heads, head_dim, prefill_len, decode_steps + ): + """Document HF's kv_length behavior and verify cumulative_length tracks correctly. + + For full attention, kv_length = cumulative_length + query_length. + This test verifies our cache tracks tokens identically to HF. + """ + # Prefill + key = torch.randn(batch_size, num_heads, prefill_len, head_dim) + value = torch.randn(batch_size, num_heads, prefill_len, head_dim) + hf_dynamic_layer.update(key.clone(), value.clone()) + apriel_attention_cache.update(key.clone(), value.clone()) + + # Decode + for _ in range(decode_steps): + key = torch.randn(batch_size, num_heads, 1, head_dim) + value = torch.randn(batch_size, num_heads, 1, head_dim) + hf_dynamic_layer.update(key.clone(), value.clone()) + apriel_attention_cache.update(key.clone(), value.clone()) + + # Verify cumulative_length matches HF + assert apriel_attention_cache.cumulative_length == hf_dynamic_layer.get_seq_length() + + # Verify HF's kv_length follows the expected formula + cache_position = torch.arange(1) # Single token decode + hf_kv_len, hf_kv_offset = hf_dynamic_layer.get_mask_sizes(cache_position) + expected_kv_len = hf_dynamic_layer.get_seq_length() + cache_position.shape[0] + assert hf_kv_len == expected_kv_len + + def test_hf_kv_offset_always_zero(self, hf_dynamic_layer, batch_size, num_heads, head_dim): + """Document that HF DynamicLayer always returns kv_offset=0. + + For full attention, all cached KV pairs map to absolute positions + starting from 0, so kv_offset is always 0. + """ + # Add many tokens + for _ in range(20): + key = torch.randn(batch_size, num_heads, 5, head_dim) + value = torch.randn(batch_size, num_heads, 5, head_dim) + hf_dynamic_layer.update(key.clone(), value.clone()) + + cache_position = torch.arange(1) + _, hf_kv_offset = hf_dynamic_layer.get_mask_sizes(cache_position) + + assert hf_kv_offset == 0, "DynamicLayer always returns kv_offset=0" + + # ------------------------------------------------------------------------- + # update: Output shape and values must match + # ------------------------------------------------------------------------- + + @pytest.mark.parametrize("seq_len", [1, 5, 10]) + def test_update_returns_same_shape( + self, hf_dynamic_layer, apriel_attention_cache, batch_size, num_heads, head_dim, seq_len + ): + """update() returns tensors with matching shapes.""" + key = torch.randn(batch_size, num_heads, seq_len, head_dim) + value = torch.randn(batch_size, num_heads, seq_len, head_dim) + + hf_k, hf_v = hf_dynamic_layer.update(key.clone(), value.clone()) + apr_k, apr_v = apriel_attention_cache.update(key.clone(), value.clone()) + + assert hf_k.shape == apr_k.shape + assert hf_v.shape == apr_v.shape + + def test_update_concatenates_identically( + self, hf_dynamic_layer, apriel_attention_cache, batch_size, num_heads, head_dim + ): + """Multiple updates produce identical concatenated states.""" + # Use deterministic values for comparison + k1 = torch.arange(10).float().view(1, 1, 10, 1).expand(batch_size, num_heads, 10, head_dim) + v1 = k1.clone() + + hf_dynamic_layer.update(k1.clone(), v1.clone()) + apriel_attention_cache.update(k1.clone(), v1.clone()) + + k2 = torch.arange(10, 15).float().view(1, 1, 5, 1).expand(batch_size, num_heads, 5, head_dim) + v2 = k2.clone() + + hf_k, hf_v = hf_dynamic_layer.update(k2.clone(), v2.clone()) + apr_k, apr_v = apriel_attention_cache.update(k2.clone(), v2.clone()) + + torch.testing.assert_close(hf_k, apr_k) + torch.testing.assert_close(hf_v, apr_v) + + +# ============================================================================= +# SECTION 2: SLIDING WINDOW - _AttentionCache vs DynamicSlidingWindowLayer +# ============================================================================= + + +class TestSlidingWindowContract: + """Test _AttentionCache (with window) matches HuggingFace DynamicSlidingWindowLayer. + + DynamicSlidingWindowLayer is used for sliding window attention (e.g., Mistral). + Critical behaviors: + - cumulative_length tracks ALL tokens seen (not just cached) + - kv_offset increases once window is exceeded + - kv_length is capped at window size + + Uses fixtures from conftest.py: + - window_size: parameterized [4, 8, 16, 32] + - hf_sliding_layer: DynamicSlidingWindowLayer + - apriel_sliding_cache: _AttentionCache with window + """ + + # ------------------------------------------------------------------------- + # cumulative_length: Must track total tokens, not cached tokens + # ------------------------------------------------------------------------- + + @pytest.mark.parametrize("prefill_len", [1, 3, 5, 10, 20]) + def test_cumulative_length_matches_after_prefill( + self, hf_sliding_layer, apriel_sliding_cache, batch_size, num_heads, head_dim, prefill_len + ): + """cumulative_length matches HF get_seq_length after prefill.""" + key = torch.randn(batch_size, num_heads, prefill_len, head_dim) + value = torch.randn(batch_size, num_heads, prefill_len, head_dim) + + hf_sliding_layer.update(key.clone(), value.clone()) + apriel_sliding_cache.update(key.clone(), value.clone()) + + assert apriel_sliding_cache.cumulative_length == hf_sliding_layer.get_seq_length() + + def test_cumulative_length_continues_past_window( + self, hf_sliding_layer, apriel_sliding_cache, window_size, batch_size, num_heads, head_dim + ): + """cumulative_length keeps growing even after window is full.""" + total_tokens = window_size * 3 # Way past window + + for i in range(total_tokens): + key = torch.randn(batch_size, num_heads, 1, head_dim) + value = torch.randn(batch_size, num_heads, 1, head_dim) + + hf_sliding_layer.update(key.clone(), value.clone()) + apriel_sliding_cache.update(key.clone(), value.clone()) + + expected = i + 1 + assert apriel_sliding_cache.cumulative_length == expected + assert hf_sliding_layer.get_seq_length() == expected + + # ------------------------------------------------------------------------- + # get_mask_sizes: kv_offset must increase once window is exceeded + # ------------------------------------------------------------------------- + + def test_kv_offset_zero_before_window_full( + self, hf_sliding_layer, apriel_sliding_cache, window_size, batch_size, num_heads, head_dim + ): + """kv_offset is 0 while cumulative < window. + + Before the window is full, kv_offset should be 0 because all cached tokens + correspond to absolute positions starting from 0. + """ + # Add tokens up to window-1 + for i in range(window_size - 1): + key = torch.randn(batch_size, num_heads, 1, head_dim) + value = torch.randn(batch_size, num_heads, 1, head_dim) + + hf_sliding_layer.update(key.clone(), value.clone()) + apriel_sliding_cache.update(key.clone(), value.clone()) + + cache_position = torch.arange(1) + hf_kv_len, hf_kv_offset = hf_sliding_layer.get_mask_sizes(cache_position) + + # Verify HF returns 0 offset before window full + assert hf_kv_offset == 0, f"HF offset should be 0 at step {i}" + # Verify Apriel cache tracks cumulative correctly + assert apriel_sliding_cache.cumulative_length == i + 1 + + def test_kv_offset_increases_after_window_full( + self, hf_sliding_layer, apriel_sliding_cache, window_size, batch_size, num_heads, head_dim + ): + """kv_offset increases once cumulative >= window. + + Once the window is full, the cache discards oldest tokens. kv_offset tracks + which absolute position KV[0] corresponds to. + """ + # Fill to exactly window + for _ in range(window_size): + key = torch.randn(batch_size, num_heads, 1, head_dim) + value = torch.randn(batch_size, num_heads, 1, head_dim) + hf_sliding_layer.update(key.clone(), value.clone()) + apriel_sliding_cache.update(key.clone(), value.clone()) + + cache_position = torch.arange(1) + hf_kv_len, hf_kv_offset = hf_sliding_layer.get_mask_sizes(cache_position) + + # At window boundary, offset should be 1 + assert hf_kv_offset == 1, "HF offset should be 1 at window boundary" + assert apriel_sliding_cache.cumulative_length == window_size + + # Add more tokens and verify offset keeps increasing with HF + for i in range(5): + key = torch.randn(batch_size, num_heads, 1, head_dim) + value = torch.randn(batch_size, num_heads, 1, head_dim) + hf_sliding_layer.update(key.clone(), value.clone()) + apriel_sliding_cache.update(key.clone(), value.clone()) + + hf_kv_len, hf_kv_offset = hf_sliding_layer.get_mask_sizes(cache_position) + + expected_offset = i + 2 + assert hf_kv_offset == expected_offset + assert apriel_sliding_cache.cumulative_length == window_size + i + 1 + + def test_kv_length_capped_at_window( + self, hf_sliding_layer, apriel_sliding_cache, window_size, batch_size, num_heads, head_dim + ): + """kv_length is capped at window size once exceeded. + + For a query of length 1 after the window is full, kv_length = window + (window-1 cached tokens + 1 query token). + """ + # Way past window + for _ in range(window_size * 2): + key = torch.randn(batch_size, num_heads, 1, head_dim) + value = torch.randn(batch_size, num_heads, 1, head_dim) + hf_sliding_layer.update(key.clone(), value.clone()) + apriel_sliding_cache.update(key.clone(), value.clone()) + + cache_position = torch.arange(1) + hf_kv_len, _ = hf_sliding_layer.get_mask_sizes(cache_position) + + # HF returns window (window-1 cached + 1 query) + assert hf_kv_len == window_size + # Verify our cache tracked cumulative correctly + assert apriel_sliding_cache.cumulative_length == window_size * 2 + + # ------------------------------------------------------------------------- + # Full sequence length tracking through generation + # ------------------------------------------------------------------------- + + @pytest.mark.parametrize("prefill_len", [1, 3, 5, 10, 20]) + def test_cumulative_length_tracks_all_tokens( + self, hf_sliding_layer, apriel_sliding_cache, window_size, batch_size, num_heads, head_dim, prefill_len + ): + """cumulative_length tracks total tokens seen through prefill + decode. + + This is the foundation for correct mask size computation. We verify that + our _AttentionCache tracks tokens identically to HuggingFace's DynamicSlidingWindowLayer. + The actual get_mask_sizes computation is tested in TestApriel2CacheIntegration. + """ + # Prefill + key = torch.randn(batch_size, num_heads, prefill_len, head_dim) + value = torch.randn(batch_size, num_heads, prefill_len, head_dim) + hf_sliding_layer.update(key.clone(), value.clone()) + apriel_sliding_cache.update(key.clone(), value.clone()) + + assert apriel_sliding_cache.cumulative_length == hf_sliding_layer.get_seq_length() + + # Decode past window + for i in range(window_size + 10): + key = torch.randn(batch_size, num_heads, 1, head_dim) + value = torch.randn(batch_size, num_heads, 1, head_dim) + hf_sliding_layer.update(key.clone(), value.clone()) + apriel_sliding_cache.update(key.clone(), value.clone()) + + assert ( + apriel_sliding_cache.cumulative_length == hf_sliding_layer.get_seq_length() + ), f"cumulative_length mismatch at step {i}" + + +# ============================================================================= +# SECTION 3: SSM CACHE - _SSMCache vs MambaCache concept +# ============================================================================= + + +class TestSSMCacheContract: + """Document _SSMCache interface and verify basic contract. + + Unlike attention caches which have HF equivalents (DynamicLayer, DynamicSlidingWindowLayer), + SSM caches have no direct HF counterpart with matching interface. HF's MambaCache uses + different methods (update_conv_state, update_ssm_state), so we can't do direct comparison. + + These tests document the interface contract: + 1. `conv` and `recurrent` attributes for storing states + 2. Both support None (lazy initialization) + 3. `conv` can be tuple (for KDA which has separate q/k/v conv states) + + Higher-level operations (reorder, batch_repeat, reset) are tested in + TestBeamSearchOperations in test_cache_apriel2_specific.py. + """ + + def test_conv_state_storage(self, ssm_cache): + """conv attribute stores conv states (batch, intermediate, kernel_size).""" + conv = torch.randn(2, 64, 4) + ssm_cache.conv = conv + torch.testing.assert_close(ssm_cache.conv, conv) + + def test_recurrent_state_storage(self, ssm_cache): + """recurrent attribute stores SSM states (batch, intermediate, state_size).""" + recurrent = torch.randn(2, 64, 16) + ssm_cache.recurrent = recurrent + torch.testing.assert_close(ssm_cache.recurrent, recurrent) + + def test_conv_state_tuple_for_kda(self, ssm_cache): + """conv can be tuple for KDA's separate q/k/v convolutions.""" + conv_tuple = (torch.randn(2, 64, 4), torch.randn(2, 64, 4), torch.randn(2, 64, 4)) + ssm_cache.conv = conv_tuple + assert isinstance(ssm_cache.conv, tuple) + assert len(ssm_cache.conv) == 3 + + def test_initial_states_none(self, ssm_cache): + """States are None initially (lazy initialization pattern).""" + assert ssm_cache.conv is None + assert ssm_cache.recurrent is None + + def test_states_independent(self, ssm_cache): + """conv and recurrent states are independent.""" + ssm_cache.conv = torch.randn(2, 64, 4) + assert ssm_cache.recurrent is None # recurrent unchanged + + ssm_cache.recurrent = torch.randn(2, 64, 16) + assert ssm_cache.conv is not None # conv unchanged + + +# ============================================================================= +# SECTION 4: APRIEL2CACHE INTEGRATION +# ============================================================================= + + +class TestApriel2CacheIntegration: + """Test Apriel2Cache correctly delegates to underlying caches. + + Uses fixtures from conftest.py: + - attention_config: Pure attention config + - swa_config: Sliding window attention config (window=8) + """ + + def test_get_seq_length_matches_dynamic_layer(self, attention_config): + """Apriel2Cache.get_seq_length matches DynamicLayer for full attention.""" + from transformers.cache_utils import DynamicLayer + + cache = Apriel2Cache(attention_config) + hf_layer = DynamicLayer() + + key = torch.randn(2, 4, 10, 16) + value = torch.randn(2, 4, 10, 16) + + cache.update(key.clone(), value.clone(), layer_idx=0) + hf_layer.update(key.clone(), value.clone()) + + assert cache.get_seq_length(0) == hf_layer.get_seq_length() + + def test_get_mask_sizes_matches_dynamic_layer(self, attention_config): + """Apriel2Cache.get_mask_sizes matches DynamicLayer.""" + from transformers.cache_utils import DynamicLayer + + cache = Apriel2Cache(attention_config) + hf_layer = DynamicLayer() + + key = torch.randn(2, 4, 10, 16) + value = torch.randn(2, 4, 10, 16) + + cache.update(key.clone(), value.clone(), layer_idx=0) + hf_layer.update(key.clone(), value.clone()) + + cache_position = torch.arange(1) + hf_kv_len, hf_kv_offset = hf_layer.get_mask_sizes(cache_position) + apr_kv_len, apr_kv_offset = cache.get_mask_sizes(cache_position, layer_idx=0) + + assert apr_kv_len == hf_kv_len + assert apr_kv_offset == hf_kv_offset + + def test_get_mask_sizes_matches_sliding_layer(self, swa_config): + """Apriel2Cache.get_mask_sizes matches DynamicSlidingWindowLayer.""" + from transformers.cache_utils import DynamicSlidingWindowLayer + + cache = Apriel2Cache(swa_config) + hf_layer = DynamicSlidingWindowLayer(sliding_window=8) + + # Fill past window + for _ in range(15): + key = torch.randn(2, 4, 1, 16) + value = torch.randn(2, 4, 1, 16) + cache.update(key.clone(), value.clone(), layer_idx=0) + hf_layer.update(key.clone(), value.clone()) + + cache_position = torch.arange(1) + hf_kv_len, hf_kv_offset = hf_layer.get_mask_sizes(cache_position) + apr_kv_len, apr_kv_offset = cache.get_mask_sizes(cache_position, layer_idx=0) + + assert apr_kv_len == hf_kv_len + assert apr_kv_offset == hf_kv_offset + + def test_reset_clears_cumulative_length(self, attention_config): + """reset() clears cumulative_length (matches DynamicLayer.reset).""" + cache = Apriel2Cache(attention_config) + + cache.update(torch.randn(2, 4, 10, 16), torch.randn(2, 4, 10, 16), layer_idx=0) + assert cache.get_seq_length(0) == 10 + + cache.reset() + assert cache.get_seq_length(0) == 0 + + +# ============================================================================= +# SECTION 5: MASK CORRECTNESS (SEMANTIC TESTS) +# ============================================================================= + + +class TestMaskCorrectness: + """Test that mask parameters produce semantically correct masks. + + These tests verify the END RESULT: masks created with our parameters + allow the correct attention patterns. + """ + + def test_full_attention_decode_can_attend_to_all(self): + """During decode, query can attend to all cached positions.""" + from transformers.masking_utils import causal_mask_function, sdpa_mask + + cache = _AttentionCache(window=None) + + # Prefill + decode + for _ in range(10): + cache.update(torch.randn(1, 1, 1, 16), torch.randn(1, 1, 1, 16)) + + # Mask for decode step + cache_position = torch.tensor([10]) # Position of new token + kv_length = cache.cumulative_length + 1 + kv_offset = 0 + + mask = sdpa_mask( + batch_size=1, + cache_position=cache_position, + kv_length=kv_length, + kv_offset=kv_offset, + mask_function=causal_mask_function, + ) + + if mask is not None: + # Query at position 10 should attend to positions 0-10 + query_mask = mask[0, 0, 0, :] + for kv_idx in range(kv_length): + assert query_mask[kv_idx].item() == True, f"Should attend to position {kv_idx}" + + @pytest.mark.parametrize("window_size", [4, 8, 16]) + def test_sliding_window_decode_respects_window(self, window_size): + """During decode, query only attends within sliding window.""" + from transformers.masking_utils import sdpa_mask, sliding_window_causal_mask_function + + cache = _AttentionCache(window=window_size) + + # Fill way past window + total_tokens = window_size * 2 + for _ in range(total_tokens): + cache.update(torch.randn(1, 1, 1, 16), torch.randn(1, 1, 1, 16)) + + # Mask for decode step + cache_position = torch.tensor([total_tokens]) + cumulative = cache.cumulative_length + kv_offset = max(cumulative - window_size + 1, 0) + kv_length = window_size - 1 + 1 # cached + query + + mask = sdpa_mask( + batch_size=1, + cache_position=cache_position, + kv_length=kv_length, + kv_offset=kv_offset, + mask_function=sliding_window_causal_mask_function(window_size), + ) + + if mask is not None: + query_mask = mask[0, 0, 0, :] + query_pos = cache_position[0].item() + + for kv_idx in range(kv_length): + abs_pos = kv_offset + kv_idx + in_window = abs_pos > query_pos - window_size + causal = abs_pos <= query_pos + expected = in_window and causal + + assert ( + query_mask[kv_idx].item() == expected + ), f"Position {abs_pos}: expected {expected}, got {query_mask[kv_idx].item()}" + + def test_prefill_has_causal_pattern(self): + """During prefill, mask has proper causal (lower triangular) pattern.""" + from transformers.masking_utils import causal_mask_function, sdpa_mask + + cache = _AttentionCache(window=None) + cache.update(torch.randn(1, 1, 5, 16), torch.randn(1, 1, 5, 16)) + + cache_position = torch.arange(5) + kv_length = cache.cumulative_length + kv_offset = 0 + + mask = sdpa_mask( + batch_size=1, + cache_position=cache_position, + kv_length=kv_length, + kv_offset=kv_offset, + mask_function=causal_mask_function, + allow_is_causal_skip=False, # Force mask creation + ) + + if mask is not None: + # Check causal pattern + for q_idx in range(5): + for kv_idx in range(5): + expected = kv_idx <= q_idx + actual = mask[0, 0, q_idx, kv_idx].item() + assert actual == expected, f"q={q_idx}, kv={kv_idx}: expected {expected}" diff --git a/fast_llm_external_models/tests/test_apriel2/test_causal_conv1d.py b/fast_llm_external_models/tests/test_apriel2/test_causal_conv1d.py index ec6abc1d2..0567cd76e 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_causal_conv1d.py +++ b/fast_llm_external_models/tests/test_apriel2/test_causal_conv1d.py @@ -24,7 +24,6 @@ from fast_llm_external_models.apriel2.modeling_apriel2 import CausalConv1d, _causal_conv1d_fn - # ============================================================================= # Fixtures # ============================================================================= @@ -63,6 +62,7 @@ def kernel_size(): def to_device(conv: CausalConv1d, device: str) -> CausalConv1d: """Create a copy of conv on the specified device.""" import copy + return copy.deepcopy(conv).to(device) @@ -71,7 +71,9 @@ def prefill(conv: CausalConv1d, x: torch.Tensor, state: torch.Tensor = None) -> return conv(x, conv_state=state, return_final_state=True) -def decode_sequence(conv: CausalConv1d, tokens: torch.Tensor, state: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: +def decode_sequence( + conv: CausalConv1d, tokens: torch.Tensor, state: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: """Decode multiple tokens one-by-one, return (stacked_outputs, final_state). Args: @@ -223,7 +225,7 @@ def test_chunked_prefill_cpu(self, conv, dim, total_len, chunk_size): outputs = [] state = None for start in range(0, total_len, chunk_size): - chunk = x[:, :, start:start + chunk_size] + chunk = x[:, :, start : start + chunk_size] out, state = prefill(conv, chunk, state) outputs.append(out) @@ -248,7 +250,7 @@ def test_chunked_prefill_cuda(self, conv, dim, total_len, chunk_size): outputs = [] state = None for start in range(0, total_len, chunk_size): - chunk = x[:, :, start:start + chunk_size].cuda() + chunk = x[:, :, start : start + chunk_size].cuda() out, state = prefill(conv_cuda, chunk, state) outputs.append(out) @@ -329,7 +331,7 @@ def test_all_cpu_paths_match(self, conv, dim): outputs = [] state = None for start in range(0, total_len, chunk_size): - chunk = x[:, :, start:start + chunk_size] + chunk = x[:, :, start : start + chunk_size] out, state = prefill(conv, chunk, state) outputs.append(out) path1 = torch.cat(outputs, dim=-1) @@ -374,7 +376,7 @@ def test_all_paths_match_cross_device(self, conv, dim): # CPU chunked outputs, state = [], None for start in range(0, total_len, chunk_size): - out, state = prefill(conv, x[:, :, start:start + chunk_size], state) + out, state = prefill(conv, x[:, :, start : start + chunk_size], state) outputs.append(out) results["cpu_chunked"] = torch.cat(outputs, dim=-1) @@ -393,7 +395,7 @@ def test_all_paths_match_cross_device(self, conv, dim): # CUDA chunked outputs, state = [], None for start in range(0, total_len, chunk_size): - out, state = prefill(conv_cuda, x[:, :, start:start + chunk_size].cuda(), state) + out, state = prefill(conv_cuda, x[:, :, start : start + chunk_size].cuda(), state) outputs.append(out.cpu()) results["cuda_chunked"] = torch.cat(outputs, dim=-1) @@ -431,8 +433,7 @@ def test_all_paths_match_cross_device(self, conv, dim): for name, result in results.items(): tol = tolerances[name] torch.testing.assert_close( - result, reference, atol=tol, rtol=tol, - msg=f"Path '{name}' diverged from reference" + result, reference, atol=tol, rtol=tol, msg=f"Path '{name}' diverged from reference" ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") @@ -468,8 +469,8 @@ def test_long_decode_no_drift(self, conv, dim): # Check no systematic drift (errors shouldn't consistently increase) decode_errors = errors[prefill_len:] - first_half = decode_errors[:len(decode_errors)//2].mean() - second_half = decode_errors[len(decode_errors)//2:].mean() + first_half = decode_errors[: len(decode_errors) // 2].mean() + second_half = decode_errors[len(decode_errors) // 2 :].mean() assert second_half < first_half * 2, "Errors growing over decode steps (drift detected)" diff --git a/fast_llm_external_models/tests/test_apriel2/test_compose_configs.py b/fast_llm_external_models/tests/test_apriel2/test_compose_configs.py index 0bd6ac88d..3413b9d25 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_compose_configs.py +++ b/fast_llm_external_models/tests/test_apriel2/test_compose_configs.py @@ -20,7 +20,7 @@ import yaml from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config -from fast_llm_external_models.apriel2.conversion.config import apply_surgery, compose_configs +from fast_llm_external_models.apriel2.conversion.config import compose_configs class TestComposeConfigsLaws: @@ -75,14 +75,10 @@ def source_config(self): }, } - def test_identity_empty_surgery(self, source_config): - """Law 1: compose_configs(config, {}) == config""" - result = compose_configs(source_config, {}) - assert result == source_config - - def test_identity_none_surgery(self, source_config): - """Law 1: compose_configs(config, None) == config""" - result = compose_configs(source_config, None) + @pytest.mark.parametrize("empty_surgery", [{}, None]) + def test_identity(self, source_config, empty_surgery): + """Law 1: compose_configs(config, empty) == config for empty in [{}, None]""" + result = compose_configs(source_config, empty_surgery) assert result == source_config def test_override_explicit_values(self, source_config): @@ -114,7 +110,7 @@ def test_same_type_inheritance(self, source_config): assert mixer["head_size"] == 32 # Inherited assert mixer["rope_theta"] == 10000.0 # Inherited assert mixer["window_size"] == 512 # Added - assert "init" not in mixer # Stripped by apply_surgery + # init is preserved for plan_surgery to see (stripped only at final output) def test_cross_type_attention_to_gdn(self, source_config): """Law 5: attention→gdn derives GDN dims from attention geometry.""" @@ -239,8 +235,14 @@ def test_null_deletion(self, source_config): assert "vision_encoder" not in result - def test_init_stripped_from_result(self, source_config): - """Verify `init` keys are stripped from final result.""" + def test_init_preserved_for_plan_surgery(self, source_config): + """Verify `init` keys are preserved so plan_surgery can see them. + + The `init` field controls weight initialization (transfer vs random). + It's preserved through composition and only stripped at final output. + """ + from fast_llm_external_models.apriel2.conversion.config import strip_init_fields + surgery = { "decoder": { "block": { @@ -252,20 +254,20 @@ def test_init_stripped_from_result(self, source_config): "gdn": {"type": "gdn", "init": "random", "convolution_layer": {"kernel_size": 4}}, }, }, - "mlp": {"init": "transfer"}, - "normalization": {"init": "transfer"}, }, }, } result = compose_configs(source_config, surgery) - def check_no_init(d, path=""): - assert "init" not in d, f"Found 'init' key at {path}" - for k, v in d.items(): - if isinstance(v, dict): - check_no_init(v, f"{path}.{k}") + # init is preserved in composed config + mixers = result["decoder"]["block"]["mixer"]["mixers"] + assert mixers["attention"].get("init") == "transfer" + assert mixers["gdn"].get("init") == "random" - check_no_init(result) + # strip_init_fields removes them for final output + stripped = strip_init_fields(result) + assert "init" not in stripped["decoder"]["block"]["mixer"]["mixers"]["attention"] + assert "init" not in stripped["decoder"]["block"]["mixer"]["mixers"]["gdn"] def test_init_random_still_inherits_config(self, source_config): """init: random is for weights only - config params still inherited.""" @@ -287,6 +289,212 @@ def test_init_random_still_inherits_config(self, source_config): assert mixer["head_groups"] == 4 assert mixer["window_size"] == 512 + # ========================================================================= + # Monoid Laws: compose_configs forms a monoid action on configs + # ========================================================================= + + def test_surgery_monoid_associativity(self): + """MONOID: merge(merge(A, B), C) == merge(A, merge(B, C)) for partial configs.""" + surgery_a = {"decoder": {"block": {"mixer": {"type": "stochastic", "main_mixer_name": "attention"}}}} + surgery_b = {"decoder": {"block": {"mixer": {"mixers": {"sliding_window": {"window_size": 512}}}}}} + surgery_c = {"decoder": {"block": {"mixer": {"mixers": {"gdn": {"type": "gdn"}}}}}} + + # Left-associated: (A ∘ B) ∘ C + ab_c = compose_configs(compose_configs(surgery_a, surgery_b), surgery_c) + # Right-associated: A ∘ (B ∘ C) + a_bc = compose_configs(surgery_a, compose_configs(surgery_b, surgery_c)) + + assert ab_c == a_bc, "Surgery monoid should be associative" + + @pytest.mark.parametrize("num_surgeries", [2, 3]) + def test_monoid_action_compatibility(self, source_config, num_surgeries): + """MONOID ACTION: apply(apply(c, A), B) == apply(c, merge(A, B)) + + This is the key law: applying surgeries sequentially equals merging first. + Parameterized to test with 2 and 3 surgeries. + """ + surgeries = [ + { + "decoder": { + "block": { + "mixer": {"type": "stochastic", "main_mixer_name": "attention", "mixers": {"attention": {}}} + } + } + }, + {"decoder": {"block": {"mixer": {"mixers": {"sliding_window": {"window_size": 512}}}}}}, + {"decoder": {"block": {"mixer": {"mixers": {"gdn": {"type": "gdn"}}}}}}, + ][:num_surgeries] + + # Sequential: ((c ⊳ A) ⊳ B) ⊳ ... + result_sequential = source_config + for s in surgeries: + result_sequential = compose_configs(result_sequential, s) + + # Merged: c ⊳ (A ∘ B ∘ ...) + merged = surgeries[0] + for s in surgeries[1:]: + merged = compose_configs(merged, s) + result_merged = compose_configs(source_config, merged) + + assert result_sequential == result_merged, f"Monoid action compatibility failed for {num_surgeries} surgeries" + + +class TestBiasConfigInheritance: + """Test per-layer bias inheritance through surgery composition. + + These tests verify that the per-layer bias configuration (mirroring Fast-LLM's + AffineLinearConfig) is correctly inherited through surgery operations: + - query_layer.bias.enabled, key_layer.bias.enabled, etc. for attention + - layer_1.bias.enabled, layer_2.bias.enabled for MLP + """ + + @pytest.fixture + def source_config_with_bias(self): + """Source config with Qwen-style bias (QKV enabled, O disabled).""" + return { + "model_type": "apriel2", + "architectures": ["Apriel2ForCausalLM"], + "hidden_size": 256, + "vocab_size": 1000, + "decoder": { + "type": "fixed", + "num_blocks": 4, + "block": { + "mixer": { + "type": "attention", + "heads": 8, + "head_groups": 4, + "head_size": 32, + "rotary": {"type": "mistral_1d", "theta": 10000.0}, + # Qwen-style per-layer bias + "query_layer": {"bias": {"enabled": True}}, + "key_layer": {"bias": {"enabled": True}}, + "value_layer": {"bias": {"enabled": True}}, + "dense_layer": {"bias": {"enabled": False}}, + }, + "mlp": { + "type": "mlp", + "intermediate_size": 512, + "gated": False, + # Per-layer MLP bias + "layer_1": {"bias": {"enabled": True}}, + "layer_2": {"bias": {"enabled": False}}, + }, + "normalization": {"type": "rms_norm", "epsilon": 1e-5}, + }, + }, + } + + def test_same_type_inherits_attention_bias(self, source_config_with_bias): + """Same-type surgery inherits per-layer attention bias settings.""" + surgery = { + "decoder": { + "block": { + "mixer": { + "window_size": 512, # Add sliding window behavior + }, + }, + }, + } + result = compose_configs(source_config_with_bias, surgery) + + mixer = result["decoder"]["block"]["mixer"] + assert mixer["query_layer"]["bias"]["enabled"] is True + assert mixer["key_layer"]["bias"]["enabled"] is True + assert mixer["value_layer"]["bias"]["enabled"] is True + assert mixer["dense_layer"]["bias"]["enabled"] is False + + def test_same_type_inherits_mlp_bias(self, source_config_with_bias): + """Same-type surgery inherits per-layer MLP bias settings.""" + surgery = { + "decoder": { + "block": { + "mlp": { + "intermediate_size": 1024, # Change size + }, + }, + }, + } + result = compose_configs(source_config_with_bias, surgery) + + mlp = result["decoder"]["block"]["mlp"] + assert mlp["layer_1"]["bias"]["enabled"] is True + assert mlp["layer_2"]["bias"]["enabled"] is False + assert mlp["intermediate_size"] == 1024 + + def test_cross_type_attention_to_sliding_window_preserves_bias(self, source_config_with_bias): + """attention→sliding_window cross-type preserves per-layer bias.""" + surgery = { + "decoder": { + "block": { + "mixer": { + "type": "sliding_window", # Cross-type derivation + "window_size": 512, + }, + }, + }, + } + result = compose_configs(source_config_with_bias, surgery) + + mixer = result["decoder"]["block"]["mixer"] + assert mixer["type"] == "sliding_window" + # Bias settings preserved through cross-type + assert mixer["query_layer"]["bias"]["enabled"] is True + assert mixer["key_layer"]["bias"]["enabled"] is True + assert mixer["value_layer"]["bias"]["enabled"] is True + assert mixer["dense_layer"]["bias"]["enabled"] is False + + def test_stochastic_wrapper_inherits_bias(self, source_config_with_bias): + """Wrapping in stochastic inherits bias settings to all sub-mixers.""" + surgery = { + "decoder": { + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"init": "transfer"}, + "sliding_window": { + "type": "sliding_window", + "window_size": 512, + "init": "transfer", + }, + }, + }, + }, + }, + } + result = compose_configs(source_config_with_bias, surgery) + + mixers = result["decoder"]["block"]["mixer"]["mixers"] + + # Attention sub-mixer inherits bias + assert mixers["attention"]["query_layer"]["bias"]["enabled"] is True + assert mixers["attention"]["dense_layer"]["bias"]["enabled"] is False + + # Sliding window sub-mixer also inherits bias + assert mixers["sliding_window"]["query_layer"]["bias"]["enabled"] is True + assert mixers["sliding_window"]["dense_layer"]["bias"]["enabled"] is False + + def test_surgery_can_override_bias(self, source_config_with_bias): + """Surgery can explicitly override inherited bias settings.""" + surgery = { + "decoder": { + "block": { + "mixer": { + "dense_layer": {"bias": {"enabled": True}}, # Override O bias + }, + }, + }, + } + result = compose_configs(source_config_with_bias, surgery) + + mixer = result["decoder"]["block"]["mixer"] + # Q/K/V unchanged + assert mixer["query_layer"]["bias"]["enabled"] is True + # O bias overridden + assert mixer["dense_layer"]["bias"]["enabled"] is True + class TestComposeConfigsRealYAML: """Test compose_configs with real YAML surgery files.""" @@ -398,160 +606,12 @@ def test_build_plan_returns_complete_config(self, llava_pixtral_checkpoint): mixer = config.decoder["block"]["mixer"] assert mixer["type"] == "stochastic" - # Each sub-mixer should have complete config (no init keys) + # Each sub-mixer should have complete config + # (init is preserved for plan_surgery, stripped only at final output) for name, sub_mixer in mixer["mixers"].items(): - assert "init" not in sub_mixer, f"Sub-mixer {name} still has 'init' key" assert "type" in sub_mixer -class TestMonoidLaws: - """Test the algebraic laws of compose_configs. - - Surgery specs form a MONOID under deep-merge: - - Identity: {} - - Operation: deep merge (overlay wins) - - Associativity: merge(merge(A, B), C) == merge(A, merge(B, C)) - - compose_configs is a MONOID ACTION on configs: - - Identity action: apply(config, {}) == config - - Compatibility: apply(apply(c, A), B) == apply(c, merge(A, B)) - """ - - @pytest.fixture - def complete_config(self): - """A complete Apriel2 config.""" - return { - "model_type": "apriel2", - "architectures": ["Apriel2ForConditionalGeneration"], - "hidden_size": 256, - "vocab_size": 1000, - "bos_token_id": 1, - "eos_token_id": 2, - "tie_word_embeddings": False, - "image_token_index": 100, - "decoder": { - "type": "fixed", - "num_blocks": 4, - "block": { - "mixer": { - "type": "attention", - "heads": 8, - "head_groups": 4, - "head_size": 32, - "rope_theta": 10000.0, - }, - "mlp": {"type": "mlp", "intermediate_size": 512}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - }, - } - - @pytest.fixture - def surgery_a(self): - """First surgery: wrap in stochastic with attention.""" - return { - "decoder": { - "block": { - "mixer": { - "type": "stochastic", - "main_mixer_name": "attention", - "mixers": { - "attention": {"init": "transfer"}, - }, - }, - }, - }, - } - - @pytest.fixture - def surgery_b(self): - """Second surgery: add sliding window mixer.""" - return { - "decoder": { - "block": { - "mixer": { - "mixers": { - "sliding_window": {"init": "transfer", "window_size": 512}, - }, - }, - }, - }, - } - - def test_identity_action(self, complete_config): - """apply(config, {}) == config""" - result = compose_configs(complete_config, {}) - assert result == complete_config - - def test_surgery_monoid_associativity(self, surgery_a, surgery_b): - """merge(merge(A, B), C) == merge(A, merge(B, C)) for partial configs.""" - surgery_c = { - "decoder": { - "block": { - "mixer": { - "mixers": { - "gdn": {"type": "gdn", "init": "transfer", "convolution_layer": {"kernel_size": 4}}, - }, - }, - }, - }, - } - - # Left-associated: (A ∘ B) ∘ C - ab = compose_configs(surgery_a, surgery_b) - ab_c = compose_configs(ab, surgery_c) - - # Right-associated: A ∘ (B ∘ C) - bc = compose_configs(surgery_b, surgery_c) - a_bc = compose_configs(surgery_a, bc) - - assert ab_c == a_bc, "Surgery monoid should be associative" - - def test_monoid_action_compatibility(self, complete_config, surgery_a, surgery_b): - """apply(apply(c, A), B) == apply(c, merge(A, B)) - - This is the key law: applying surgeries sequentially should equal - merging the surgeries first, then applying once. - """ - # Sequential application: (c ⊳ A) ⊳ B - result_sequential = compose_configs(compose_configs(complete_config, surgery_a), surgery_b) - - # Merged application: c ⊳ (A ∘ B) - merged_surgery = compose_configs(surgery_a, surgery_b) - result_merged = compose_configs(complete_config, merged_surgery) - - # These should be equivalent - assert result_sequential == result_merged, "Monoid action should satisfy compatibility law" - - def test_three_way_compatibility(self, complete_config, surgery_a, surgery_b): - """Test with three surgeries for stronger confidence.""" - surgery_c = { - "decoder": { - "block": { - "mixer": { - "mixers": { - "gdn": {"type": "gdn", "init": "transfer", "convolution_layer": {"kernel_size": 4}}, - }, - }, - }, - }, - } - - # Sequential: ((c ⊳ A) ⊳ B) ⊳ C - seq = compose_configs( - compose_configs(compose_configs(complete_config, surgery_a), surgery_b), - surgery_c - ) - - # Merged: c ⊳ ((A ∘ B) ∘ C) - merged = compose_configs( - complete_config, - compose_configs(compose_configs(surgery_a, surgery_b), surgery_c) - ) - - assert seq == merged, "Three-way monoid action should satisfy compatibility" - - class TestCompositionTortureTest: """Comprehensive stress test for config composition. @@ -650,19 +710,29 @@ def test_final_config_structure(self, complete_config, additive_surgery_chain): assert mixer["mixers"]["sliding_window"]["window_size"] == 512 assert mixer["mixers"]["gdn"]["value_heads"] == 16 - def test_no_init_keys_in_result(self, complete_config, additive_surgery_chain): - """Verify no 'init' keys leak through.""" + def test_init_keys_preserved_for_planning(self, complete_config, additive_surgery_chain): + """Verify 'init' keys are preserved for plan_surgery to see. - def check_no_init(d, path=""): - if isinstance(d, dict): - assert "init" not in d, f"Found 'init' key at {path}" - for k, v in d.items(): - check_no_init(v, f"{path}.{k}") + The `init` field is metadata for weight initialization. It's preserved + through composition and only stripped when saving final output. + """ + from fast_llm_external_models.apriel2.conversion.config import strip_init_fields result = complete_config for i, surgery in enumerate(additive_surgery_chain): result = compose_configs(result, surgery) - check_no_init(result, f"step_{i+1}") + + # init should be in the composed config + mixer = result["decoder"]["block"]["mixer"] + if "mixers" in mixer: + has_init = any("init" in m for m in mixer["mixers"].values()) + assert has_init, "init should be preserved in composed config" + + # strip_init_fields removes them + stripped = strip_init_fields(result) + mixer = stripped["decoder"]["block"]["mixer"] + if "mixers" in mixer: + assert all("init" not in m for m in mixer["mixers"].values()) def test_full_torture_chain(self, complete_config, torture_surgery_chain): """Test the full 10-step torture chain produces valid configs.""" diff --git a/fast_llm_external_models/tests/test_apriel2/test_plan_composition_torture.py b/fast_llm_external_models/tests/test_apriel2/test_conversion_e2e.py similarity index 78% rename from fast_llm_external_models/tests/test_apriel2/test_plan_composition_torture.py rename to fast_llm_external_models/tests/test_apriel2/test_conversion_e2e.py index 3b4adc7f5..b91fb7e51 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_plan_composition_torture.py +++ b/fast_llm_external_models/tests/test_apriel2/test_conversion_e2e.py @@ -1,6 +1,6 @@ -"""End-to-end torture test for plan composition. +"""test_conversion_e2e.py - End-to-end conversion integration tests. -This tests the FULL pipeline at every step of a surgery chain: +Tests the FULL pipeline at every step of a surgery chain: 1. Config composition produces valid configs 2. Plan building works for each surgery 3. Plan execution produces valid weights @@ -16,21 +16,12 @@ import pytest import torch -from fast_llm_external_models.tests.test_apriel2.conftest import requires_cuda - from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config -from fast_llm_external_models.apriel2.conversion import ( - compose, - compose_configs, - execute, - plan_surgery, -) -from fast_llm_external_models.apriel2.conversion.llava import ( - convert_config as convert_llava_config, - plan_llava_to_apriel2, -) +from fast_llm_external_models.apriel2.conversion import compose, compose_configs, execute, plan_surgery +from fast_llm_external_models.apriel2.conversion.llava import convert_config as convert_llava_config +from fast_llm_external_models.apriel2.conversion.llava import plan_llava_to_apriel2 from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForConditionalGeneration - +from fast_llm_external_models.tests.test_apriel2.conftest import requires_cuda # ============================================================================= # Cycling Surgery Generation @@ -87,40 +78,20 @@ def generate_cycling_surgeries(config: dict) -> list[tuple[dict, str]]: if sub_name != main_mixer: # Build surgery path based on block_path if block_path == "block": - surgery = { - "decoder": { - "block": {"mixer": {"main_mixer_name": sub_name}} - } - } + surgery = {"decoder": {"block": {"mixer": {"main_mixer_name": sub_name}}}} else: # block_path is "blocks.block_name" block_name = block_path.split(".")[1] - surgery = { - "decoder": { - "blocks": { - block_name: {"mixer": {"main_mixer_name": sub_name}} - } - } - } + surgery = {"decoder": {"blocks": {block_name: {"mixer": {"main_mixer_name": sub_name}}}}} surgeries.append((surgery, f"cycle {block_path} to {sub_name}")) # Restore original main_mixer_name if any(sub_name != main_mixer for sub_name in sub_mixer_names): if block_path == "block": - restore = { - "decoder": { - "block": {"mixer": {"main_mixer_name": main_mixer}} - } - } + restore = {"decoder": {"block": {"mixer": {"main_mixer_name": main_mixer}}}} else: block_name = block_path.split(".")[1] - restore = { - "decoder": { - "blocks": { - block_name: {"mixer": {"main_mixer_name": main_mixer}} - } - } - } + restore = {"decoder": {"blocks": {block_name: {"mixer": {"main_mixer_name": main_mixer}}}}} surgeries.append((restore, f"restore {block_path} to {main_mixer}")) return surgeries @@ -194,9 +165,7 @@ def source_config(self, llava_pixtral_checkpoint): with open(llava_pixtral_checkpoint / "config.json") as f: return json.load(f) - def test_initial_conversion_produces_working_model( - self, source_config, source_weights - ): + def test_initial_conversion_produces_working_model(self, source_config, source_weights): """Test that Llava → Apriel2 conversion produces a working model.""" # Convert config apriel2_config_dict = convert_llava_config(source_config) @@ -219,9 +188,7 @@ def test_initial_conversion_produces_working_model( assert outputs.logits.shape == (1, 8, config.vocab_size) - def test_each_surgery_step_produces_working_model( - self, source_config, source_weights, additive_surgery_chain - ): + def test_each_surgery_step_produces_working_model(self, source_config, source_weights, additive_surgery_chain): """Test that each surgery step produces a model that can forward pass. Key insight: Surgery plans reference Apriel2 keys, so we must COMPOSE @@ -290,9 +257,7 @@ def test_each_surgery_step_produces_working_model( except Exception as e: pytest.fail(f"Step {i+1}: Forward pass failed - {e}") - def test_all_stochastic_submixers_via_cycling( - self, source_config, source_weights, additive_surgery_chain - ): + def test_all_stochastic_submixers_via_cycling(self, source_config, source_weights, additive_surgery_chain): """Test ALL sub-mixers in stochastic blocks, not just the main mixer. Problem: Forward pass only exercises the main_mixer_name. Other sub-mixers @@ -312,9 +277,7 @@ def test_all_stochastic_submixers_via_cycling( conversion_plan = plan_llava_to_apriel2(source_config) # Expand surgery chain with cycling - expanded_chain = expand_surgery_chain_with_cycling( - additive_surgery_chain, apriel2_config - ) + expanded_chain = expand_surgery_chain_with_cycling(additive_surgery_chain, apriel2_config) # Build cumulative plan: conversion | surgery_1 | cycling_1a | ... | restore_1 | surgery_2 | ... current_plan = conversion_plan @@ -359,9 +322,7 @@ def test_all_stochastic_submixers_via_cycling( except Exception as e: pytest.fail(f"{desc}: Forward pass failed - {e}") - def test_composed_plan_equals_sequential_execution( - self, source_config, source_weights, additive_surgery_chain - ): + def test_composed_plan_equals_sequential_execution(self, source_config, source_weights, additive_surgery_chain): """Test that composing plans gives same result as sequential execution. This verifies plan composition associativity: @@ -399,13 +360,9 @@ def test_composed_plan_equals_sequential_execution( # Compare weights for key in seq_weights: if key in composed_weights: - assert torch.allclose( - seq_weights[key], composed_weights[key], atol=1e-5 - ), f"Weight mismatch for {key}" + assert torch.allclose(seq_weights[key], composed_weights[key], atol=1e-5), f"Weight mismatch for {key}" - def test_final_model_structure( - self, source_config, source_weights, additive_surgery_chain - ): + def test_final_model_structure(self, source_config, source_weights, additive_surgery_chain): """Verify the final model has the expected structure.""" # Initial conversion current_config = convert_llava_config(source_config) @@ -504,9 +461,7 @@ def base_setup(self, llava_pixtral_checkpoint): """Set up base config and weights after Llava conversion.""" from safetensors.torch import load_file - from fast_llm_external_models.apriel2.conversion.llava import ( - convert_config as convert_llava_config, - ) + from fast_llm_external_models.apriel2.conversion.llava import convert_config as convert_llava_config # Load source config and weights with open(llava_pixtral_checkpoint / "config.json") as f: @@ -534,9 +489,7 @@ def _merge_surgeries(self, surgeries: list[dict]) -> dict: result = _deep_merge(result, s) return result - def _build_incremental_plans( - self, base_config: dict, surgeries: list[dict] - ) -> tuple[list, list[dict]]: + def _build_incremental_plans(self, base_config: dict, surgeries: list[dict]) -> tuple[list, list[dict]]: """Build incremental plans for each surgery step. Returns (plans, configs) where configs[i] is the config after surgery i. @@ -552,9 +505,7 @@ def _build_incremental_plans( config = target_config return plans, configs - def test_incremental_equals_direct_full_chain( - self, base_setup, additive_surgery_chain - ): + def test_incremental_equals_direct_full_chain(self, base_setup, additive_surgery_chain): """Test that composing all incremental plans equals one direct plan. compose(P1, P2, ..., Pn) ≡ plan_surgery(base, final) @@ -575,9 +526,7 @@ def test_incremental_equals_direct_full_chain( direct_plan = plan_surgery(base_config, final_config) # Verify same target keys - assert set(composed_plan.mappings.keys()) == set( - direct_plan.mappings.keys() - ), "Plan keys should match" + assert set(composed_plan.mappings.keys()) == set(direct_plan.mappings.keys()), "Plan keys should match" # Execute both and compare weights composed_weights = execute(composed_plan, base_weights, seed=0) @@ -611,9 +560,7 @@ def test_every_prefix_consistency(self, base_setup, additive_surgery_chain): direct = plan_surgery(base_config, configs[k]) # Verify keys match - assert set(composed.mappings.keys()) == set( - direct.mappings.keys() - ), f"Prefix {k}: keys don't match" + assert set(composed.mappings.keys()) == set(direct.mappings.keys()), f"Prefix {k}: keys don't match" # Execute and compare composed_weights = execute(composed, base_weights, seed=0) @@ -781,9 +728,7 @@ def torture_setup(self, llava_pixtral_checkpoint): """Set up for comprehensive torture tests.""" from safetensors.torch import load_file - from fast_llm_external_models.apriel2.conversion.llava import ( - convert_config as convert_llava_config, - ) + from fast_llm_external_models.apriel2.conversion.llava import convert_config as convert_llava_config # Load source with open(llava_pixtral_checkpoint / "config.json") as f: @@ -801,9 +746,7 @@ def torture_setup(self, llava_pixtral_checkpoint): return base_config, base_weights - def test_each_step_produces_valid_config( - self, torture_setup, comprehensive_torture_chain - ): + def test_each_step_produces_valid_config(self, torture_setup, comprehensive_torture_chain): """Test that each surgery step produces a valid config.""" base_config, _ = torture_setup @@ -818,9 +761,7 @@ def test_each_step_produces_valid_config( pytest.fail(f"Step {i+1} produced invalid config: {e}") @requires_cuda - def test_each_step_produces_working_model( - self, torture_setup, comprehensive_torture_chain - ): + def test_each_step_produces_working_model(self, torture_setup, comprehensive_torture_chain): """Test that each surgery step produces a model that can forward pass. This is the ultimate integration test - config composition + plan building @@ -875,9 +816,7 @@ def test_each_step_produces_working_model( current_weights = new_weights @requires_cuda - def test_final_supernet_structure( - self, torture_setup, comprehensive_torture_chain - ): + def test_final_supernet_structure(self, torture_setup, comprehensive_torture_chain): """Verify the final architecture has supernet blocks with all 4 mixer types.""" base_config, base_weights = torture_setup @@ -914,9 +853,7 @@ def test_final_supernet_structure( assert outputs.logits.shape == (1, 8, config.vocab_size) @requires_cuda - def test_plan_config_consistency_comprehensive( - self, torture_setup, comprehensive_torture_chain - ): + def test_plan_config_consistency_comprehensive(self, torture_setup, comprehensive_torture_chain): """Test that incremental plan composition works for the comprehensive chain. Note: We cannot compare to a "direct plan" because the comprehensive chain @@ -1083,66 +1020,6 @@ def mamba_config(self): }, } - def test_config_composition_identical_regardless_of_init_mode(self, base_config): - """Config composition produces same structure with init: transfer vs init: random.""" - # Surgery with init: transfer - surgery_transfer = { - "decoder": { - "block": { - "mixer": { - "type": "stochastic", - "main_mixer_name": "attention", - "mixers": { - "attention": {"type": "attention", "init": "transfer"}, - "swa": { - "type": "attention", - "init": "transfer", - "sliding_window": 512, - }, - }, - }, - }, - }, - } - - # Surgery with init: random - surgery_random = { - "decoder": { - "block": { - "mixer": { - "type": "stochastic", - "main_mixer_name": "attention", - "mixers": { - "attention": {"type": "attention", "init": "random"}, - "swa": { - "type": "attention", - "init": "random", - "sliding_window": 512, - }, - }, - }, - }, - }, - } - - # Compose configs - result_transfer = compose_configs(base_config, surgery_transfer) - result_random = compose_configs(base_config, surgery_random) - - # Both should produce identical structure (init is stripped) - assert result_transfer == result_random, ( - "Config composition should produce identical structure regardless of init mode" - ) - - # Verify the structure is correct - mixer = result_transfer["decoder"]["block"]["mixer"] - assert mixer["type"] == "stochastic" - assert "attention" in mixer["mixers"] - assert "swa" in mixer["mixers"] - # init should be stripped - assert "init" not in mixer["mixers"]["attention"] - assert "init" not in mixer["mixers"]["swa"] - def test_plan_surgery_random_succeeds_for_any_type_pair(self, mamba_config): """plan_surgery with init: random should succeed even for mamba -> attention.""" # This surgery changes mamba to attention with random init @@ -1166,7 +1043,7 @@ def test_plan_surgery_random_succeeds_for_any_type_pair(self, mamba_config): plan = plan_surgery(mamba_config, surgery) # Verify the plan has the expected target keys - target_keys = set(str(k) for k in plan.mappings.keys()) + target_keys = {str(k) for k in plan.mappings.keys()} assert any("mixer.q_proj" in k for k in target_keys) def test_plan_surgery_transfer_fails_for_unsupported_type_pair(self, mamba_config): @@ -1219,7 +1096,7 @@ def test_plan_surgery_transfer_succeeds_for_supported_type_pair(self, base_confi plan = plan_surgery(base_config, surgery) # Verify the plan has mamba target keys - target_keys = set(str(k) for k in plan.mappings.keys()) + target_keys = {str(k) for k in plan.mappings.keys()} assert any("mixer.in_proj" in k for k in target_keys) def test_stochastic_init_random_succeeds_for_any_submixer_type(self, mamba_config): @@ -1259,7 +1136,7 @@ def test_stochastic_init_random_succeeds_for_any_submixer_type(self, mamba_confi plan = plan_surgery(mamba_config, surgery) # Verify both sub-mixers have target keys - target_keys = set(str(k) for k in plan.mappings.keys()) + target_keys = {str(k) for k in plan.mappings.keys()} assert any("mixers.attention.q_proj" in k for k in target_keys) assert any("mixers.swa.q_proj" in k for k in target_keys) @@ -1294,7 +1171,7 @@ def test_mixed_init_modes_in_stochastic(self, base_config): plan = plan_surgery(base_config, surgery) # Verify both sub-mixers have target keys - target_keys = set(str(k) for k in plan.mappings.keys()) + target_keys = {str(k) for k in plan.mappings.keys()} assert any("mixers.attention.q_proj" in k for k in target_keys) assert any("mixers.gdn.in_proj_qkvz" in k for k in target_keys) @@ -1313,8 +1190,8 @@ class TestMarkovianProperty: """ @pytest.fixture - def attention_config(self): - """Base config with attention.""" + def attention_config_dict(self): + """Base config dict with attention mixer for compose_configs tests.""" return { "model_type": "apriel2", "hidden_size": 256, @@ -1335,43 +1212,7 @@ def attention_config(self): }, } - @pytest.fixture - def stochastic_config(self): - """Config with stochastic mixer.""" - return { - "model_type": "apriel2", - "hidden_size": 256, - "vocab_size": 1000, - "decoder": { - "type": "fixed", - "num_blocks": 2, - "block": { - "mixer": { - "type": "stochastic", - "main_mixer_name": "attention", - "mixers": { - "attention": { - "type": "attention", - "heads": 8, - "head_groups": 4, - "head_size": 32, - }, - "swa": { - "type": "sliding_window", - "heads": 8, - "head_groups": 4, - "head_size": 32, - "window_size": 512, - }, - }, - }, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - }, - } - - def test_different_paths_same_config_same_plan(self, attention_config): + def test_different_paths_same_config_same_plan(self, attention_config_dict): """Two different paths to the same config produce identical plans. Path A: attention -> stochastic{att, swa} @@ -1398,7 +1239,7 @@ def test_different_paths_same_config_same_plan(self, attention_config): }, }, } - config_a = compose_configs(attention_config, surgery_a) + config_a = compose_configs(attention_config_dict, surgery_a) # Path B: First add attention only, then add swa surgery_b1 = { @@ -1414,7 +1255,7 @@ def test_different_paths_same_config_same_plan(self, attention_config): }, }, } - intermediate_config = compose_configs(attention_config, surgery_b1) + intermediate_config = compose_configs(attention_config_dict, surgery_b1) surgery_b2 = { "decoder": { @@ -1465,11 +1306,11 @@ def test_different_paths_same_config_same_plan(self, attention_config): plan_from_b = plan_surgery(config_b, final_surgery) # Compare plan mappings - keys_a = set(str(k) for k in plan_from_a.mappings.keys()) - keys_b = set(str(k) for k in plan_from_b.mappings.keys()) + keys_a = {str(k) for k in plan_from_a.mappings.keys()} + keys_b = {str(k) for k in plan_from_b.mappings.keys()} assert keys_a == keys_b, "Plans from same config via different paths should be identical" - def test_init_in_source_config_does_not_affect_plan(self, attention_config): + def test_init_in_source_config_does_not_affect_plan(self, attention_config_dict): """Manually injecting init into source config doesn't change the plan. This tests that plan_surgery reads init from surgery, not source. @@ -1479,8 +1320,8 @@ def test_init_in_source_config_does_not_affect_plan(self, attention_config): import copy # Create two copies of the config - config_with_init = copy.deepcopy(attention_config) - config_without_init = copy.deepcopy(attention_config) + config_with_init = copy.deepcopy(attention_config_dict) + config_without_init = copy.deepcopy(attention_config_dict) # Manually inject init into one (bypassing compose_configs) config_with_init["decoder"]["block"]["mixer"]["init"] = "random" @@ -1504,238 +1345,12 @@ def test_init_in_source_config_does_not_affect_plan(self, attention_config): plan_with = plan_surgery(config_with_init, surgery) plan_without = plan_surgery(config_without_init, surgery) - keys_with = set(str(k) for k in plan_with.mappings.keys()) - keys_without = set(str(k) for k in plan_without.mappings.keys()) + keys_with = {str(k) for k in plan_with.mappings.keys()} + keys_without = {str(k) for k in plan_without.mappings.keys()} # Plans should be identical - source's init field is ignored assert keys_with == keys_without, "Plan should not depend on init in source config" - def test_associativity_of_surgery_composition(self, attention_config): - """Verify associativity: (A ∘ B) ∘ C == A ∘ (B ∘ C) for surgery specs. - - This tests that composing surgeries is associative, which is - equivalent to Markovianity for plan creation. - """ - surgery_a = { - "decoder": { - "block": { - "mixer": { - "type": "stochastic", - "main_mixer_name": "attention", - "mixers": { - "attention": {"type": "attention", "init": "transfer"}, - }, - }, - }, - }, - } - - surgery_b = { - "decoder": { - "block": { - "mixer": { - "mixers": { - "swa": { - "type": "sliding_window", - "init": "transfer", - "window_size": 512, - }, - }, - }, - }, - }, - } - - surgery_c = { - "decoder": { - "block": { - "mixer": { - "mixers": { - "gdn": { - "type": "gdn", - "init": "random", - "value_heads": 8, - "key_heads": 4, - "key_head_dim": 32, - "value_head_dim": 32, - "convolution_layer": {"kernel_size": 4}, - }, - }, - }, - }, - }, - } - - # Left association: ((attention_config ∘ A) ∘ B) ∘ C - left_1 = compose_configs(attention_config, surgery_a) - left_2 = compose_configs(left_1, surgery_b) - left_result = compose_configs(left_2, surgery_c) - - # Right association: (attention_config ∘ A) ∘ (B ∘ C) - # Note: B ∘ C is partial ∘ partial = deep merge of surgery specs - bc_merged = compose_configs(surgery_b, surgery_c) - right_1 = compose_configs(attention_config, surgery_a) - right_result = compose_configs(right_1, bc_merged) - - assert left_result == right_result, "Surgery composition should be associative" - - def test_complete_configs_have_no_init_fields(self, attention_config): - """Verify that compose_configs strips init from complete configs. - - This is the key invariant that enables Markovianity: - - Complete configs (states) have no init fields - - Surgery specs (transitions) have init fields - - Plans read init from surgery, not state - """ - surgery_with_init = { - "decoder": { - "block": { - "mixer": { - "type": "stochastic", - "main_mixer_name": "attention", - "mixers": { - "attention": {"type": "attention", "init": "transfer"}, - "swa": {"type": "sliding_window", "init": "random", "window_size": 512}, - }, - }, - }, - }, - } - - result = compose_configs(attention_config, surgery_with_init) - - # Recursively check for init fields - def has_init(obj): - if isinstance(obj, dict): - if "init" in obj: - return True - return any(has_init(v) for v in obj.values()) - if isinstance(obj, list): - return any(has_init(v) for v in obj) - return False - - assert not has_init(result), "Complete configs should have no init fields" - - def test_monoid_action_law_additive_surgeries(self): - """Monoid action law HOLDS for additive surgeries. - - Additive surgeries (no type: declaration) support: - apply(apply(s, t1), t2) == apply(s, t1 ∘ t2) - - This is because additive operations commute nicely: - "add {a}" then "add {b}" == "add {a, b}" - """ - # Start with stochastic (additive surgery target) - s = { - "model_type": "apriel2", - "hidden_size": 256, - "vocab_size": 1000, - "decoder": { - "type": "fixed", - "num_blocks": 2, - "block": { - "mixer": { - "type": "stochastic", - "main_mixer_name": "attention", - "mixers": { - "attention": {"type": "attention", "heads": 8, "head_groups": 4, "head_size": 32}, - }, - }, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - }, - } - - # Additive surgeries (no type: declaration) - t1 = {"decoder": {"block": {"mixer": {"mixers": {"swa": {"type": "sliding_window", "window_size": 512}}}}}} - t2 = {"decoder": {"block": {"mixer": {"mixers": {"mamba": {"type": "mamba", "d_inner": 512}}}}}} - - # Path A: Sequential - s_prime = compose_configs(s, t1) - s_double_prime_A = compose_configs(s_prime, t2) - - # Path B: Composed - t1_t2 = compose_configs(t1, t2) - s_double_prime_B = compose_configs(s, t1_t2) - - assert s_double_prime_A == s_double_prime_B, "Monoid action law should hold for additive surgeries" - - def test_monoid_action_law_replacement_surgeries_fails(self): - """Monoid action law FAILS for replacement surgeries (by design). - - Replacement surgeries (type: stochastic declared) have: - apply(apply(s, t1), t2) != apply(s, t1 ∘ t2) - - This is FUNDAMENTAL, not a bug: - - Sequential: "set to {a}" then "set to {b}" → {b} (second wins) - - Composed: merge({a}, {b}) = {a,b}, then apply → {a,b} - - These are genuinely different semantics. The failure documents - the distinction between declarative composition (merge) and - operational composition (function application). - """ - s = { - "model_type": "apriel2", - "hidden_size": 256, - "vocab_size": 1000, - "decoder": { - "type": "fixed", - "num_blocks": 2, - "block": { - "mixer": {"type": "attention", "heads": 8, "head_groups": 4, "head_size": 32}, - "mlp": {"type": "mlp", "intermediate_size": 256}, - "normalization": {"type": "rms_norm", "epsilon": 1e-5}, - }, - }, - } - - # Replacement surgeries (both declare type: stochastic) - t1 = { - "decoder": { - "block": { - "mixer": { - "type": "stochastic", - "main_mixer_name": "attention", - "mixers": {"attention": {"type": "attention"}}, - } - } - } - } - t2 = { - "decoder": { - "block": { - "mixer": { - "type": "stochastic", - "main_mixer_name": "swa", - "mixers": {"swa": {"type": "sliding_window", "window_size": 512}}, - } - } - } - } - - # Path A: Sequential (second replacement wins) - s_prime = compose_configs(s, t1) - s_double_prime_A = compose_configs(s_prime, t2) - - # Path B: Composed (declarations merged) - t1_t2 = compose_configs(t1, t2) - s_double_prime_B = compose_configs(s, t1_t2) - - # They should be DIFFERENT (law fails) - assert s_double_prime_A != s_double_prime_B, ( - "Monoid action law should FAIL for replacement surgeries" - ) - - # Verify the specific difference: - # Sequential: only swa (second replacement wins) - # Composed: both attention and swa (merged declarations) - mixers_A = set(s_double_prime_A["decoder"]["block"]["mixer"]["mixers"].keys()) - mixers_B = set(s_double_prime_B["decoder"]["block"]["mixer"]["mixers"].keys()) - - assert mixers_A == {"swa"}, "Sequential: second replacement wins" - assert mixers_B == {"attention", "swa"}, "Composed: declarations merged" - class TestCyclingSurgeryGeneration: """Tests for the cycling surgery generation functions. @@ -1936,7 +1551,7 @@ def test_expand_surgery_chain_adds_cycling(self): # Verify restore flag assert expanded[0][2] is False # surgery - not restore assert expanded[1][2] is False # cycle - not restore - assert expanded[2][2] is True # restore + assert expanded[2][2] is True # restore def test_expand_surgery_chain_preserves_invariant(self): """Test that cycling leaves the chain state invariant.""" @@ -1980,3 +1595,151 @@ def test_expand_surgery_chain_preserves_invariant(self): # After cycling and restore, we should be back to the same state assert current_config == config_after_original + + +class TestBiasSurgeryChain: + """Torture tests for per-layer bias inheritance through surgery operations. + + Uses apriel2_config_with_bias + bias_surgery_chain to test that: + - Qwen-style per-layer attention bias (QKV enabled, O disabled) survives surgery + - Non-gated MLP per-layer bias (layer_1 enabled, layer_2 disabled) survives surgery + - Bias settings are correctly inherited by new sub-mixers + - Bias is correctly tracked in surgery plans + """ + + @pytest.fixture + def bias_source_config(self, apriel2_config_with_bias): + """Convert Apriel2Config to dict for surgery operations.""" + return apriel2_config_with_bias.to_dict() + + def test_bias_survives_stochastic_wrapper(self, bias_source_config, bias_surgery_chain): + """Test that bias settings survive wrapping in stochastic mixer.""" + # Apply first surgery (wrap in stochastic) + result = compose_configs(bias_source_config, bias_surgery_chain[0]) + + # Check attention sub-mixer inherited bias settings + mixer = result["decoder"]["block"]["mixer"] + assert mixer["type"] == "stochastic" + + attn_mixer = mixer["mixers"]["attention"] + assert attn_mixer["query_layer"]["bias"]["enabled"] is True + assert attn_mixer["key_layer"]["bias"]["enabled"] is True + assert attn_mixer["value_layer"]["bias"]["enabled"] is True + assert attn_mixer["dense_layer"]["bias"]["enabled"] is False + + # Check MLP bias survived + mlp = result["decoder"]["block"]["mlp"] + assert mlp["layer_1"]["bias"]["enabled"] is True + assert mlp["layer_2"]["bias"]["enabled"] is False + + def test_new_submixer_inherits_bias(self, bias_source_config, bias_surgery_chain): + """Test that new sub-mixers inherit bias from source attention.""" + # Apply S1 + S2 (wrap in stochastic, add sliding_window) + config = bias_source_config + for surgery in bias_surgery_chain[:2]: + config = compose_configs(config, surgery) + + # sliding_window should inherit bias from source attention + mixer = config["decoder"]["block"]["mixer"] + sw_mixer = mixer["mixers"]["sliding_window"] + + assert sw_mixer["query_layer"]["bias"]["enabled"] is True + assert sw_mixer["key_layer"]["bias"]["enabled"] is True + assert sw_mixer["value_layer"]["bias"]["enabled"] is True + assert sw_mixer["dense_layer"]["bias"]["enabled"] is False + + def test_full_bias_chain_produces_valid_config(self, bias_source_config, bias_surgery_chain): + """Test that full bias surgery chain produces valid config.""" + config = bias_source_config + for surgery in bias_surgery_chain: + config = compose_configs(config, surgery) + + # Verify final config structure + mixer = config["decoder"]["block"]["mixer"] + assert mixer["type"] == "stochastic" + assert "attention" in mixer["mixers"] + assert "sliding_window" in mixer["mixers"] + assert "full_bias_attn" in mixer["mixers"] + + # attention and sliding_window inherit Qwen-style bias + for name in ["attention", "sliding_window"]: + sub = mixer["mixers"][name] + assert sub["query_layer"]["bias"]["enabled"] is True + assert sub["dense_layer"]["bias"]["enabled"] is False + + # full_bias_attn has add_linear_biases=True but per-layer settings inherited from + # source take precedence, so O bias is still disabled + full_bias = mixer["mixers"]["full_bias_attn"] + assert full_bias.get("add_linear_biases") is True + # Per-layer dense_layer.bias.enabled=False inherited from source takes precedence + assert full_bias["dense_layer"]["bias"]["enabled"] is False + + def test_bias_plan_has_correct_mappings(self, bias_source_config, bias_surgery_chain): + """Test that surgery plan correctly includes/excludes bias weight mappings.""" + # Compose config first to get full target config with inherited bias settings + target_config = compose_configs(bias_source_config, bias_surgery_chain[0]) + plan = plan_surgery(bias_source_config, target_config) + mapping_strs = [str(k) for k in plan.mappings.keys()] + + # Should have q_proj.bias (enabled) + q_bias = [m for m in mapping_strs if "q_proj.bias" in m] + assert len(q_bias) > 0, "Should have q_proj.bias mappings" + + # Should NOT have o_proj.bias (disabled) + o_bias = [m for m in mapping_strs if "o_proj.bias" in m] + assert len(o_bias) == 0, "Should not have o_proj.bias mappings" + + # Should have up_proj.bias (layer_1 enabled) + up_bias = [m for m in mapping_strs if "up_proj.bias" in m] + assert len(up_bias) > 0, "Should have up_proj.bias mappings" + + # Should NOT have down_proj.bias (layer_2 disabled) + down_bias = [m for m in mapping_strs if "down_proj.bias" in m] + assert len(down_bias) == 0, "Should not have down_proj.bias mappings" + + def test_bias_chain_produces_working_model(self, bias_source_config, bias_surgery_chain): + """Test that bias surgery chain produces a working model.""" + from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForCausalLM + + # Apply full chain + config = bias_source_config + for surgery in bias_surgery_chain: + config = compose_configs(config, surgery) + + # Create model + apriel_config = Apriel2Config(**config) + model = Apriel2ForCausalLM(apriel_config) + model.eval() + + # Verify model structure has correct biases + block = model.model.decoder.blocks[0] + + # attention sub-mixer should have QKV bias, no O bias + attn = block.mixer.mixers["attention"] + assert attn.q_proj.bias is not None + assert attn.k_proj.bias is not None + assert attn.v_proj.bias is not None + assert attn.o_proj.bias is None + + # sliding_window should also inherit bias settings + sw = block.mixer.mixers["sliding_window"] + assert sw.q_proj.bias is not None + assert sw.o_proj.bias is None + + # full_bias_attn inherits per-layer bias from source (even with add_linear_biases=True, + # per-layer settings take precedence in same-type inheritance) + full_bias = block.mixer.mixers["full_bias_attn"] + assert full_bias.q_proj.bias is not None + # O bias is disabled because inherited per-layer dense_layer.bias.enabled=False + # takes precedence over add_linear_biases=True + assert full_bias.o_proj.bias is None + + # MLP should have layer_1 bias, no layer_2 bias + assert block.mlp.up_proj.bias is not None + assert block.mlp.down_proj.bias is None + + # Forward pass should work + input_ids = torch.randint(0, config["vocab_size"], (1, 10)) + with torch.no_grad(): + outputs = model(input_ids, use_cache=False) + assert outputs.logits.shape == (1, 10, config["vocab_size"]) diff --git a/fast_llm_external_models/tests/test_apriel2/test_convert_from_llava.py b/fast_llm_external_models/tests/test_apriel2/test_convert_from_llava.py index a437f920d..f96f5ac40 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_convert_from_llava.py +++ b/fast_llm_external_models/tests/test_apriel2/test_convert_from_llava.py @@ -14,23 +14,15 @@ """ import json -from pathlib import Path -import pytest import torch from safetensors import safe_open -from safetensors.torch import save_file from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config -from fast_llm_external_models.apriel2.conversion import ( - convert_llava_config as convert_config, - execute, - plan_llava_to_apriel2, - plan_surgery, -) +from fast_llm_external_models.apriel2.conversion import convert_llava_config as convert_config +from fast_llm_external_models.apriel2.conversion import execute, plan_llava_to_apriel2, plan_surgery from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForConditionalGeneration - # ============================================================================= # Config Conversion Tests # ============================================================================= @@ -330,9 +322,9 @@ def test_plan_keys_match_model_state_dict(self, llava_pixtral_checkpoint): extra_in_plan = plan_keys - model_keys # Filter out expected missing keys (caches, positions, etc.) - missing_in_plan = {k for k in missing_in_plan if not any( - skip in k.lower() for skip in ["cache", "position", "mask"] - )} + missing_in_plan = { + k for k in missing_in_plan if not any(skip in k.lower() for skip in ["cache", "position", "mask"]) + } assert not missing_in_plan, f"Model keys not in plan: {sorted(missing_in_plan)[:10]}" assert not extra_in_plan, f"Plan keys not in model: {sorted(extra_in_plan)[:10]}" diff --git a/fast_llm_external_models/tests/test_apriel2/test_equivalence.py b/fast_llm_external_models/tests/test_apriel2/test_equivalence.py index c59ed2000..9b3eb4efe 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_equivalence.py +++ b/fast_llm_external_models/tests/test_apriel2/test_equivalence.py @@ -23,9 +23,6 @@ import torch from transformers import LlavaForConditionalGeneration -from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForConditionalGeneration - - # ============================================================================= # Input Configuration # ============================================================================= @@ -487,8 +484,10 @@ def test_batch_processing_behavior(self, model_pair): batch_tgt = target.get_image_features(pixel_values).view(-1, batch_src.shape[-1]) # Sequential processing - singles_src = [get_pixtral_vision_features(source, pixel_values[i:i+1]) for i in range(3)] - singles_tgt = [target.get_image_features(pixel_values[i:i+1]).view(-1, batch_src.shape[-1]) for i in range(3)] + singles_src = [get_pixtral_vision_features(source, pixel_values[i : i + 1]) for i in range(3)] + singles_tgt = [ + target.get_image_features(pixel_values[i : i + 1]).view(-1, batch_src.shape[-1]) for i in range(3) + ] single_concat_src = torch.cat(singles_src, dim=0) single_concat_tgt = torch.cat(singles_tgt, dim=0) @@ -500,9 +499,9 @@ def test_batch_processing_behavior(self, model_pair): print(f"Apriel2 batch vs sequential: {tgt_diff:.6f}") # Both should have the same behavior (within FP tolerance) - assert abs(src_diff - tgt_diff) < 1e-6, ( - f"Batch processing behavior differs: src={src_diff:.6f}, tgt={tgt_diff:.6f}" - ) + assert ( + abs(src_diff - tgt_diff) < 1e-6 + ), f"Batch processing behavior differs: src={src_diff:.6f}, tgt={tgt_diff:.6f}" if __name__ == "__main__": diff --git a/fast_llm_external_models/tests/test_apriel2/test_expr_plan.py b/fast_llm_external_models/tests/test_apriel2/test_expr_plan.py index c487ab3a3..2dccac5ad 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_expr_plan.py +++ b/fast_llm_external_models/tests/test_apriel2/test_expr_plan.py @@ -1,15 +1,13 @@ """Tests for the expression-based plan system.""" import json + import pytest import torch -from fast_llm_external_models.tests.test_apriel2.conftest import requires_cuda - from fast_llm_external_models.apriel2.conversion import ( Concat, EvalKwargs, - Expr, ExprAdapter, ExprPlan, Init, @@ -18,10 +16,9 @@ Slice, StreamingExecutor, W, - compose, execute, - fuse, full_slice, + fuse, make_slice, plan_dil_attention_to_gdn, plan_kil_attention_to_kda, @@ -31,6 +28,7 @@ slice_spec, substitute, ) +from fast_llm_external_models.tests.test_apriel2.conftest import requires_cuda def make_eval_kwargs( @@ -219,10 +217,13 @@ def test_substitute_init_unchanged(self): def test_substitute_complex(self): """Substitute handles complex nested expressions.""" # Concat of Slice(Ref) and Init - expr = Concat(exprs=( - Slice(expr=Ref(key=W("a")), slices=((0, 5, None),)), - Init(shape=(5,), init_type="zeros"), - ), dim=0) + expr = Concat( + exprs=( + Slice(expr=Ref(key=W("a")), slices=((0, 5, None),)), + Init(shape=(5,), init_type="zeros"), + ), + dim=0, + ) bindings = {W("a"): Ref(key=W("source"))} result = substitute(expr, bindings) @@ -238,7 +239,13 @@ class TestFuse: def test_fuse_flatten_concat(self): """Fuse flattens nested Concat with same dim.""" inner = Concat(exprs=(Ref(key=W("a")), Ref(key=W("b"))), dim=0) - outer = Concat(exprs=(inner, Ref(key=W("c")),), dim=0) + outer = Concat( + exprs=( + inner, + Ref(key=W("c")), + ), + dim=0, + ) result = fuse(outer) assert isinstance(result, Concat) @@ -250,7 +257,13 @@ def test_fuse_flatten_concat(self): def test_fuse_no_flatten_different_dim(self): """Fuse doesn't flatten Concat with different dim.""" inner = Concat(exprs=(Ref(key=W("a")), Ref(key=W("b"))), dim=1) - outer = Concat(exprs=(inner, Ref(key=W("c")),), dim=0) + outer = Concat( + exprs=( + inner, + Ref(key=W("c")), + ), + dim=0, + ) result = fuse(outer) assert isinstance(result, Concat) @@ -340,28 +353,34 @@ class TestExprPlan: def test_plan_define_and_access(self): """Plan stores and retrieves expressions.""" - plan = ExprPlan(mappings={ - W("target"): Ref(key=W("source")), - }) + plan = ExprPlan( + mappings={ + W("target"): Ref(key=W("source")), + } + ) assert W("target") in plan assert isinstance(plan[W("target")], Ref) def test_plan_source_keys(self): """Plan identifies all source references.""" - plan = ExprPlan(mappings={ - W("a"): Ref(key=W("x")), - W("b"): Concat(exprs=(Ref(key=W("y")), Ref(key=W("z"))), dim=0), - W("c"): Init(shape=(10,), init_type="zeros"), - }) + plan = ExprPlan( + mappings={ + W("a"): Ref(key=W("x")), + W("b"): Concat(exprs=(Ref(key=W("y")), Ref(key=W("z"))), dim=0), + W("c"): Init(shape=(10,), init_type="zeros"), + } + ) assert plan.source_keys() == {W("x"), W("y"), W("z")} def test_plan_target_keys(self): """Plan identifies all target keys.""" - plan = ExprPlan(mappings={ - W("a"): Ref(key=W("x")), - W("b"): Ref(key=W("y")), - }) + plan = ExprPlan( + mappings={ + W("a"): Ref(key=W("x")), + W("b"): Ref(key=W("y")), + } + ) assert plan.target_keys() == {W("a"), W("b")} @@ -386,9 +405,17 @@ def test_plan_summary(self): def test_plan_fuse(self): """Plan fuse applies optimizations.""" inner = Concat(exprs=(Ref(key=W("a")), Ref(key=W("b"))), dim=0) - plan = ExprPlan(mappings={ - W("out"): Concat(exprs=(inner, Ref(key=W("c")),), dim=0), - }) + plan = ExprPlan( + mappings={ + W("out"): Concat( + exprs=( + inner, + Ref(key=W("c")), + ), + dim=0, + ), + } + ) fused = plan.fuse() assert isinstance(fused[W("out")], Concat) @@ -532,9 +559,11 @@ class TestStreamingExecution: def test_execute_simple(self): """Execute simple plan.""" - plan = ExprPlan(mappings={ - W("out"): Ref(key=W("in")), - }) + plan = ExprPlan( + mappings={ + W("out"): Ref(key=W("in")), + } + ) sources = {W("in"): torch.tensor([1.0, 2.0, 3.0])} result = execute(plan, sources, seed=42) @@ -544,9 +573,11 @@ def test_execute_simple(self): def test_execute_concat(self): """Execute plan with Concat.""" - plan = ExprPlan(mappings={ - W("combined"): Concat(exprs=(Ref(key=W("a")), Ref(key=W("b"))), dim=0), - }) + plan = ExprPlan( + mappings={ + W("combined"): Concat(exprs=(Ref(key=W("a")), Ref(key=W("b"))), dim=0), + } + ) sources = { W("a"): torch.ones(2, 3), @@ -559,14 +590,19 @@ def test_execute_concat(self): def test_execute_mil_like(self): """Execute MIL-like Concat of Slices and Init.""" # Simulated MIL: in_proj = [z, x, B, C] - plan = ExprPlan(mappings={ - W("in_proj"): Concat(exprs=( - Init(shape=(4, 8), init_type="zeros"), # z - Slice(expr=Ref(key=W("v")), slices=((0, 2, None), (None, None, None))), # x - Slice(expr=Ref(key=W("k")), slices=((0, 2, None), (None, None, None))), # B - Slice(expr=Ref(key=W("q")), slices=((0, 4, None), (None, None, None))), # C - ), dim=0), - }) + plan = ExprPlan( + mappings={ + W("in_proj"): Concat( + exprs=( + Init(shape=(4, 8), init_type="zeros"), # z + Slice(expr=Ref(key=W("v")), slices=((0, 2, None), (None, None, None))), # x + Slice(expr=Ref(key=W("k")), slices=((0, 2, None), (None, None, None))), # B + Slice(expr=Ref(key=W("q")), slices=((0, 4, None), (None, None, None))), # C + ), + dim=0, + ), + } + ) sources = { W("q"): torch.ones(4, 8), @@ -583,11 +619,13 @@ def test_execute_mil_like(self): def test_streaming_execution(self): """Streaming executor processes all targets.""" - plan = ExprPlan(mappings={ - W("out1"): Ref(key=W("shared")), - W("out2"): Ref(key=W("shared")), - W("out3"): Ref(key=W("unique")), - }) + plan = ExprPlan( + mappings={ + W("out1"): Ref(key=W("shared")), + W("out2"): Ref(key=W("shared")), + W("out3"): Ref(key=W("unique")), + } + ) load_calls = [] @@ -858,25 +896,23 @@ def test_plan_dil_execution(self): key_dim = 64 value_dim = 64 - head_k_dim = 16 - head_v_dim = 16 conv_dim = 2 * key_dim + value_dim # 192 # Create attention weights with per-head distinctive values # Q: each head gets value (head_idx + 1) q_weight = torch.zeros(64, 64) for h in range(4): - q_weight[h*16:(h+1)*16, :] = float(h + 1) + q_weight[h * 16 : (h + 1) * 16, :] = float(h + 1) # K: each head gets value (head_idx + 1) * 10 k_weight = torch.zeros(64, 64) for h in range(4): - k_weight[h*16:(h+1)*16, :] = float((h + 1) * 10) + k_weight[h * 16 : (h + 1) * 16, :] = float((h + 1) * 10) # V: each head gets value (head_idx + 1) * 100 v_weight = torch.zeros(64, 64) for h in range(4): - v_weight[h*16:(h+1)*16, :] = float((h + 1) * 100) + v_weight[h * 16 : (h + 1) * 16, :] = float((h + 1) * 100) sources = { W("attn.q_proj.weight"): q_weight, @@ -894,30 +930,23 @@ def test_plan_dil_execution(self): # Q_all (rows 0-63): heads 0,1,2,3 concatenated for h in range(4): - assert torch.allclose( - in_proj_qkvz[h*16:(h+1)*16], - torch.full((16, 64), float(h + 1)) - ) + assert torch.allclose(in_proj_qkvz[h * 16 : (h + 1) * 16], torch.full((16, 64), float(h + 1))) # K_all (rows 64-127): heads 0,1,2,3 concatenated for h in range(4): assert torch.allclose( - in_proj_qkvz[key_dim + h*16:key_dim + (h+1)*16], - torch.full((16, 64), float((h + 1) * 10)) + in_proj_qkvz[key_dim + h * 16 : key_dim + (h + 1) * 16], torch.full((16, 64), float((h + 1) * 10)) ) # V_all (rows 128-191): heads 0,1,2,3 concatenated for h in range(4): assert torch.allclose( - in_proj_qkvz[2*key_dim + h*16:2*key_dim + (h+1)*16], - torch.full((16, 64), float((h + 1) * 100)) + in_proj_qkvz[2 * key_dim + h * 16 : 2 * key_dim + (h + 1) * 16], + torch.full((16, 64), float((h + 1) * 100)), ) # Z_all (rows 192-255): zeros - assert torch.allclose( - in_proj_qkvz[2*key_dim + value_dim:], - torch.zeros(value_dim, 64) - ) + assert torch.allclose(in_proj_qkvz[2 * key_dim + value_dim :], torch.zeros(value_dim, 64)) # in_proj_ba should be zeros in_proj_ba = result[W("in_proj_ba.weight")] @@ -971,17 +1000,17 @@ def test_plan_dil_execution_gqa(self): # Q: 4 heads, each with value (head_idx + 1) q_weight = torch.zeros(64, 64) for h in range(4): - q_weight[h*16:(h+1)*16, :] = float(h + 1) + q_weight[h * 16 : (h + 1) * 16, :] = float(h + 1) # K: 2 kv_heads, each with value (head_idx + 1) * 10 k_weight = torch.zeros(32, 64) for h in range(2): - k_weight[h*16:(h+1)*16, :] = float((h + 1) * 10) + k_weight[h * 16 : (h + 1) * 16, :] = float((h + 1) * 10) # V: 2 kv_heads, each with value (head_idx + 1) * 100 v_weight = torch.zeros(32, 64) for h in range(2): - v_weight[h*16:(h+1)*16, :] = float((h + 1) * 100) + v_weight[h * 16 : (h + 1) * 16, :] = float((h + 1) * 100) sources = { W("attn.q_proj.weight"): q_weight, @@ -1007,22 +1036,22 @@ def test_plan_dil_execution_gqa(self): # K_all (rows 32-63): k_heads 0,1 (maps to source K heads 0,1 via modulo) # k_head 0 → source K head 0 (value 10) - assert torch.allclose(in_proj_qkvz[key_dim:key_dim+16], torch.full((16, 64), 10.0)) + assert torch.allclose(in_proj_qkvz[key_dim : key_dim + 16], torch.full((16, 64), 10.0)) # k_head 1 → source K head 1 (value 20) - assert torch.allclose(in_proj_qkvz[key_dim+16:key_dim+32], torch.full((16, 64), 20.0)) + assert torch.allclose(in_proj_qkvz[key_dim + 16 : key_dim + 32], torch.full((16, 64), 20.0)) # V_all (rows 64-127): 4 v_heads, tiled from 2 source KV heads via modulo # v_head 0 → src_v_head 0 (value 100) - assert torch.allclose(in_proj_qkvz[2*key_dim:2*key_dim+16], torch.full((16, 64), 100.0)) + assert torch.allclose(in_proj_qkvz[2 * key_dim : 2 * key_dim + 16], torch.full((16, 64), 100.0)) # v_head 1 → src_v_head 1 (value 200) - assert torch.allclose(in_proj_qkvz[2*key_dim+16:2*key_dim+32], torch.full((16, 64), 200.0)) + assert torch.allclose(in_proj_qkvz[2 * key_dim + 16 : 2 * key_dim + 32], torch.full((16, 64), 200.0)) # v_head 2 → src_v_head 0 (value 100, tiled) - assert torch.allclose(in_proj_qkvz[2*key_dim+32:2*key_dim+48], torch.full((16, 64), 100.0)) + assert torch.allclose(in_proj_qkvz[2 * key_dim + 32 : 2 * key_dim + 48], torch.full((16, 64), 100.0)) # v_head 3 → src_v_head 1 (value 200, tiled) - assert torch.allclose(in_proj_qkvz[2*key_dim+48:2*key_dim+64], torch.full((16, 64), 200.0)) + assert torch.allclose(in_proj_qkvz[2 * key_dim + 48 : 2 * key_dim + 64], torch.full((16, 64), 200.0)) # Z_all (rows 128-191): zeros - assert torch.allclose(in_proj_qkvz[2*key_dim+value_dim:], torch.zeros(value_dim, 64)) + assert torch.allclose(in_proj_qkvz[2 * key_dim + value_dim :], torch.zeros(value_dim, 64)) def test_plan_kil_attention_to_kda(self): """AIK plan produces correct structure for attention → KDA conversion.""" @@ -1188,6 +1217,7 @@ def test_compose_llava_to_mamba(self, llava_pixtral_config, apriel2_config_stoch # Build surgery plan (need intermediate config) from fast_llm_external_models.apriel2.conversion.llava import convert_config + intermediate_config = convert_config(llava_pixtral_config) target_config = apriel2_config_stochastic.to_dict() surgery_plan = plan_surgery(intermediate_config, target_config) @@ -1210,6 +1240,7 @@ def test_execute_composed_pipeline(self, llava_pixtral_checkpoint): """ import json from pathlib import Path + from safetensors.torch import load_file # Load config @@ -1448,10 +1479,9 @@ def test_comprehensive_conversion_all_mixer_types(self, llava_pixtral_checkpoint the conversion produced correct keys and shapes. """ import json - from pathlib import Path from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config - from fast_llm_external_models.apriel2.convert import build_plan, convert + from fast_llm_external_models.apriel2.convert import convert from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForConditionalGeneration # Load LLaVA config @@ -1477,11 +1507,11 @@ def test_comprehensive_conversion_all_mixer_types(self, llava_pixtral_checkpoint "type": "pattern", "num_blocks": 5, "pattern": [ - "attn", # 0: attention → attention (passthrough) - "mamba", # 1: attention → mamba (MIL) - "gdn", # 2: attention → gated_delta_net (DIL) - "stoch_am", # 3: attention → stochastic(attention + mamba) - "stoch_sg", # 4: attention → stochastic(swa + gdn) + "attn", # 0: attention → attention (passthrough) + "mamba", # 1: attention → mamba (MIL) + "gdn", # 2: attention → gated_delta_net (DIL) + "stoch_am", # 3: attention → stochastic(attention + mamba) + "stoch_sg", # 4: attention → stochastic(swa + gdn) ], "blocks": { # Pure attention (passthrough from source) @@ -1609,7 +1639,8 @@ def test_comprehensive_conversion_all_mixer_types(self, llava_pixtral_checkpoint "type": "attention", "heads": llava_config["vision_config"]["num_attention_heads"], "head_groups": llava_config["vision_config"]["num_attention_heads"], - "head_size": llava_config["vision_config"]["hidden_size"] // llava_config["vision_config"]["num_attention_heads"], + "head_size": llava_config["vision_config"]["hidden_size"] + // llava_config["vision_config"]["num_attention_heads"], "add_linear_biases": False, "causal": False, "rotary": { @@ -1688,7 +1719,6 @@ def test_conversion_plan_targets_match_model_state_dict(self, llava_pixtral_conf This test validates the plan WITHOUT executing it, by comparing plan target keys against what the model expects. """ - import json from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config from fast_llm_external_models.apriel2.convert import build_plan @@ -1703,7 +1733,7 @@ def test_conversion_plan_targets_match_model_state_dict(self, llava_pixtral_conf expected_keys = set(model.state_dict().keys()) # Get plan target keys - plan_target_keys = set(str(k) for k in plan.target_keys()) + plan_target_keys = {str(k) for k in plan.target_keys()} # Compare missing_from_plan = expected_keys - plan_target_keys @@ -1711,3 +1741,214 @@ def test_conversion_plan_targets_match_model_state_dict(self, llava_pixtral_conf assert not missing_from_plan, f"Plan missing keys that model expects: {sorted(missing_from_plan)}" assert not extra_in_plan, f"Plan has extra keys model doesn't expect: {sorted(extra_in_plan)}" + + +class TestBiasPlanGeneration: + """Test that surgery plans correctly handle per-layer bias configurations. + + These tests verify that plan_surgery correctly includes/excludes bias + weight mappings based on the per-layer bias settings: + - query_layer.bias.enabled, key_layer.bias.enabled, etc. for attention + - layer_1.bias.enabled, layer_2.bias.enabled for MLP + """ + + @pytest.fixture + def source_config_with_bias(self): + """Source config with Qwen-style bias (QKV enabled, O disabled).""" + return { + "model_type": "apriel2", + "hidden_size": 256, + "vocab_size": 1000, + "decoder": { + "type": "fixed", + "num_blocks": 2, + "block": { + "mixer": { + "type": "attention", + "heads": 8, + "head_groups": 4, + "head_size": 32, + "rotary": {"type": "mistral_1d", "theta": 10000.0}, + # Qwen-style: QKV bias enabled, O bias disabled + "query_layer": {"bias": {"enabled": True}}, + "key_layer": {"bias": {"enabled": True}}, + "value_layer": {"bias": {"enabled": True}}, + "dense_layer": {"bias": {"enabled": False}}, + }, + "mlp": { + "type": "mlp", + "intermediate_size": 512, + "gated": False, + # Per-layer MLP bias: layer_1 enabled, layer_2 disabled + "layer_1": {"bias": {"enabled": True}}, + "layer_2": {"bias": {"enabled": False}}, + }, + "normalization": {"type": "rms_norm", "epsilon": 1e-5}, + }, + }, + } + + def test_plan_includes_enabled_attention_biases(self, source_config_with_bias): + """Surgery plan includes bias mappings for enabled attention biases.""" + from fast_llm_external_models.apriel2.conversion.config import compose_configs + from fast_llm_external_models.apriel2.conversion.converters import plan_surgery + + target_config = compose_configs( + source_config_with_bias, + { + "decoder": { + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"init": "transfer"}, + }, + }, + "mlp": {"init": "transfer"}, + }, + }, + }, + ) + + plan = plan_surgery(source_config_with_bias, target_config) + mapping_strs = [str(k) for k in plan.mappings.keys()] + + # Should have q_proj.bias, k_proj.bias, v_proj.bias mappings + q_bias = [m for m in mapping_strs if "q_proj.bias" in m] + k_bias = [m for m in mapping_strs if "k_proj.bias" in m] + v_bias = [m for m in mapping_strs if "v_proj.bias" in m] + + assert len(q_bias) > 0, "Should have q_proj.bias mappings" + assert len(k_bias) > 0, "Should have k_proj.bias mappings" + assert len(v_bias) > 0, "Should have v_proj.bias mappings" + + def test_plan_excludes_disabled_attention_biases(self, source_config_with_bias): + """Surgery plan excludes bias mappings for disabled attention biases.""" + from fast_llm_external_models.apriel2.conversion.config import compose_configs + from fast_llm_external_models.apriel2.conversion.converters import plan_surgery + + target_config = compose_configs( + source_config_with_bias, + { + "decoder": { + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"init": "transfer"}, + }, + }, + "mlp": {"init": "transfer"}, + }, + }, + }, + ) + + plan = plan_surgery(source_config_with_bias, target_config) + mapping_strs = [str(k) for k in plan.mappings.keys()] + + # Should NOT have o_proj.bias mappings (disabled) + o_bias = [m for m in mapping_strs if "o_proj.bias" in m] + assert len(o_bias) == 0, f"Should not have o_proj.bias mappings, found: {o_bias}" + + def test_plan_includes_enabled_mlp_biases(self, source_config_with_bias): + """Surgery plan includes bias mappings for enabled MLP biases.""" + from fast_llm_external_models.apriel2.conversion.config import compose_configs + from fast_llm_external_models.apriel2.conversion.converters import plan_surgery + + target_config = compose_configs( + source_config_with_bias, + { + "decoder": { + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"init": "transfer"}, + }, + }, + "mlp": {"init": "transfer"}, + }, + }, + }, + ) + + plan = plan_surgery(source_config_with_bias, target_config) + mapping_strs = [str(k) for k in plan.mappings.keys()] + + # Should have up_proj.bias (layer_1) mappings + up_bias = [m for m in mapping_strs if "up_proj.bias" in m] + assert len(up_bias) > 0, "Should have up_proj.bias mappings" + + def test_plan_excludes_disabled_mlp_biases(self, source_config_with_bias): + """Surgery plan excludes bias mappings for disabled MLP biases.""" + from fast_llm_external_models.apriel2.conversion.config import compose_configs + from fast_llm_external_models.apriel2.conversion.converters import plan_surgery + + target_config = compose_configs( + source_config_with_bias, + { + "decoder": { + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"init": "transfer"}, + }, + }, + "mlp": {"init": "transfer"}, + }, + }, + }, + ) + + plan = plan_surgery(source_config_with_bias, target_config) + mapping_strs = [str(k) for k in plan.mappings.keys()] + + # Should NOT have down_proj.bias (layer_2) mappings + down_bias = [m for m in mapping_strs if "down_proj.bias" in m] + assert len(down_bias) == 0, f"Should not have down_proj.bias mappings, found: {down_bias}" + + def test_plan_random_init_creates_init_expressions_for_bias(self, source_config_with_bias): + """Random init creates Init expressions for bias weights.""" + from fast_llm_external_models.apriel2.conversion.converters import plan_surgery + + # Surgery spec - pass directly to plan_surgery (NOT composed, to preserve init) + surgery = { + "decoder": { + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"init": "transfer"}, + "new_attention": { + "type": "attention", + "init": "random", # This triggers random init + "heads": 8, + "head_groups": 4, + "head_size": 32, + "rotary": {"type": "mistral_1d", "theta": 10000.0}, + "add_linear_biases": True, # All biases enabled + }, + }, + }, + }, + }, + } + + # Pass surgery spec directly - init fields are preserved + plan = plan_surgery(source_config_with_bias, surgery) + + # Check that new_attention biases use Init expressions + new_mixer_bias_keys = [k for k in plan.mappings.keys() if "new_attention" in str(k) and "bias" in str(k)] + + assert len(new_mixer_bias_keys) > 0, "Should have bias mappings for new_attention" + + for key in new_mixer_bias_keys: + expr = plan.mappings[key] + assert isinstance(expr, Init), f"{key} should be Init, got {type(expr)}" diff --git a/fast_llm_external_models/tests/test_apriel2/test_integration.py b/fast_llm_external_models/tests/test_apriel2/test_integration.py new file mode 100644 index 000000000..e84fa06ef --- /dev/null +++ b/fast_llm_external_models/tests/test_apriel2/test_integration.py @@ -0,0 +1,330 @@ +"""Integration tests for Qwen2 -> Apriel2 -> Fast-LLM conversion pipeline. + +Tests verify the full conversion chain: +1. Qwen2 -> Apriel2 (external module conversion) +2. Apriel2 + Surgery -> Supernet (stochastic mixer creation) +3. Supernet -> Fast-LLM -> Supernet (roundtrip through training format) + +Test Strategy: +- Use real HuggingFace model (Qwen2.5-0.5B) for meaningful validation +- Separate config preservation tests from numerical equivalence tests +- Parameterize both conversion stages AND input variations +- Single test implementation applied across all stages +""" + +import json +import tempfile +from pathlib import Path + +import pytest +import torch + +from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config +from fast_llm_external_models.apriel2.conversion import compose, compose_configs, execute, plan_surgery +from fast_llm_external_models.apriel2.conversion.expr import W +from fast_llm_external_models.apriel2.conversion.qwen2.config import convert_config as convert_qwen2_config +from fast_llm_external_models.apriel2.conversion.qwen2.plan import plan_qwen2_to_apriel2 +from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForCausalLM + +from .conftest import requires_fastllm + +# ============================================================================= +# Test Input Variations +# ============================================================================= + +TEST_INPUTS = pytest.mark.parametrize( + "prompts,max_new_tokens", + [ + pytest.param(["Hello world"], 10, id="single_short"), + pytest.param(["Hi", "The quick brown fox jumps over the lazy dog"], 20, id="batch_varied"), + pytest.param(["Once upon a time"], 50, id="long_generation"), + ], +) + + +# ============================================================================= +# Conversion Fixtures +# ============================================================================= + + +@pytest.fixture(scope="module") +def qwen2_source(): + """Load Qwen2.5-0.5B as the source/reference model.""" + from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer + + model_name = "Qwen/Qwen2.5-0.5B" + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32, trust_remote_code=True) + config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + model.eval() + + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "left" + + return { + "model": model, + "tokenizer": tokenizer, + "config_dict": config.to_dict(), + "state_dict": model.state_dict(), + } + + +@pytest.fixture(scope="module") +def apriel2_converted(qwen2_source): + """Stage 1: Qwen2 -> Apriel2.""" + config_dict = convert_qwen2_config(qwen2_source["config_dict"]) + plan = plan_qwen2_to_apriel2(qwen2_source["config_dict"]) + weights = execute(plan, {W(k): v for k, v in qwen2_source["state_dict"].items()}, seed=42) + + config = Apriel2Config(**config_dict) + model = Apriel2ForCausalLM(config) + model.load_state_dict({str(k): v for k, v in weights.items()}, strict=False) + model.eval() + + return {"model": model, "config_dict": config_dict, "plan": plan, "name": "Apriel2"} + + +@pytest.fixture(scope="module") +def supernet_converted(qwen2_source, apriel2_converted): + """Stage 2: Apriel2 + Surgery -> Supernet.""" + surgery_spec = { + "decoder": { + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"type": "attention", "init": "transfer"}, + "sliding_window": { + "type": "attention", + "init": "transfer", + "window_size": 4096, + }, + }, + }, + }, + }, + } + + apriel_config = apriel2_converted["config_dict"] + supernet_config = compose_configs(apriel_config, surgery_spec) + + full_plan = compose( + apriel2_converted["plan"], + plan_surgery(apriel_config, supernet_config), + ) + + weights = execute(full_plan, {W(k): v for k, v in qwen2_source["state_dict"].items()}, seed=42) + + config = Apriel2Config(**supernet_config) + model = Apriel2ForCausalLM(config) + model.load_state_dict({str(k): v for k, v in weights.items()}, strict=False) + model.eval() + + return {"model": model, "config_dict": supernet_config, "name": "Supernet"} + + +@pytest.fixture(scope="module") +def roundtrip_converted(supernet_converted, qwen2_source): + """Stage 3: Supernet -> Fast-LLM -> Supernet.""" + if not torch.cuda.is_available(): + pytest.skip("Roundtrip conversion requires CUDA (integration tests need realistic hardware)") + + from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, CheckpointSaveConfig, FastLLMCheckpointFormat + from fast_llm.engine.checkpoint.convert import ConvertConfig + from fast_llm.models.gpt.config import GPTModelConfig + from fast_llm.models.gpt.conversion.config import Apriel2TextCheckpointFormat + + with tempfile.TemporaryDirectory() as tmpdir: + tmpdir = Path(tmpdir) + supernet_path = tmpdir / "supernet" + fastllm_path = tmpdir / "fastllm" + roundtrip_path = tmpdir / "roundtrip" + + supernet_converted["model"].save_pretrained(supernet_path) + qwen2_source["tokenizer"].save_pretrained(supernet_path) + + ConvertConfig( + model=GPTModelConfig, + input=CheckpointLoadConfig(path=supernet_path, format=Apriel2TextCheckpointFormat), + output=CheckpointSaveConfig(path=fastllm_path, format=FastLLMCheckpointFormat), + ).run() + + ConvertConfig( + model=GPTModelConfig, + input=CheckpointLoadConfig(path=fastllm_path, format=FastLLMCheckpointFormat), + output=CheckpointSaveConfig(path=roundtrip_path, format=Apriel2TextCheckpointFormat), + ).run() + + model = Apriel2ForCausalLM.from_pretrained(roundtrip_path) + model.eval() + + with open(roundtrip_path / "config.json") as f: + config_dict = json.load(f) + + yield {"model": model, "config_dict": config_dict, "name": "Roundtrip"} + + +# ============================================================================= +# Parameterized Fixture: All Conversion Stages +# ============================================================================= + + +@pytest.fixture(params=["apriel2", "supernet", "roundtrip"]) +def converted_model(request, apriel2_converted, supernet_converted): + """Parameterized fixture providing each conversion stage for testing. + + This allows a single test to run against all stages automatically. + """ + if request.param == "roundtrip": + pytest.importorskip("fast_llm") + if not torch.cuda.is_available(): + pytest.skip("Roundtrip tests require CUDA (integration tests need realistic hardware)") + # Lazy-load to avoid fixture evaluation when CUDA unavailable + roundtrip_converted = request.getfixturevalue("roundtrip_converted") + return roundtrip_converted + + return { + "apriel2": apriel2_converted, + "supernet": supernet_converted, + }[request.param] + + +# ============================================================================= +# Config Preservation Tests +# ============================================================================= + + +@pytest.mark.slow +class TestConfigPreservation: + """Verify configs are correctly preserved through the conversion chain.""" + + def test_apriel2_structure(self, qwen2_source, apriel2_converted): + """Qwen2 -> Apriel2 preserves model dimensions.""" + qwen = qwen2_source["config_dict"] + apriel = apriel2_converted["config_dict"] + + assert apriel["hidden_size"] == qwen["hidden_size"] + assert apriel["vocab_size"] == qwen["vocab_size"] + assert apriel["decoder"]["num_blocks"] == qwen["num_hidden_layers"] + + def test_apriel2_bias_pattern(self, apriel2_converted): + """Qwen2 -> Apriel2 preserves Qwen-style bias (QKV yes, O no).""" + mixer = apriel2_converted["config_dict"]["decoder"]["block"]["mixer"] + + assert mixer["query_layer"]["bias"]["enabled"] is True + assert mixer["key_layer"]["bias"]["enabled"] is True + assert mixer["value_layer"]["bias"]["enabled"] is True + assert mixer["dense_layer"]["bias"]["enabled"] is False + + def test_supernet_structure(self, supernet_converted): + """Surgery creates correct stochastic mixer structure.""" + mixer = supernet_converted["config_dict"]["decoder"]["block"]["mixer"] + + assert mixer["type"] == "stochastic" + assert mixer["main_mixer_name"] == "attention" + assert set(mixer["mixers"].keys()) == {"attention", "sliding_window"} + + def test_supernet_bias_inheritance(self, supernet_converted): + """Submixers inherit bias settings from source.""" + mixer = supernet_converted["config_dict"]["decoder"]["block"]["mixer"] + + for name in ["attention", "sliding_window"]: + assert mixer["mixers"][name]["query_layer"]["bias"]["enabled"] is True + assert mixer["mixers"][name]["dense_layer"]["bias"]["enabled"] is False + + @requires_fastllm + def test_roundtrip_structure(self, roundtrip_converted): + """Fast-LLM roundtrip preserves stochastic mixer structure.""" + mixer = roundtrip_converted["config_dict"]["decoder"]["block"]["mixer"] + + assert mixer["type"] == "stochastic" + assert mixer["main_mixer_name"] == "attention" + assert set(mixer["mixers"].keys()) == {"attention", "sliding_window"} + + @requires_fastllm + def test_roundtrip_bias_preservation(self, roundtrip_converted): + """Fast-LLM roundtrip preserves per-layer bias settings.""" + mixer = roundtrip_converted["config_dict"]["decoder"]["block"]["mixer"] + + for name in ["attention", "sliding_window"]: + assert mixer["mixers"][name]["query_layer"]["bias"]["enabled"] is True + assert mixer["mixers"][name]["dense_layer"]["bias"]["enabled"] is False + + +# ============================================================================= +# Numerical Equivalence Tests +# ============================================================================= + + +@pytest.mark.slow +class TestNumericalEquivalence: + """Verify all conversion stages produce numerically identical outputs. + + Uses parameterized fixtures to test all stages with all input variations, + giving us 3 stages × 3 inputs = 9 test cases from a single test function. + """ + + @TEST_INPUTS + def test_logits_match(self, qwen2_source, converted_model, prompts, max_new_tokens): + """Converted model produces identical logits to source.""" + tokenizer = qwen2_source["tokenizer"] + ref_model = qwen2_source["model"] + test_model = converted_model["model"] + stage = converted_model["name"] + + inputs = tokenizer(prompts, return_tensors="pt", padding=True) + ref_device = next(ref_model.parameters()).device + test_device = next(test_model.parameters()).device + + with torch.no_grad(): + ref_logits = ref_model( + input_ids=inputs.input_ids.to(ref_device), + attention_mask=inputs.attention_mask.to(ref_device), + ).logits.cpu() + + test_logits = test_model( + input_ids=inputs.input_ids.to(test_device), + attention_mask=inputs.attention_mask.to(test_device), + ).logits.cpu() + + max_diff = (ref_logits - test_logits).abs().max().item() + assert torch.allclose( + ref_logits, test_logits, rtol=1e-4, atol=1e-4 + ), f"{stage} logits mismatch: max diff = {max_diff:.6f}" + + @TEST_INPUTS + def test_generation_match(self, qwen2_source, converted_model, prompts, max_new_tokens): + """Converted model produces identical generation to source.""" + tokenizer = qwen2_source["tokenizer"] + ref_model = qwen2_source["model"] + test_model = converted_model["model"] + stage = converted_model["name"] + + inputs = tokenizer(prompts, return_tensors="pt", padding=True) + ref_device = next(ref_model.parameters()).device + test_device = next(test_model.parameters()).device + + with torch.no_grad(): + ref_gen = ref_model.generate( + input_ids=inputs.input_ids.to(ref_device), + attention_mask=inputs.attention_mask.to(ref_device), + max_new_tokens=max_new_tokens, + do_sample=False, + pad_token_id=tokenizer.pad_token_id, + ).cpu() + + test_gen = test_model.generate( + input_ids=inputs.input_ids.to(test_device), + attention_mask=inputs.attention_mask.to(test_device), + max_new_tokens=max_new_tokens, + do_sample=False, + pad_token_id=tokenizer.pad_token_id, + ).cpu() + + assert torch.equal(ref_gen, test_gen), ( + f"{stage} generation mismatch:\n" + f" Reference: {tokenizer.batch_decode(ref_gen, skip_special_tokens=True)}\n" + f" Test: {tokenizer.batch_decode(test_gen, skip_special_tokens=True)}" + ) diff --git a/fast_llm_external_models/tests/test_apriel2/test_mixer_equivalence.py b/fast_llm_external_models/tests/test_apriel2/test_mixer_equivalence.py index 1aa8a56d9..c6f3337e8 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_mixer_equivalence.py +++ b/fast_llm_external_models/tests/test_apriel2/test_mixer_equivalence.py @@ -28,15 +28,7 @@ import torch import torch.nn as nn -from fast_llm_external_models.apriel2.conversion import ( - Concat, - ExprPlan, - Ref, - Slice, - W, - execute, -) - +from fast_llm_external_models.apriel2.conversion import Concat, ExprPlan, Ref, Slice, W, execute # ============================================================================= # Shared Fixtures @@ -69,10 +61,10 @@ def hidden_size(request): @pytest.fixture( params=[ - pytest.param((8, 8, 32), id="mha-8h-32d"), # MHA: 8 heads, 8 kv heads, 32 head_dim - pytest.param((8, 4, 32), id="gqa-8h4kv-32d"), # GQA: 8 heads, 4 kv heads, 32 head_dim - pytest.param((8, 2, 64), id="gqa-8h2kv-64d"), # GQA: 8 heads, 2 kv heads, 64 head_dim - pytest.param((4, 1, 64), id="mqa-4h1kv-64d"), # MQA: 4 heads, 1 kv head, 64 head_dim + pytest.param((8, 8, 32), id="mha-8h-32d"), # MHA: 8 heads, 8 kv heads, 32 head_dim + pytest.param((8, 4, 32), id="gqa-8h4kv-32d"), # GQA: 8 heads, 4 kv heads, 32 head_dim + pytest.param((8, 2, 64), id="gqa-8h2kv-64d"), # GQA: 8 heads, 2 kv heads, 64 head_dim + pytest.param((4, 1, 64), id="mqa-4h1kv-64d"), # MQA: 4 heads, 1 kv head, 64 head_dim ] ) def attention_config(request): @@ -90,7 +82,7 @@ def attention_config(request): params=[ pytest.param((8, 4, 32, 32), id="8v-4k-32d"), # 8 value heads, 4 key heads, symmetric dims pytest.param((8, 2, 64, 64), id="8v-2k-64d"), # 8 value heads, 2 key heads, larger dims - pytest.param((4, 2, 32, 64), id="4v-2k-asym"), # Asymmetric key/value dims + pytest.param((4, 2, 32, 64), id="4v-2k-asym"), # Asymmetric key/value dims ] ) def gdn_config(request): @@ -100,9 +92,9 @@ def gdn_config(request): @pytest.fixture( params=[ - pytest.param((4, 8), id="4h-8d"), # 4 heads, 8 head_dim (small) - pytest.param((8, 16), id="8h-16d"), # 8 heads, 16 head_dim (medium) - pytest.param((4, 32), id="4h-32d"), # 4 heads, 32 head_dim (large head_dim) + pytest.param((4, 8), id="4h-8d"), # 4 heads, 8 head_dim (small) + pytest.param((8, 16), id="8h-16d"), # 8 heads, 16 head_dim (medium) + pytest.param((4, 32), id="4h-32d"), # 4 heads, 32 head_dim (large head_dim) ] ) def kda_config(request): @@ -283,9 +275,21 @@ def plan_qwen3next_gdn_to_apriel2( for g in range(num_k_heads): base = g * group_size q_slices.append(Slice(expr=qkvz_ref, slices=((base, base + head_k_dim, None), (None, None, None)))) - k_slices.append(Slice(expr=qkvz_ref, slices=((base + head_k_dim, base + 2 * head_k_dim, None), (None, None, None)))) - v_slices.append(Slice(expr=qkvz_ref, slices=((base + 2 * head_k_dim, base + 2 * head_k_dim + v_per_group, None), (None, None, None)))) - z_slices.append(Slice(expr=qkvz_ref, slices=((base + 2 * head_k_dim + v_per_group, base + group_size, None), (None, None, None)))) + k_slices.append( + Slice(expr=qkvz_ref, slices=((base + head_k_dim, base + 2 * head_k_dim, None), (None, None, None))) + ) + v_slices.append( + Slice( + expr=qkvz_ref, + slices=((base + 2 * head_k_dim, base + 2 * head_k_dim + v_per_group, None), (None, None, None)), + ) + ) + z_slices.append( + Slice( + expr=qkvz_ref, + slices=((base + 2 * head_k_dim + v_per_group, base + group_size, None), (None, None, None)), + ) + ) in_proj_qkvz_expr = Concat( exprs=( @@ -304,8 +308,15 @@ def plan_qwen3next_gdn_to_apriel2( b_slices, a_slices = [], [] for g in range(num_k_heads): base = g * ba_per_group - b_slices.append(Slice(expr=ba_ref, slices=((base, base + num_v_heads // num_k_heads, None), (None, None, None)))) - a_slices.append(Slice(expr=ba_ref, slices=((base + num_v_heads // num_k_heads, base + ba_per_group, None), (None, None, None)))) + b_slices.append( + Slice(expr=ba_ref, slices=((base, base + num_v_heads // num_k_heads, None), (None, None, None))) + ) + a_slices.append( + Slice( + expr=ba_ref, + slices=((base + num_v_heads // num_k_heads, base + ba_per_group, None), (None, None, None)), + ) + ) in_proj_ba_expr = Concat( exprs=(Concat(exprs=tuple(b_slices), dim=0), Concat(exprs=tuple(a_slices), dim=0)), @@ -565,6 +576,7 @@ def test_causal_vs_mistral( ): """Verify Apriel2Attention (causal) matches MistralAttention output.""" from transformers.models.mistral.modeling_mistral import MistralAttention, MistralRotaryEmbedding + from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2Attention mixer_config = apriel2_config.decoder["block"]["mixer"] @@ -593,13 +605,20 @@ def test_causal_vs_mistral( apriel2_attn.eval() with torch.no_grad(): - mistral_out = mistral_attn(hidden_states, position_embeddings=position_embeddings, attention_mask=causal_mask)[0] - apriel2_out = apriel2_attn(hidden_states, attention_mask=causal_mask, position_embeddings=position_embeddings)[0] + mistral_out = mistral_attn( + hidden_states, position_embeddings=position_embeddings, attention_mask=causal_mask + )[0] + apriel2_out = apriel2_attn( + hidden_states, attention_mask=causal_mask, position_embeddings=position_embeddings + )[0] rtol, atol = tolerance assert_close( - apriel2_out, mistral_out, rtol=rtol, atol=atol, - msg=f"Apriel2Attention vs MistralAttention (batch={batch_size}, seq={seq_len}, hidden={hidden_size})" + apriel2_out, + mistral_out, + rtol=rtol, + atol=atol, + msg=f"Apriel2Attention vs MistralAttention (batch={batch_size}, seq={seq_len}, hidden={hidden_size})", ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA") @@ -613,8 +632,9 @@ def test_noncausal_vs_pixtral( tolerance, ): """Verify Apriel2Attention (non-causal) matches PixtralAttention output.""" - from transformers.models.pixtral.modeling_pixtral import PixtralAttention, PixtralRotaryEmbedding from transformers.models.pixtral.configuration_pixtral import PixtralVisionConfig + from transformers.models.pixtral.modeling_pixtral import PixtralAttention, PixtralRotaryEmbedding + from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2TextConfig from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2Attention @@ -689,8 +709,11 @@ def test_noncausal_vs_pixtral( rtol, atol = tolerance assert_close( - apriel2_out, pixtral_out, rtol=rtol, atol=atol, - msg=f"Apriel2Attention (non-causal) vs PixtralAttention (batch={batch_size}, seq={seq_len})" + apriel2_out, + pixtral_out, + rtol=rtol, + atol=atol, + msg=f"Apriel2Attention (non-causal) vs PixtralAttention (batch={batch_size}, seq={seq_len})", ) @@ -737,6 +760,7 @@ def test_vs_qwen3next( ): """Verify Apriel2GatedDeltaNet matches Qwen3NextGatedDeltaNet output.""" from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextGatedDeltaNet + from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2GatedDeltaNet value_heads, key_heads, key_head_dim, value_head_dim = gdn_config @@ -758,8 +782,10 @@ def test_vs_qwen3next( # Transfer weights plan = plan_qwen3next_gdn_to_apriel2( - num_k_heads=key_heads, num_v_heads=value_heads, - head_k_dim=key_head_dim, head_v_dim=value_head_dim, + num_k_heads=key_heads, + num_v_heads=value_heads, + head_k_dim=key_head_dim, + head_v_dim=value_head_dim, ) source_weights = extract_module_weights(qwen_gdn) target_weights = execute(plan, source_weights, seed=seed) @@ -778,8 +804,11 @@ def test_vs_qwen3next( rtol, atol = tolerance assert_close( - apriel2_out, qwen_out, rtol=rtol, atol=atol, - msg=f"Apriel2GatedDeltaNet vs Qwen3NextGatedDeltaNet (batch={batch_size}, seq={seq_len})" + apriel2_out, + qwen_out, + rtol=rtol, + atol=atol, + msg=f"Apriel2GatedDeltaNet vs Qwen3NextGatedDeltaNet (batch={batch_size}, seq={seq_len})", ) @@ -803,6 +832,7 @@ def test_vs_fla( ): """Verify Apriel2 KimiDeltaAttention matches FLA KimiDeltaAttention output.""" from fla.layers.kda import KimiDeltaAttention as FLA_KDA + from fast_llm_external_models.apriel2.modeling_apriel2 import KimiDeltaAttention as Apriel2_KDA num_heads, head_dim = kda_config @@ -853,8 +883,11 @@ def test_vs_fla( rtol, atol = tolerance assert_close( - apriel2_out, fla_out, rtol=rtol, atol=atol, - msg=f"Apriel2 KDA vs FLA KDA (batch={batch_size}, seq={seq_len}, hidden={hidden_size})" + apriel2_out, + fla_out, + rtol=rtol, + atol=atol, + msg=f"Apriel2 KDA vs FLA KDA (batch={batch_size}, seq={seq_len}, hidden={hidden_size})", ) @@ -913,7 +946,4 @@ def test_gdn_fast_vs_slow(self, gdn_config, batch_size): slow_out = model(hidden_states)[0].clone() # Looser tolerance for kernel vs reference comparison - assert_close( - fast_out, slow_out, rtol=1e-3, atol=1e-3, - msg="GDN fast path (CUDA) vs slow path (PyTorch)" - ) + assert_close(fast_out, slow_out, rtol=1e-3, atol=1e-3, msg="GDN fast path (CUDA) vs slow path (PyTorch)") diff --git a/fast_llm_external_models/tests/test_apriel2/test_model_structure.py b/fast_llm_external_models/tests/test_apriel2/test_model_structure.py index 23856be30..56d2bc6a6 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_model_structure.py +++ b/fast_llm_external_models/tests/test_apriel2/test_model_structure.py @@ -1,9 +1,9 @@ """Tests for Apriel2 model structure and architecture validation.""" -import pytest import torch -from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForCausalLM + from fast_llm_external_models.apriel2.cache import Apriel2Cache, _AttentionCache, _SSMCache +from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForCausalLM class TestStochasticMixerStructure: @@ -14,20 +14,27 @@ def test_all_submixers_present(self, apriel2_config_all_mixers): model = Apriel2ForCausalLM(apriel2_config_all_mixers) stochastic_layer = model.model.decoder.blocks[1] # Layer 1 is the "all_mixers" layer - assert hasattr(stochastic_layer.mixer, 'mixers'), "Stochastic mixer should have 'mixers' attribute" + assert hasattr(stochastic_layer.mixer, "mixers"), "Stochastic mixer should have 'mixers' attribute" assert set(stochastic_layer.mixer.mixers.keys()) == { - 'attention', 'swa', 'mamba', 'gdn' + "attention", + "swa", + "mamba", + "gdn", }, "Stochastic mixer should contain all 4 configured mixer types" # Verify each mixer is the correct type from fast_llm_external_models.apriel2.modeling_apriel2 import ( - Apriel2Attention, Apriel2Mamba, Apriel2GatedDeltaNet + Apriel2Attention, + Apriel2GatedDeltaNet, + Apriel2Mamba, ) - assert isinstance(stochastic_layer.mixer.mixers['attention'], Apriel2Attention) - assert isinstance(stochastic_layer.mixer.mixers['swa'], Apriel2Attention) # SWA is Apriel2Attention with sliding_window - assert isinstance(stochastic_layer.mixer.mixers['mamba'], Apriel2Mamba) - assert isinstance(stochastic_layer.mixer.mixers['gdn'], Apriel2GatedDeltaNet) + assert isinstance(stochastic_layer.mixer.mixers["attention"], Apriel2Attention) + assert isinstance( + stochastic_layer.mixer.mixers["swa"], Apriel2Attention + ) # SWA is Apriel2Attention with sliding_window + assert isinstance(stochastic_layer.mixer.mixers["mamba"], Apriel2Mamba) + assert isinstance(stochastic_layer.mixer.mixers["gdn"], Apriel2GatedDeltaNet) def test_main_mixer_is_configured(self, apriel2_config_all_mixers): """Verify main_mixer_name is set correctly.""" @@ -44,7 +51,10 @@ def test_cache_has_all_submixer_slots(self, apriel2_config_all_mixers): assert isinstance(layer_cache, dict), "Stochastic layer cache should be a dict" assert set(layer_cache.keys()) == { - 'attention', 'swa', 'mamba', 'gdn' + "attention", + "swa", + "mamba", + "gdn", }, "Cache should have slots for all 4 mixers" def test_attention_mixers_use_attention_cache(self, apriel2_config_all_mixers): @@ -53,12 +63,12 @@ def test_attention_mixers_use_attention_cache(self, apriel2_config_all_mixers): layer_cache = cache.layers[1] # Attention-based mixers use AttentionCache - assert isinstance(layer_cache['attention'], _AttentionCache) - assert isinstance(layer_cache['swa'], _AttentionCache) + assert isinstance(layer_cache["attention"], _AttentionCache) + assert isinstance(layer_cache["swa"], _AttentionCache) # SSM-based mixers use SSMCache - assert isinstance(layer_cache['mamba'], _SSMCache) - assert isinstance(layer_cache['gdn'], _SSMCache) + assert isinstance(layer_cache["mamba"], _SSMCache) + assert isinstance(layer_cache["gdn"], _SSMCache) def test_parameter_counts_differ_by_config(self): """Different configs create models with different parameter counts.""" @@ -74,8 +84,10 @@ def test_parameter_counts_differ_by_config(self): } config_tiny = Apriel2Config( - vocab_size=100, hidden_size=64, - num_attention_heads=4, num_key_value_heads=2, + vocab_size=100, + hidden_size=64, + num_attention_heads=4, + num_key_value_heads=2, decoder={ "type": "fixed", "num_blocks": 2, @@ -88,8 +100,10 @@ def test_parameter_counts_differ_by_config(self): ) config_stochastic = Apriel2Config( - vocab_size=100, hidden_size=64, - num_attention_heads=4, num_key_value_heads=2, + vocab_size=100, + hidden_size=64, + num_attention_heads=4, + num_key_value_heads=2, decoder={ "type": "pattern", "num_blocks": 2, @@ -106,14 +120,14 @@ def test_parameter_counts_differ_by_config(self): "main_mixer_name": "attention", "mixers": { "attention": attn_config, - "mamba": {"type": "mamba", "conv_bias": True, "dt_proj_bias": True} - } + "mamba": {"type": "mamba", "conv_bias": True, "dt_proj_bias": True}, + }, }, "mlp": {"type": "mlp", "intermediate_size": 256, "gated": True}, "normalization": {"type": "rms_norm"}, - } - } - } + }, + }, + }, ) model_tiny = Apriel2ForCausalLM(config_tiny) @@ -122,8 +136,9 @@ def test_parameter_counts_differ_by_config(self): params_tiny = sum(p.numel() for p in model_tiny.parameters()) params_stochastic = sum(p.numel() for p in model_stochastic.parameters()) - assert params_stochastic > params_tiny, \ - "Stochastic mixer should have more parameters (has both attention and mamba)" + assert ( + params_stochastic > params_tiny + ), "Stochastic mixer should have more parameters (has both attention and mamba)" def test_weights_are_initialized(self, apriel2_config_all_mixers): """Verify model weights are initialized (not all zeros/constant).""" @@ -136,9 +151,7 @@ def test_weights_are_initialized(self, apriel2_config_all_mixers): # Basic sanity: at least some parameters should be non-zero non_zero_params = sum( - not torch.all(p == 0) - for mixer in stochastic_layer.mixer.mixers.values() - for p in mixer.parameters() + not torch.all(p == 0) for mixer in stochastic_layer.mixer.mixers.values() for p in mixer.parameters() ) assert non_zero_params > 0, "At least some mixer parameters should be non-zero" diff --git a/fast_llm_external_models/tests/test_apriel2/test_modeling.py b/fast_llm_external_models/tests/test_apriel2/test_modeling.py index 5dbd36159..8e2f610bb 100644 --- a/fast_llm_external_models/tests/test_apriel2/test_modeling.py +++ b/fast_llm_external_models/tests/test_apriel2/test_modeling.py @@ -2,18 +2,23 @@ import pytest import torch + from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForCausalLM class TestApriel2Modeling: """End-to-end tests for Apriel2 model with different configurations.""" - @pytest.mark.parametrize("config_name", [ - "apriel2_config_tiny", - "apriel2_config_stochastic", - "apriel2_config_multi_mixer", - "apriel2_config_all_mixers" # Tests all 4 mixer types - ]) + @pytest.mark.parametrize( + "config_name", + [ + "apriel2_config_tiny", + "apriel2_config_stochastic", + "apriel2_config_multi_mixer", + "apriel2_config_all_mixers", # Tests all 4 mixer types + "apriel2_config_with_bias", # Tests per-layer bias and non-gated MLP + ], + ) def test_model_end_to_end(self, config_name, request): """Test instantiation, forward pass, cache correctness, and generation. @@ -42,7 +47,7 @@ def test_model_end_to_end(self, config_name, request): # 2. Forward pass - basic shape validation outputs = model(input_ids, use_cache=False) assert outputs.logits.shape == (2, seq_len, config.vocab_size) - assert hasattr(outputs, 'logits') + assert hasattr(outputs, "logits") # 3. Verify cache is actually being used (not dormant) split_pos = 30 @@ -52,28 +57,23 @@ def test_model_end_to_end(self, config_name, request): assert outputs_part1.past_key_values is not None outputs_correct_cache = model( - input_ids[:, split_pos:split_pos+1], - past_key_values=outputs_part1.past_key_values, - use_cache=True + input_ids[:, split_pos : split_pos + 1], past_key_values=outputs_part1.past_key_values, use_cache=True ) # Test 1: Empty cache should give different results than filled cache # This verifies cache is being used at all from fast_llm_external_models.apriel2.cache import Apriel2Cache, _AttentionCache + empty_cache = Apriel2Cache(config) outputs_empty_cache = model( - input_ids[:, split_pos:split_pos+1], - past_key_values=empty_cache, - use_cache=True + input_ids[:, split_pos : split_pos + 1], past_key_values=empty_cache, use_cache=True ) - cache_affects_output = not torch.allclose( - outputs_correct_cache.logits, - outputs_empty_cache.logits, - atol=1e-3 - ) - assert cache_affects_output, f"Cache appears dormant for {config_name} - empty cache gives same results as filled cache" + cache_affects_output = not torch.allclose(outputs_correct_cache.logits, outputs_empty_cache.logits, atol=1e-3) + assert ( + cache_affects_output + ), f"Cache appears dormant for {config_name} - empty cache gives same results as filled cache" # Test 2: Corrupted cache (zeros) should give different results than correct cache # This verifies the actual cache VALUES are being used @@ -98,17 +98,15 @@ def test_model_end_to_end(self, config_name, request): corrupted_layer[name].value = torch.zeros_like(correct_sub.value) outputs_corrupted_cache = model( - input_ids[:, split_pos:split_pos+1], - past_key_values=corrupted_cache, - use_cache=True + input_ids[:, split_pos : split_pos + 1], past_key_values=corrupted_cache, use_cache=True ) cache_values_matter = not torch.allclose( - outputs_correct_cache.logits, - outputs_corrupted_cache.logits, - atol=1e-3 + outputs_correct_cache.logits, outputs_corrupted_cache.logits, atol=1e-3 ) - assert cache_values_matter, f"Cache values not used for {config_name} - zeroed cache gives same results as correct cache" + assert ( + cache_values_matter + ), f"Cache values not used for {config_name} - zeroed cache gives same results as correct cache" # 4. Cache correctness - validate cache produces same results as no-cache # Compute full sequence without cache @@ -117,18 +115,14 @@ def test_model_end_to_end(self, config_name, request): # Compute in two steps with cache outputs_part1 = model(input_ids[:, :split_pos], use_cache=True) outputs_part2 = model( - input_ids[:, split_pos:split_pos+1], - past_key_values=outputs_part1.past_key_values, - use_cache=True + input_ids[:, split_pos : split_pos + 1], past_key_values=outputs_part1.past_key_values, use_cache=True ) # Logits should match between cached and non-cached # Note: GPU execution with bfloat16/float16 has lower precision than CPU float32, # so we use a looser tolerance here. assert torch.allclose( - outputs_full.logits[:, split_pos, :], - outputs_part2.logits[:, 0, :], - atol=1e-3 + outputs_full.logits[:, split_pos, :], outputs_part2.logits[:, 0, :], atol=1e-3 ), f"Cache correctness failed for {config_name}: cached and non-cached logits differ" # 5. Generation - end-to-end validation diff --git a/fast_llm_external_models/tests/test_apriel2/test_plan_execution.py b/fast_llm_external_models/tests/test_apriel2/test_plan_execution.py new file mode 100644 index 000000000..ca0c8739f --- /dev/null +++ b/fast_llm_external_models/tests/test_apriel2/test_plan_execution.py @@ -0,0 +1,598 @@ +"""test_plan_execution.py - Plan execution and algebraic composition laws. + +This module provides rigorous, parameterized tests for the mathematical properties +that the conversion system must satisfy. Each test class corresponds to one +algebraic structure, and each test method verifies one specific law. + +Conceptual Types +================ + +The conversion system operates on three conceptual types (all ``dict`` at runtime): + +- **S (State)**: Complete config without ``init`` fields +- **P (Partial Surgery)**: Incomplete config, may have ``init`` fields +- **T (Transition Spec)**: Complete config WITH ``init`` fields + +Algebraic Structures +==================== + +1. **Partial Surgeries (P)** form a **Monoid** under deep merge:: + + compose_configs : P × P → P + Identity: {} + Associativity: (p1 ∘ p2) ∘ p3 = p1 ∘ (p2 ∘ p3) + +2. **Surgeries act on States** to produce Transition Specs:: + + compose_configs : S × P → T + compose_configs : T × P → T + + Action law (additive surgeries): (s · p1) · p2 = s · (p1 ∘ p2) + +3. **Plans** form a **Category** with composition:: + + compose : Plan(A→B) × Plan(B→C) → Plan(A→C) + Associativity: (P1 ∘ P2) ∘ P3 = P1 ∘ (P2 ∘ P3) + +4. **plan_surgery is a Functor** from config pairs to plans:: + + plan_surgery : S × T → Plan + Functoriality: compose(plan(S,T1), plan(T1,T2)) ≡ plan(S,T2) + + This is semantic equivalence: both produce identical weights when executed. + +Important Behaviors Tested +========================== + +- **init stripping**: Between surgery iterations, T → S conversion via + ``strip_init_fields()`` ensures ``init: random`` applies only to the surgery + that introduces a component. + +- **Bias inheritance**: Per-layer bias settings propagate through surgery chains. + +- **Plan composition**: Composed plans produce identical weights to direct plans. + +Design Principles +================= + +- Each law gets ONE parameterized test, not multiple similar tests +- Fixtures provide diverse configs (with/without biases) +- Corner cases are covered via parameterization, not test proliferation +- Tests document the laws they verify in their docstrings +""" + +from functools import reduce + +import pytest +import torch + +from fast_llm_external_models.apriel2.conversion import ( + Concat, + ExprPlan, + Init, + Ref, + Slice, + W, + compose, + compose_configs, + execute, + plan_surgery, +) + +# Import shared helper from conftest +from fast_llm_external_models.tests.test_apriel2.conftest import make_weights_for_config + +# ============================================================================= +# Fixtures: Use shared fixtures from conftest.py where possible +# ============================================================================= +# - base_config_dict: Complete config without biases (Llama-style) +# - base_config_with_bias_dict: Complete config with QKV biases +# - additive_surgery_chain: [wrap_stochastic, add_sliding_window, add_gdn] +# ============================================================================= + + +# ============================================================================= +# Test: Plan Composition Associativity +# ============================================================================= + + +class TestPlanCompositionAssociativity: + """ + LAW: Plan composition is associative. + + (P₁ ∘ P₂) ∘ P₃ = P₁ ∘ (P₂ ∘ P₃) + + where ∘ denotes compose(P1, P2). + + This must hold for the AST structure, not just semantic equivalence. + """ + + @pytest.mark.parametrize("expr_type", ["ref_chain", "with_concat", "with_slice", "with_init"]) + def test_associativity(self, expr_type): + """Plan composition is associative for various expression types.""" + # Build three plans that can be composed + if expr_type == "ref_chain": + p1 = ExprPlan(mappings={W("b"): Ref(key=W("a"))}) + p2 = ExprPlan(mappings={W("c"): Ref(key=W("b"))}) + p3 = ExprPlan(mappings={W("d"): Ref(key=W("c"))}) + elif expr_type == "with_concat": + p1 = ExprPlan(mappings={W("x"): Ref(key=W("a")), W("y"): Ref(key=W("b"))}) + p2 = ExprPlan(mappings={W("xy"): Concat(exprs=(Ref(key=W("x")), Ref(key=W("y"))), dim=0)}) + p3 = ExprPlan(mappings={W("final"): Ref(key=W("xy"))}) + elif expr_type == "with_slice": + p1 = ExprPlan(mappings={W("full"): Ref(key=W("src"))}) + p2 = ExprPlan(mappings={W("part"): Slice(expr=Ref(key=W("full")), slices=((0, 5, None),))}) + p3 = ExprPlan(mappings={W("out"): Ref(key=W("part"))}) + elif expr_type == "with_init": + p1 = ExprPlan(mappings={W("x"): Ref(key=W("a"))}) + p2 = ExprPlan( + mappings={W("y"): Concat(exprs=(Ref(key=W("x")), Init(shape=(5,), init_type="zeros")), dim=0)} + ) + p3 = ExprPlan(mappings={W("z"): Ref(key=W("y"))}) + + left = compose(compose(p1, p2), p3) + right = compose(p1, compose(p2, p3)) + + assert left.mappings == right.mappings, f"Associativity failed for {expr_type}" + + +# ============================================================================= +# Test: Functoriality of plan_surgery (THE CRITICAL PROPERTY) +# ============================================================================= + + +class TestPlanSurgeryFunctoriality: + """ + LAW: plan_surgery is functorial with respect to config composition. + + For a surgery chain P₁, P₂, ..., Pₙ applied to base state S₀:: + + T₁ = compose_configs(S₀, P₁) # S × P → T + T₂ = compose_configs(T₁, P₂) # T × P → T (no stripping!) + ... + Tₙ = compose_configs(Tₙ₋₁, Pₙ) + + Plan functoriality says:: + + compose(plan(S₀,T₁), plan(T₁,T₂), ...) ≡ plan(S₀, Tₙ) + + where ≡ denotes semantic equivalence (identical weights when executed). + + NOTE: This tests T × P composition WITHOUT stripping between steps. + This differs from build_plan which strips (T → S) between iterations. + Both patterns are valid: + + - Without stripping: init fields accumulate, testing plan composition purity + - With stripping: init consumed per-step, testing real usage (see + test_build_plan_strips_init_between_iterations) + + The functoriality law holds in both cases because plan composition + correctly substitutes Ref expressions with their definitions. + """ + + @pytest.mark.parametrize("chain_length", [1, 2, 3]) + @pytest.mark.parametrize("use_bias", [True, False]) + def test_functoriality( + self, + chain_length, + use_bias, + base_config_dict, + base_config_with_bias_dict, + additive_surgery_chain, + ): + """ + Composed incremental plans produce same weights as direct plan. + + Parameterized over: + - chain_length: Number of surgeries (1, 2, or 3) + - use_bias: Whether base config has biases + """ + base_config = base_config_with_bias_dict if use_bias else base_config_dict + surgeries = additive_surgery_chain[:chain_length] + + # Build config chain: C₀ → C₁ → ... → Cₙ + configs = [base_config] + for s in surgeries: + configs.append(compose_configs(configs[-1], s)) + + # Build incremental plans: Pₖ = plan_surgery(Cₖ₋₁, Cₖ) + plans = [plan_surgery(configs[i], configs[i + 1]) for i in range(len(surgeries))] + + # Compose all incremental plans + composed_plan = reduce(compose, plans) + + # Build direct plan: plan_surgery(C₀, Cₙ) + direct_plan = plan_surgery(configs[0], configs[-1]) + + # Execute both on same weights + weights = make_weights_for_config(base_config) + composed_weights = execute(composed_plan, weights, seed=42) + direct_weights = execute(direct_plan, weights, seed=42) + + # Verify semantic equivalence + assert set(composed_weights.keys()) == set( + direct_weights.keys() + ), f"Key sets differ for chain_length={chain_length}, use_bias={use_bias}" + + for key in composed_weights: + assert torch.allclose( + composed_weights[key], direct_weights[key], atol=1e-6 + ), f"Weight mismatch for {key} with chain_length={chain_length}, use_bias={use_bias}" + + @pytest.mark.parametrize("split_point", [1, 2]) + def test_arbitrary_grouping( + self, + split_point, + base_config_with_bias_dict, + additive_surgery_chain, + ): + """ + Any grouping of surgery chain produces same result. + + For surgeries [S₁, S₂, S₃], tests that: + - compose(P₁, compose(P₂, P₃)) + - compose(compose(P₁, P₂), P₃) + - plan_surgery(C₀, C₃) + + all produce identical weights. + """ + surgeries = additive_surgery_chain + + # Build config chain + configs = [base_config_with_bias_dict] + for s in surgeries: + configs.append(compose_configs(configs[-1], s)) + + # Build incremental plans + plans = [plan_surgery(configs[i], configs[i + 1]) for i in range(3)] + + # Different groupings + left_grouped = compose(compose(plans[0], plans[1]), plans[2]) + right_grouped = compose(plans[0], compose(plans[1], plans[2])) + direct = plan_surgery(configs[0], configs[-1]) + + # Execute all + weights = make_weights_for_config(base_config_with_bias_dict) + results = { + "left": execute(left_grouped, weights, seed=42), + "right": execute(right_grouped, weights, seed=42), + "direct": execute(direct, weights, seed=42), + } + + # All must match + keys = set(results["left"].keys()) + assert keys == set(results["right"].keys()) == set(results["direct"].keys()) + + for key in keys: + assert torch.allclose(results["left"][key], results["right"][key], atol=1e-6) + assert torch.allclose(results["left"][key], results["direct"][key], atol=1e-6) + + +# ============================================================================= +# Test: Bias Inheritance Preservation (Regression for the specific bug) +# ============================================================================= + + +class TestBiasInheritancePreservation: + """ + PROPERTY: Per-layer bias settings must be preserved through surgery chains. + + When a surgery spec does not mention bias settings, they must be inherited + from the source config. This is the specific failure mode of the build_plan + bug: passing partial surgery specs to plan_surgery lost inherited fields. + + This test verifies the SYMPTOM (missing biases) rather than the LAW + (functoriality). It's kept as a focused regression test. + """ + + @pytest.mark.parametrize("num_surgeries", [1, 2, 3]) + def test_qkv_biases_preserved_through_chain( + self, + num_surgeries, + base_config_with_bias_dict, + additive_surgery_chain, + ): + """QKV biases (enabled in source) appear in plan after N surgeries.""" + surgeries = additive_surgery_chain[:num_surgeries] + + # Build config and plan chain + configs = [base_config_with_bias_dict] + for s in surgeries: + configs.append(compose_configs(configs[-1], s)) + + plans = [plan_surgery(configs[i], configs[i + 1]) for i in range(num_surgeries)] + final_plan = reduce(compose, plans) if len(plans) > 1 else plans[0] + + # Check bias keys present + target_keys = {str(k) for k in final_plan.target_keys()} + + assert any("q_proj.bias" in k for k in target_keys), f"q_proj.bias missing after {num_surgeries} surgeries" + assert any("k_proj.bias" in k for k in target_keys), f"k_proj.bias missing after {num_surgeries} surgeries" + assert any("v_proj.bias" in k for k in target_keys), f"v_proj.bias missing after {num_surgeries} surgeries" + # O bias should NOT be present (disabled in source) + assert not any( + "o_proj.bias" in k for k in target_keys + ), f"o_proj.bias should not be present (disabled in source)" + + def test_bias_values_preserved( + self, + base_config_with_bias_dict, + additive_surgery_chain, + ): + """Bias tensor values are correctly transferred, not just keys.""" + surgery = additive_surgery_chain[0] # wrap_stochastic + c1 = compose_configs(base_config_with_bias_dict, surgery) + plan = plan_surgery(base_config_with_bias_dict, c1) + + weights = make_weights_for_config(base_config_with_bias_dict) + result = execute(plan, weights, seed=42) + + # Verify values match (not just that keys exist) + for i in range(base_config_with_bias_dict["decoder"]["num_blocks"]): + src_key = W(f"model.decoder.blocks.{i}.mixer.q_proj.bias") + dst_key = W(f"model.decoder.blocks.{i}.mixer.mixers.attention.q_proj.bias") + + assert dst_key in result, f"Missing {dst_key}" + assert torch.allclose(weights[src_key], result[dst_key]), f"Bias values differ for block {i}" + + +# ============================================================================= +# Test: build_plan Integration (Regression test for convert.py) +# ============================================================================= + + +class TestBuildPlanIntegration: + """ + REGRESSION: build_plan must compose configs before calling plan_surgery. + + The bug was: + plan_surgery(current_config, surgery_config) # WRONG: partial + + Should be: + target = compose_configs(current_config, surgery_config) + plan_surgery(current_config, target) # CORRECT: complete + + This test verifies the fix in convert.py's build_plan function. + """ + + @pytest.mark.parametrize("num_surgeries", [1, 2]) + def test_build_plan_preserves_inherited_fields( + self, + num_surgeries, + base_config_with_bias_dict, + additive_surgery_chain, + ): + """build_plan produces plans with inherited bias mappings.""" + from fast_llm_external_models.apriel2.convert import build_plan + + surgeries = additive_surgery_chain[:num_surgeries] + + plan, final_config = build_plan( + base_config_with_bias_dict, + surgeries, + source_format="apriel2", + ) + + # Verify inherited biases in config + if num_surgeries >= 1: + attn = final_config["decoder"]["block"]["mixer"]["mixers"]["attention"] + assert attn.get("query_layer", {}).get("bias", {}).get("enabled") is True + + # Verify bias mappings in plan + target_keys = {str(k) for k in plan.target_keys()} + assert any( + "q_proj.bias" in k for k in target_keys + ), f"build_plan with {num_surgeries} surgeries missing q_proj.bias" + + +# ============================================================================= +# Test: init Field Preservation (Critical for random initialization) +# ============================================================================= + + +class TestInitFieldPreservation: + """ + PROPERTY: The `init` field must be visible to plan_surgery. + + The `init` field controls weight initialization mode: + - `init: transfer` → use weight transfer/conversion + - `init: random` → use random initialization + + compose_configs must preserve `init` so plan_surgery can see it. + Stripping happens only at final output (when saving to disk). + """ + + def test_init_random_produces_init_expression(self, base_config_with_bias_dict): + """Surgery with init: random produces Init expressions in plan.""" + surgery = { + "decoder": { + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"init": "transfer"}, + "gdn": {"type": "gdn", "init": "random", "convolution_layer": {"kernel_size": 4}}, + }, + }, + }, + }, + } + + target = compose_configs(base_config_with_bias_dict, surgery) + plan = plan_surgery(base_config_with_bias_dict, target) + + # Check that GDN weights use Init expressions (random init) + target_keys = {str(k) for k in plan.target_keys()} + gdn_keys = [k for k in target_keys if "gdn" in k.lower()] + + assert len(gdn_keys) > 0, "No GDN keys in plan" + + # Verify at least one GDN weight uses Init (random initialization) + has_init_expr = False + for key in plan.target_keys(): + if "gdn" in str(key).lower(): + expr = plan.mappings[key] + if isinstance(expr, Init): + has_init_expr = True + break + # Also check inside Concat/other composite expressions + if hasattr(expr, "exprs"): + for sub in expr.exprs: + if isinstance(sub, Init): + has_init_expr = True + break + + assert has_init_expr, "init: random should produce Init expressions for GDN weights" + + def test_init_transfer_produces_ref_expression(self, base_config_with_bias_dict): + """Surgery with init: transfer produces Ref expressions (weight transfer).""" + surgery = { + "decoder": { + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"init": "transfer"}, + }, + }, + }, + }, + } + + target = compose_configs(base_config_with_bias_dict, surgery) + plan = plan_surgery(base_config_with_bias_dict, target) + + # Check that attention weights use Ref expressions (transfer) + has_ref = False + for key in plan.target_keys(): + if "attention" in str(key) and "q_proj.weight" in str(key): + expr = plan.mappings[key] + if isinstance(expr, Ref): + has_ref = True + break + + assert has_ref, "init: transfer should produce Ref expressions for attention weights" + + def test_build_plan_respects_init_random(self, base_config_with_bias_dict): + """build_plan correctly uses init: random for weight initialization.""" + from fast_llm_external_models.apriel2.convert import build_plan + + # Mamba requires many config fields for random init + surgery = { + "decoder": { + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"init": "transfer"}, + "mamba": { + "type": "mamba", + "init": "random", + "d_inner": 512, + "d_state": 16, + "dt_rank": 16, + "d_xb": 64, + "d_conv": 4, + "repeat_kv_before_conv": False, + "conv_bias": True, + "dt_proj_bias": True, + "dt_min": 0.001, + "dt_max": 0.1, + "dt_init_floor": 1e-4, + }, + }, + }, + }, + }, + } + + plan, final_config = build_plan( + base_config_with_bias_dict, + [surgery], + source_format="apriel2", + ) + + # Verify mamba weights use Init (random init) + has_mamba_init = False + for key in plan.target_keys(): + key_str = str(key) + if "mamba" in key_str: + expr = plan.mappings[key] + if isinstance(expr, Init): + has_mamba_init = True + break + + assert has_mamba_init, "build_plan should use Init for init: random components" + + def test_build_plan_strips_init_between_iterations(self, base_config_with_bias_dict): + """build_plan strips init between iterations (T → S conversion). + + This tests that the intermediate state between surgeries has no init fields. + The composed plan will show Init expressions because plan composition + substitutes Ref → Init, but the semantics are correct: GDN is initialized + once (in surgery 1), not re-randomized in surgery 2. + """ + from fast_llm_external_models.apriel2.conversion import compose_configs, plan_surgery, strip_init_fields + + # Surgery 1: Add GDN with random init + surgery1 = { + "decoder": { + "block": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"init": "transfer"}, + "gdn": { + "type": "gdn", + "init": "random", + "convolution_layer": {"kernel_size": 4}, + }, + }, + }, + }, + }, + } + + # Surgery 2: Add sliding window (doesn't mention GDN) + surgery2 = { + "decoder": { + "block": { + "mixer": { + "mixers": { + "sliding_window": {"init": "transfer", "window_size": 512}, + }, + }, + }, + }, + } + + # Simulate build_plan's iteration loop + s0 = base_config_with_bias_dict + + # Iteration 1 + t1 = compose_configs(s0, surgery1) + assert t1["decoder"]["block"]["mixer"]["mixers"]["gdn"].get("init") == "random" + s1 = strip_init_fields(t1) + assert s1["decoder"]["block"]["mixer"]["mixers"]["gdn"].get("init") is None + + # Iteration 2: s1 has no init for GDN + t2 = compose_configs(s1, surgery2) + assert ( + t2["decoder"]["block"]["mixer"]["mixers"]["gdn"].get("init") is None + ), "GDN should have no init in T2 (wasn't in surgery2, stripped from s1)" + + # plan_surgery(s1, t2) should use Ref for GDN (transfer, not random) + plan2 = plan_surgery(s1, t2) + gdn_uses_ref = False + for key in plan2.target_keys(): + if "gdn" in str(key): + expr = plan2.mappings[key] + if isinstance(expr, Ref): + gdn_uses_ref = True + break + + assert gdn_uses_ref, "plan_surgery(s1, t2) should use Ref for GDN (transfer from s1)" diff --git a/setup.py b/setup.py index b273e077e..5c4d0def6 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,6 @@ -import sys -import re import pathlib +import re +import sys try: import pybind11 @@ -18,6 +18,7 @@ print(f"Error: setuptools version {_SETUPTOOLS_MIN_VERSION} " "or greater is required") sys.exit(1) + def get_version(): """Read version from fast_llm/__init__.py""" init_file = pathlib.Path(__file__).parent.joinpath("fast_llm", "__init__.py").read_text() @@ -26,6 +27,7 @@ def get_version(): return version_match.group(1) raise RuntimeError("Unable to find version string in fast_llm/__init__.py") + cpp_extension = setuptools.Extension( "fast_llm.csrc.data", sources=["fast_llm/csrc/data.cpp"], diff --git a/tests/data/test_tokenizer.py b/tests/data/test_tokenizer.py index c7fdef9ca..4e9e2fdd5 100644 --- a/tests/data/test_tokenizer.py +++ b/tests/data/test_tokenizer.py @@ -40,3 +40,263 @@ def test_tokenize_with_spans(common_tokenizer, spans, expected_token_spans, expe expected_token_spans = [(begin + 1, end + 1) for begin, end in expected_token_spans] Assert.eq(tokens.tolist(), expected_tokens) Assert.eq(token_spans, expected_token_spans) + + +def test_validate_chat_template_no_template(common_tokenizer): + """Tokenizer without chat template raises.""" + with pytest.raises(ValueError, match="does not have a chat template"): + common_tokenizer.validate_chat_template() + + +def test_validate_chat_template_no_markers(common_tokenizer): + """Tokenizer with chat template but no markers raises.""" + common_tokenizer.tokenizer.chat_template = "{{ messages }}" + with pytest.raises(ValueError, match="does not contain.*generation"): + common_tokenizer.validate_chat_template() + + +def test_validate_chat_template_with_markers(common_tokenizer): + """Tokenizer with generation markers validates.""" + common_tokenizer.tokenizer.chat_template = "{% generation %}{{ m }}{% endgeneration %}" + common_tokenizer.validate_chat_template() + + +# Realistic chat template following HF conventions (e.g., SmolLM3): +# The generation block includes the full assistant turn: opening tag, content, and closing tag. +# This ensures the model learns to emit the closing tag. +CHAT_TEMPLATE = ( + "{% for message in messages %}" + "{% if message.role == 'assistant' %}" + "{% generation %}{{ message.content }}{% endgeneration %}" + "{% else %}" + "<{{ message.role }}>{{ message.content }}" + "{% endif %}" + "{% endfor %}" +) + + +@pytest.mark.parametrize( + ("messages", "expected_tokens", "expected_loss_masking_spans"), + ( + # Single turn: full assistant turn (Hello) is trainable + # 15 tokens, trainable indices 7-13, loss mask spans cover 0-6 and 14 + ( + [{"role": "user", "content": "Hi"}, {"role": "assistant", "content": "Hello"}], + [49152, 27, 789, 29, 16946, 750, 789, 2293, 17822, 29, 7371, 750, 17822, 29, 49152], + [(0, 7), (14, 15)], + ), + # Multi-turn: both assistant turns are fully trainable + # 27 tokens, trainable indices 7-13 and 19-25 + ( + [ + {"role": "user", "content": "A"}, + {"role": "assistant", "content": "B"}, + {"role": "user", "content": "C"}, + {"role": "assistant", "content": "D"}, + ], + [ + 49152, + 27, + 789, + 29, + 32, + 750, + 789, + 2293, + 17822, + 29, + 33, + 750, + 17822, + 2293, + 789, + 29, + 34, + 750, + 789, + 2293, + 17822, + 29, + 35, + 750, + 17822, + 29, + 49152, + ], + [(0, 7), (14, 19), (26, 27)], + ), + # System + user + assistant: full assistant turn trainable + # 23 tokens, trainable indices 15-21 + ( + [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello"}, + ], + [ + 49152, + 27, + 3144, + 29, + 5815, + 1139, + 44569, + 6928, + 3144, + 2293, + 789, + 29, + 16946, + 750, + 789, + 2293, + 17822, + 29, + 7371, + 750, + 17822, + 29, + 49152, + ], + [(0, 15), (22, 23)], + ), + # User only: no trainable tokens + # 9 tokens, no trainable indices + ( + [{"role": "user", "content": "Hi"}], + [49152, 27, 789, 29, 16946, 750, 789, 29, 49152], + [(0, 9)], + ), + # Long multi-turn (85 tokens, 3 assistant responses with tags, tests span machinery) + # Trainable: indices 27-40, 49-62, 70-83 + ( + [ + {"role": "system", "content": "You are a helpful assistant that answers questions."}, + {"role": "user", "content": "What is the capital of France?"}, + {"role": "assistant", "content": "The capital of France is Paris."}, + {"role": "user", "content": "What about Germany?"}, + {"role": "assistant", "content": "The capital of Germany is Berlin."}, + {"role": "user", "content": "And Italy?"}, + {"role": "assistant", "content": "The capital of Italy is Rome."}, + ], + [ + 49152, + 27, + 3144, + 29, + 5815, + 1139, + 373, + 44569, + 2424, + 11886, + 954, + 15737, + 14516, + 6928, + 3144, + 2293, + 789, + 29, + 13938, + 438, + 331, + 25016, + 457, + 12409, + 562, + 35838, + 789, + 2293, + 17822, + 29, + 2111, + 25016, + 457, + 12409, + 562, + 438, + 4235, + 280, + 6928, + 17822, + 2293, + 789, + 29, + 13938, + 5028, + 759, + 42226, + 35838, + 789, + 2293, + 17822, + 29, + 2111, + 25016, + 457, + 759, + 42226, + 438, + 29784, + 3556, + 6928, + 17822, + 2293, + 789, + 29, + 1996, + 4413, + 3326, + 35838, + 789, + 2293, + 17822, + 29, + 2111, + 25016, + 457, + 4413, + 3326, + 438, + 613, + 1361, + 6928, + 17822, + 29, + 49152, + ], + [(0, 27), (41, 49), (63, 70), (84, 85)], + ), + ), +) +def test_tokenize_chat(common_tokenizer, messages, expected_tokens, expected_loss_masking_spans): + common_tokenizer.tokenizer.chat_template = CHAT_TEMPLATE + tokens, loss_masking_spans = common_tokenizer.tokenize_chat(messages) + Assert.eq(tokens.tolist(), expected_tokens) + Assert.eq(loss_masking_spans, expected_loss_masking_spans) + + +@pytest.mark.parametrize( + ("train_mask", "expected_loss_spans"), + ( + # All masked (no trainable tokens) + ([False, False, False], [(0, 3)]), + # All trainable (no spans) + ([True, True, True], []), + # Single trainable at start + ([True, False, False], [(1, 3)]), + # Single trainable at end + ([False, False, True], [(0, 2)]), + # Single trainable in middle + ([False, True, False], [(0, 1), (2, 3)]), + # Multiple trainable regions (simulates multi-turn conversation) + ([False, False, True, True, False, False, True, True, True, False], [(0, 2), (4, 6), (9, 10)]), + # Alternating + ([False, True, False, True, False], [(0, 1), (2, 3), (4, 5)]), + ), +) +def test_train_mask_to_loss_spans(train_mask, expected_loss_spans): + from fast_llm.data.preprocessing.tokenizer import _train_mask_to_loss_spans + + Assert.eq(_train_mask_to_loss_spans(train_mask), expected_loss_spans) From 4d074943766c93020f2ce10bbbbe4a81f5e92683 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 6 Jan 2026 00:20:00 -0500 Subject: [PATCH 10/12] stuff --- fast_llm/engine/checkpoint/config.py | 3 +- fast_llm/engine/training/config.py | 132 +++---- fast_llm/engine/training/streaming.py | 70 ++++ fast_llm/engine/training/trainer.py | 19 +- fast_llm/engine/training/trainer_events.py | 109 ------ tests/data/test_streaming.py | 13 +- tests/models/test_checkpoint.py | 14 +- tests/models/test_events.py | 210 +++++++++++ tests/trainer/events_fake_consumer.py | 105 ------ tests/trainer/test_events.py | 407 --------------------- tests/utils/model_configs.py | 57 +-- tests/utils/run_test_script.py | 13 +- tests/utils/subtest.py | 8 + 13 files changed, 406 insertions(+), 754 deletions(-) create mode 100644 fast_llm/engine/training/streaming.py delete mode 100644 fast_llm/engine/training/trainer_events.py create mode 100644 tests/models/test_events.py delete mode 100644 tests/trainer/events_fake_consumer.py delete mode 100644 tests/trainer/test_events.py diff --git a/fast_llm/engine/checkpoint/config.py b/fast_llm/engine/checkpoint/config.py index 3f1970538..98303539e 100644 --- a/fast_llm/engine/checkpoint/config.py +++ b/fast_llm/engine/checkpoint/config.py @@ -141,11 +141,12 @@ class CheckpointSaveConfigBase(CheckpointConfigBase): @config_class() class CheckpointStateSaveConfigBase(CheckpointSaveConfigBase, CheckpointStateConfigBase): + _abstract = False model_weights: bool = FieldUpdate(desc="Save the model weights.") optimizer_state: bool = FieldUpdate(desc="Save the optimizer state. Default: save if supported by the `format`.") def _validate(self) -> None: - if self.optimizer_state is None: + if self.optimizer_state is None and hasattr(self.format, "support_optimizer"): with self._set_implicit_default(): # TODO: Make sure it's a type self.optimizer_state = self.format.support_optimizer diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 02829c580..ba52e6b5a 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -7,6 +7,7 @@ from fast_llm.config import ( Config, + Configurable, Field, FieldHint, FieldUpdate, @@ -25,6 +26,7 @@ ) from fast_llm.engine.config_utils.run import ExperimentConfig from fast_llm.engine.config_utils.runnable import RunnableConfig +from fast_llm.engine.distributed.config import DistributedBackend from fast_llm.engine.evaluation.config import EvaluatorConfig, EvaluatorConfigBase from fast_llm.engine.multi_stage.config import PretrainedFastLLMModelConfig from fast_llm.engine.optimizer.config import OptimizerConfig @@ -33,6 +35,8 @@ from fast_llm.utils import Assert if typing.TYPE_CHECKING: + from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel + from fast_llm.engine.training.streaming import StreamingTrainerCallback from fast_llm.engine.training.trainer import Trainer, TrainingEvaluator @@ -322,100 +326,64 @@ def _validate(self) -> None: self.wandb.alert.assert_sub_interval(self.logs) -@config_class() -class TrainerEvent(Config): - enabled: bool = Field( - default=False, - desc="Flag indicating whether this event is enabled. If False, the event will be skipped.", - hint=FieldHint.feature, - ) - - -@config_class() -class WeightsBroadcastEventConfig(TrainerEvent): - """ - Event sent to indicate that updated weights are ready for broadcast. - """ - - initial_weights_step_message_type: str = Field( - default="initial_weights_step", - desc="Message indicating that weights the training starting/ continuing from.", - hint=FieldHint.feature, - ) +@config_class(registry=True) +class TrainerCallbackConfig(Config): + def get_callback(self, model: "FastLLMModel") -> "TrainerCallback": + raise NotImplementedError() - initial_weights_step_message_includes_weights: bool = Field( - default=False, - desc=( - "Whether to include the loaded model weights in the initial event message. " - "Useful when training restarts from an internal checkpoint format that " - "which does not have an exported checkpoint for that step." - ), - hint=FieldHint.feature, - ) + def setup(self, config: "TrainerConfig") -> None: + pass - weights_ready_message_type: str = Field( - default="weights_ready", - desc="Message indicating that weights are ready to be broadcast.", - hint=FieldHint.feature, - ) - # NCCL rendezvous details - rdvz_master_address: str | None = Field( - default=None, +@config_class() +class WeightsBroadcastConfig(Config): + # TODO: Have the external model send these instead? + host: str = Field( + default="localhost", desc="Master address for the external NCCL process group.", hint=FieldHint.feature, ) - - rdvz_master_port: int | None = Field( - default=None, + port: int = Field( + default=23456, desc="Master port for the external NCCL process group.", hint=FieldHint.feature, ) - - world_size: int | None = Field( - default=None, + external_world_size: int = Field( + default=1, desc="World size of the external NCCL process group.", hint=FieldHint.feature, ) - - rank: int | None = Field( - default=None, - desc="Rank of this process in the external NCCL process group.", + backend: DistributedBackend = Field( + default=DistributedBackend.nccl, + desc="Backend for the external NCCL process group.", hint=FieldHint.feature, ) -@config_class() -class TrainingFinishedEventConfig(TrainerEvent): - """ - Event sent to indicate that training has completed. - """ - - training_finished_message_type: str = Field( - default="training_finished", - desc="Message indicating that weights the training starting/ continuing from.", - hint=FieldHint.feature, - ) - - -@config_class() -class TrainerEventsConfig(RedisConfig): +@config_class(dynamic_type={TrainerCallbackConfig: "streaming"}) +class StreamingTrainerCallbackConfig(TrainerCallbackConfig, RedisConfig): """ Aggregates all trainer-side Redis-based event configurations. """ - weights_broadcast: WeightsBroadcastEventConfig = Field( - default=None, + broadcast: WeightsBroadcastConfig = Field( desc="Configuration for signaling weight-ready events via Redis.", - hint=FieldHint.feature, + hint=FieldHint.core, ) - training_finished: TrainingFinishedEventConfig = Field( - default=None, - desc="Configuration for signaling training-finished events via Redis.", - hint=FieldHint.feature, + export: CheckpointStateSaveConfigBase = Field( + desc="Configuration for exporting checkpoints before broadcasting them.", + hint=FieldHint.core, ) + def get_callback(self, model: "FastLLMModel") -> "StreamingTrainerCallback": + from fast_llm.engine.training.streaming import StreamingTrainerCallback + + return StreamingTrainerCallback(self, model) + + def setup(self, config: "TrainerConfig") -> None: + self.export.setup(config.model) + @config_class(registry=True, dynamic_type={RunnableConfig: "train"}) class TrainerConfig(PretrainedFastLLMModelConfig, ExperimentConfig): @@ -448,14 +416,16 @@ class TrainerConfig(PretrainedFastLLMModelConfig, ExperimentConfig): hint=FieldHint.feature, ) - events: TrainerEventsConfig = Field( - default=None, - desc="Optional Trainer event configurations (weight broadcast, training finished, etc.).", + callbacks: dict[str, TrainerCallbackConfig] = Field( + default_factory=dict, + desc="Configuration for training callbacks.", hint=FieldHint.feature, ) def _validate(self) -> None: self.training.export.setup(self.model) + for callback in self.callbacks.values(): + callback.setup(self) for reference_model in self.reference_models.values(): self._add_reference_distributed_to_pretrained(reference_model) super()._validate() @@ -505,3 +475,21 @@ def new_setup(): old_setup() object.__setattr__(pretrained, "_setup", new_setup) + + +class TrainerCallback[ConfigType: TrainerCallbackConfig](Configurable[ConfigType]): + # TODO: Make a more exhaustive set of events and arguments. + def run_begin(self, step: int): + pass + + def step_end( + self, + step: int, + reduced_losses: dict[str, float | int], + update_successful: bool, + train_metrics: dict[str, typing.Any] | None, + ): + pass + + def train_end(self, step: int): + pass diff --git a/fast_llm/engine/training/streaming.py b/fast_llm/engine/training/streaming.py new file mode 100644 index 000000000..2009e9476 --- /dev/null +++ b/fast_llm/engine/training/streaming.py @@ -0,0 +1,70 @@ +import json +import logging +import typing + +import torch.distributed + +from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel +from fast_llm.engine.training.config import StreamingTrainerCallbackConfig, TrainerCallback + +logger = logging.getLogger(__name__) + + +REDIS_TRAINING_KEY = "fast_llm_events" + + +class StreamingTrainerCallback[ConfigType: StreamingTrainerCallbackConfig](TrainerCallback[ConfigType]): + def __init__(self, config: ConfigType, model: "FastLLMModel"): + super().__init__(config) + self._model = model + self._do_broadcast = self._model.config.distributed.rank == 0 + if self._do_broadcast: + self._client = self._config.get_client() + init_method = f"tcp://{config.broadcast.host}:{config.broadcast.port}" + logger.info(f"Waiting for weights broadcast rendezvous at {init_method} ...") + # TODO: Create a custom process group instead. + self._process_group = torch.distributed.init_process_group( + backend="nccl", + init_method=init_method, + world_size=config.broadcast.external_world_size + 1, + rank=0, + ) + logger.info(f"Weights broadcast rendezvous at {init_method} connected") + + def run_begin(self, step: int): + # TODO: ====== Send a train / run begin signal? ====== + self._broadcast_weights(step) + + def step_end( + self, + step: int, + reduced_losses: dict[str, float | int], + update_successful: bool, + train_metrics: dict[str, typing.Any] | None, + ): + if update_successful: + self._broadcast_weights(step) + + def train_end(self, step: int): + # TODO: ====== Send something on unsuccessful ends? ====== + if self._do_broadcast: + self._client.xadd(REDIS_TRAINING_KEY, {"event": json.dumps({"type": "training_finished"})}) + torch.distributed.destroy_process_group() + + def __del__(self): + if self._do_broadcast: + torch.distributed.destroy_process_group() + + def _broadcast_weights(self, step: int): + if self._do_broadcast: + self._client.xadd(REDIS_TRAINING_KEY, {"event": json.dumps({"type": "weights_ready", "step": step})}) + for shard_name, layer_name, tensor in self._model.iter_checkpoint(self._config.get_save_config("", 10), {}): + if self._do_broadcast: + # TODO: ====== Broadcast metadata in advance ======= + meta = [(shard_name, layer_name, tensor.shape, tensor.dtype)] + torch.distributed.broadcast_object_list(meta, group=self._process_group, group_src=0) + torch.distributed.broadcast(tensor, group=self._process_group, group_src=0) + # Broadcast end of weights broadcast + if self._do_broadcast: + meta = [None] + torch.distributed.broadcast_object_list(meta, group=self._process_group, group_src=0) diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index a2f98c05f..5ef625316 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -36,7 +36,6 @@ TrainingCheckpointConfig, TrainingEvaluatorConfig, ) -from fast_llm.engine.training.trainer_events import TrainerEvents from fast_llm.engine.training.wandb import Wandb from fast_llm.logging import format_metrics, log_memory_usage from fast_llm.utils import Assert, Interrupter, get_and_reset_memory_usage_mib @@ -132,8 +131,6 @@ def __init__(self, config: TrainerConfig): self._is_evaluation_only = config.training.train_iters == 0 - self.trainer_events = TrainerEvents(config.events) - self._data = self._get_data() log_main_rank("Creating model...") self._multi_stage = self._config.model.get_model_class()( @@ -154,6 +151,9 @@ def __init__(self, config: TrainerConfig): distributed_config=self._config.model.distributed, ) self._loss_definitions = self._multi_stage.base_model.get_loss_definitions() + self._callbacks = { + name: config.get_callback(self._multi_stage) for name, config in self._config.callbacks.items() + } if not self._is_evaluation_only: steps_per_split = { @@ -289,7 +289,8 @@ def run(self) -> None: assert self._is_setup with self._wandb: self._run_training() - self.trainer_events.send_training_finished() + for callback in self._callbacks.values(): + callback.train_end(self._completed_steps) def _run_training(self) -> None: self._prepare_training_state() @@ -363,9 +364,8 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: # TODO: Synchronization is probably unnecessary. safe_barrier(self._distributed.world_group, "train begin") - self.trainer_events.send_initial_weights_step( - self._completed_steps, self._multi_stage, self._config.training.export - ) + for callback in self._callbacks.values(): + callback.run_begin(self._completed_steps) torch.cuda.synchronize() start_time = time.perf_counter() @@ -393,13 +393,12 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: advanced_iters += 1 for name, value in reduced_losses.items(): total_losses[name] += value - self.trainer_events.send_weights( - self._completed_steps, self._multi_stage, self._config.training.export - ) else: skipped_iters += 1 nan_iters += not all(math.isfinite(loss) for loss in reduced_losses.values()) + for callback in self._callbacks.values(): + callback.step_end(self._completed_steps, reduced_losses, update_successful, train_metrics) # Logging. metrics = {} if is_logging: diff --git a/fast_llm/engine/training/trainer_events.py b/fast_llm/engine/training/trainer_events.py deleted file mode 100644 index 0937999fc..000000000 --- a/fast_llm/engine/training/trainer_events.py +++ /dev/null @@ -1,109 +0,0 @@ -import json -import logging - -import redis -import torch.distributed - -from fast_llm.data.dataset.config import RedisConfig -from fast_llm.engine.config_utils.run import is_main_rank -from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel -from fast_llm.engine.training.config import TrainerEventsConfig, TrainingExportConfig - -logger = logging.getLogger(__name__) - - -REDIS_TRAINING_KEY = "fast_llm_events" - - -class RedisEventSender: - def __init__(self, config: RedisConfig): - self.config = config - self.client = None - - if is_main_rank(): - self.client = redis.Redis( - host=config.host, - port=config.port, - ) - - def send(self, msg_type: str, payload: dict | None = None): - if not is_main_rank(): - return - - if not payload: - payload = {} - payload.update({"type": msg_type}) - - self.client.xadd(REDIS_TRAINING_KEY, {"event": json.dumps(payload)}) - - -class TrainerEvents: - """ - Main helper class holding all event channels. - Each event may have its own RedisConfig. - - Usage: - events = TrainerEvents(cfg.events) - events.weights_broadcast.send({"step": 100}) - events.training_finished.send() - """ - - def __init__(self, config: TrainerEventsConfig): - self.config = config - - if config.weights_broadcast.enabled or config.training_finished.enabled: - self.sender = RedisEventSender(config.redis) - else: - self.sender = None - - if config.weights_broadcast.enabled and is_main_rank(): - init_method = ( - f"tcp://{config.weights_broadcast.rdvz_master_address}:{config.weights_broadcast.rdvz_master_port}" - ) - logger.info(f"Waiting for weights broadcast rendezvous at {init_method} ...") - self.weights_pg = torch.distributed.init_process_group( - backend="nccl", - init_method=init_method, - world_size=config.weights_broadcast.world_size, - rank=config.weights_broadcast.rank, - ) - logger.info(f"Weights broadcast rendezvous at {init_method} connected") - else: - self.weights_pg = None - - def send_initial_weights_step(self, step: int, model: FastLLMModel, export_config: TrainingExportConfig): - if self.config.weights_broadcast.enabled: - self.sender.send( - msg_type=self.config.weights_broadcast.initial_weights_step_message_type, payload={"step": step} - ) - if self.config.weights_broadcast.initial_weights_step_message_includes_weights: - self._broadcast_weights(model, export_config) - - def send_weights(self, step: int, model: FastLLMModel, export_config: TrainingExportConfig): - if self.config.weights_broadcast.enabled: - self.sender.send(msg_type=self.config.weights_broadcast.weights_ready_message_type, payload={"step": step}) - self._broadcast_weights(model, export_config) - - def send_training_finished(self): - if self.config.training_finished.enabled: - self.sender.send(msg_type=self.config.training_finished.training_finished_message_type) - - if is_main_rank() and self.config.weights_broadcast.enabled: - torch.distributed.destroy_process_group() - - def _broadcast_weights(self, model: FastLLMModel, export_config: TrainingExportConfig): - for shard_name, layer_name, tensor in model.iter_checkpoint(export_config.get_save_config("", 10), {}): - if is_main_rank(): - meta = [(shard_name, layer_name, tensor.shape, tensor.dtype)] - torch.distributed.broadcast_object_list( - meta, group=self.weights_pg, group_src=self.config.weights_broadcast.rank - ) - torch.distributed.broadcast( - tensor, group=self.weights_pg, group_src=self.config.weights_broadcast.rank - ) - # Broadcast end of weights broadcast - if is_main_rank(): - meta = [None] - torch.distributed.broadcast_object_list( - meta, group=self.weights_pg, group_src=self.config.weights_broadcast.rank - ) diff --git a/tests/data/test_streaming.py b/tests/data/test_streaming.py index e16583e8f..d8953488f 100644 --- a/tests/data/test_streaming.py +++ b/tests/data/test_streaming.py @@ -20,7 +20,8 @@ from fast_llm.engine.distributed.distributed import Distributed from fast_llm.models.gpt.config import GPTBatchConfig from fast_llm.utils import Assert -from tests.utils.redis import find_free_port, make_sampling, push_msg, redis_batch_producer +from tests.conftest import WorkerResources +from tests.utils.redis import make_sampling, push_msg, redis_batch_producer from tests.utils.subtest import DistributedTestContext from tests.utils.utils import requires_cuda @@ -49,9 +50,10 @@ def fake_redis(monkeypatch): def test_streaming_dataset( fake_redis: fakeredis.FakeRedis, messages: tuple[list[int], ...], + worker_resources: WorkerResources, ): """StreamingDataset should read a message and convert it into LanguageModelSample.""" - stream_config = StreamingDatasetConfig(port=find_free_port()) + stream_config = StreamingDatasetConfig(port=worker_resources.torchrun_port) dataset_iterator = iter(RedisStreamingDataset(stream_config, DistributedConfig())) for message in messages: push_msg(fake_redis, list(message)) @@ -89,9 +91,10 @@ def test_streaming_sampled_dataset( messages: tuple[list[int], ...], expected_samples: tuple[list[int], ...], expected_lengths: tuple[int, ...], + worker_resources: WorkerResources, ): """StreamingDataset should read a message and convert it into LanguageModelSample.""" - stream_config = StreamingDatasetConfig(port=find_free_port()) + stream_config = StreamingDatasetConfig(port=worker_resources.torchrun_port) distributed = Distributed(DistributedConfig(), use_cpu=True) dataset_iterator = iter( RedisStreamingDataset(stream_config, distributed.config).sample(make_sampling(5, 1, distributed)) @@ -216,6 +219,7 @@ def test_data_streaming(result_path, worker_resources): @requires_cuda +@pytest.mark.slow @pytest.mark.depends_on(on=["test_data_streaming"]) def test_run_data_streaming_distributed(run_parallel_script, result_path, worker_resources): if torch.cuda.device_count() < 2: @@ -228,9 +232,10 @@ def test_run_data_streaming_distributed(run_parallel_script, result_path, worker @requires_cuda +@pytest.mark.slow @pytest.mark.depends_on(on=["test_data_streaming"]) @pytest.mark.parametrize(("name", "num_gpus", "distributed_config_dict"), _DISTRIBUTED_TESTING_CONFIGS) -def test_run_streaming_distributed(result_path, name, num_gpus, distributed_config_dict, report_subtest): +def test_data_streaming_distributed(result_path, name, num_gpus, distributed_config_dict, report_subtest): report_subtest(path := result_path / f"data_streaming/{name}", num_gpus) distributed_config, batch_config = _get_distributed_and_batch_config(distributed_config_dict, num_gpus) check_data_streaming_results(path, distributed_config, batch_config) diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 6f164a33e..41d0952c6 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -155,7 +155,7 @@ def test_conversion(model_testing_config, run_conversion, get_convert_path): ) -def _compare_safetensor_files( +def compare_safetensor_files( reference: pathlib.Path | dict[str, torch.Tensor], *other_paths: pathlib.Path, expected_keys: set[str] | None = None, @@ -180,24 +180,24 @@ def _compare_safetensor_files( def test_converted_round_trip(model_testing_config, get_convert_path): # Test that the various possible conversion paths yield identical results. if model_testing_config.checkpoint_format is None: - _compare_safetensor_files( + compare_safetensor_files( get_convert_path() / "rank_0.safetensors", get_convert_path(DistributedCheckpointFormat, FastLLMCheckpointFormat) / "rank_0.safetensors", expected_keys={_WEIGHT_SHARD_SAVE_NAME}, ) else: - _compare_safetensor_files( + compare_safetensor_files( get_convert_path() / "rank_0.safetensors", get_convert_path(DistributedCheckpointFormat, FastLLMCheckpointFormat) / "rank_0.safetensors", get_convert_path(DistributedCheckpointFormat, model_testing_config.checkpoint_format) / "rank_0.safetensors", expected_keys={_WEIGHT_SHARD_SAVE_NAME}, ) - _compare_safetensor_files( + compare_safetensor_files( get_convert_path(FastLLMCheckpointFormat, DistributedCheckpointFormat) / "model_0.safetensors", get_convert_path(FastLLMCheckpointFormat, model_testing_config.checkpoint_format) / "model_0.safetensors", ) - _compare_safetensor_files( + compare_safetensor_files( get_convert_path(model_testing_config.checkpoint_format, DistributedCheckpointFormat) / "model_0.safetensors", get_convert_path(model_testing_config.checkpoint_format, FastLLMCheckpointFormat) / "model_0.safetensors", @@ -499,7 +499,7 @@ def test_parallel_checkpoint_consistency(model_testing_config, run_test_script_b # Compare Distributed checkpoints for config in ("dp2", "tp2", "stp2", "pp2"): for rank in range(2): - _compare_safetensor_files( + compare_safetensor_files( *[ DISTRIBUTED_SAVE_LOAD_CONFIGS[f"load_{format_}_in_{config}"] .resolve(base_path=run_test_script_base_path, model_testing_config=model_testing_config) @@ -537,7 +537,7 @@ def test_multi_gpu_fast_llm_checkpoint( base_path=run_test_script_base_path, model_testing_config=model_testing_config ) - _compare_safetensor_files( + compare_safetensor_files( reference_fast_llm_shard, distributed_save_load_config_non_pp.save_path / f"{FastLLMCheckpointFormat.name}/model_0.safetensors", ) diff --git a/tests/models/test_events.py b/tests/models/test_events.py new file mode 100644 index 000000000..b1e538a05 --- /dev/null +++ b/tests/models/test_events.py @@ -0,0 +1,210 @@ +import contextlib +import dataclasses +import functools +import json +import logging +import pathlib + +import pytest +import safetensors +import torch + +from fast_llm.engine.training.config import StreamingTrainerCallbackConfig +from fast_llm.engine.training.streaming import REDIS_TRAINING_KEY +from fast_llm.models.gpt.config import GPTBatchConfig +from fast_llm.utils import Assert +from tests.conftest import WorkerResources +from tests.models.test_checkpoint import compare_safetensor_files +from tests.utils.distributed_configs import DistributedTestingConfig +from tests.utils.model_configs import ( + MODEL_CONFIGS, + ModelTestingConfig, + ModelTestingGroup, + update_and_add_testing_config, +) +from tests.utils.redis import redis_batch_producer +from tests.utils.run_test_script import do_run_test_script_for_all_models +from tests.utils.subtest import DistributedTestContext +from tests.utils.utils import requires_cuda + + +@dataclasses.dataclass(kw_only=True) +class StreamingDistributedTestingConfig(DistributedTestingConfig): + consumer_count: int = (1,) + + @functools.cached_property + def total_gpus(self) -> int: + return self.num_gpus + self.consumer_count + + +_DISTRIBUTED_STREAMING_CONFIGS = [ + StreamingDistributedTestingConfig(name="streaming_simple", config_args=[], num_gpus=1, consumer_count=1), + StreamingDistributedTestingConfig(name="streaming_dp2", config_args=[], num_gpus=2, consumer_count=1), + StreamingDistributedTestingConfig( + name="streaming_sdp2_c2", + config_args=["model.distributed.sequence_data_parallel=2"], + num_gpus=2, + consumer_count=2, + ), + StreamingDistributedTestingConfig( + name="streaming_tp2", config_args=["model.distributed.tensor_parallel=2"], num_gpus=2, consumer_count=2 + ), + StreamingDistributedTestingConfig( + name="streaming_stp2_c2", + config_args=[ + "model.distributed.tensor_parallel=2", + "model.distributed.sequence_tensor_parallel=true", + "callbacks.streaming.broadcast.external_world_size=2", + ], + num_gpus=2, + consumer_count=2, + ), + StreamingDistributedTestingConfig( + name="streaming_pp2s2_bf4", + config_args=[ + "model.distributed.pipeline_parallel=2", + "model.multi_stage.layers_per_stage=2", + "batch.breadth_first_micro_batches=4", + ], + num_gpus=2, + consumer_count=2, + ), +] + + +def _run_event_consumer( + streaming_config: StreamingTrainerCallbackConfig, consumer_index: int, base_path: pathlib.Path +) -> None: + client = streaming_config.get_client() + init_method = f"tcp://{streaming_config.broadcast.host}:{streaming_config.broadcast.port}" + logging.info(f"Waiting for weights broadcast rendezvous at {init_method} ...") + path = base_path / "streaming" + path.mkdir(parents=True, exist_ok=True) + # TODO: Create a custom process group instead. + try: + process_group = torch.distributed.init_process_group( + backend="nccl", + init_method=init_method, + world_size=streaming_config.broadcast.external_world_size + 1, + rank=consumer_index + 1, + ) + last_id = "0-0" + while True: + result = client.xread( + streams={REDIS_TRAINING_KEY: last_id}, + count=1, + block=10000, + ) + if not result: + raise TimeoutError("No message received after 10000 ms...") + + for _, (event_id, message) in result[0]: + last_id = event_id + message = json.loads(message.decode()) + logging.info(f"Received: {message}") + Assert.eq(message.keys(), {"event"}) + message = message["event"] + if message["type"] == "training_finished": + return + elif message["type"] == "weights_ready": + weights = {} + while True: + meta = [None] + torch.distributed.broadcast_object_list(meta, group=process_group, group_src=0) + if meta[0] is None: + print(f"Weight broadcast finished") + break + logging.info(f"receiving {meta[0]}") + shard_name, layer_name, tensor_size, tensor_type = meta[0] + tensor = torch.zeros(tuple(tensor_size), dtype=tensor_type, device="cuda") + torch.distributed.broadcast(tensor, group=process_group, group_src=0) + if shard_name == "weights": + weights[layer_name] = tensor + safetensors.torch.save_file( + weights, path / f"rank_{consumer_index}_step_{message["step"]}.safetensors" + ) + + finally: + torch.distributed.destroy_process_group() + + +def _run_model_streaming_configs( + test_context: DistributedTestContext, base_path: pathlib.Path, model_testing_config: ModelTestingConfig, port: int +) -> None: + # Import all dynamic classes. + import fast_llm.cli # noqa + + for config in _DISTRIBUTED_STREAMING_CONFIGS: + model_testing_config = update_and_add_testing_config( + model_testing_config, + None, + updates={ + ("data", "datasets"): {"training": {"port": port}}, + "callbacks": { + "streaming": { + "type": "streaming", + "port": port, + "broadcast": { + "port": port + 1000, + "external_world_size": config.consumer_count, + }, + "export": {"format": MODEL_CONFIGS["mistral"].checkpoint_format.name}, + } + }, + }, + groups={}, + ) + with test_context.subtest(base_path, config.name, config.total_gpus) as subtest: + if subtest.do_run: + if test_context.rank < config.num_gpus: + do_run_test_script_for_all_models(config, model_testing_config, base_path) + elif test_context.rank < config.total_gpus: + streaming_config = StreamingTrainerCallbackConfig.from_dict( + model_testing_config.config_dict["callbacks"]["streaming"] + ) + batch_config = GPTBatchConfig.from_dict(model_testing_config.config_dict["batch"]) + with ( + redis_batch_producer(streaming_config, batch_config) + if test_context.rank == config.num_gpus + else contextlib.nullcontext() + ): + _run_event_consumer(streaming_config, test_context.rank - config.num_gpus, base_path) + + +@requires_cuda +@pytest.mark.slow +@pytest.mark.model_testing_group(ModelTestingGroup.streaming, ModelTestingGroup.distributed) +def test_run_model_distributed_streaming( + run_parallel_script, model_testing_config, run_test_script_base_path, worker_resources +): + if torch.cuda.device_count() < 2: + pytest.skip(f"Not enough GPUs") + run_parallel_script( + _run_model_streaming_configs, + (run_test_script_base_path, model_testing_config, worker_resources.torchrun_port), + world_size=torch.cuda.device_count(), + backend=model_testing_config.distributed_backend, + ) + + +@pytest.mark.slow +@requires_cuda +@pytest.mark.model_testing_group(ModelTestingGroup.streaming, ModelTestingGroup.distributed) +@pytest.mark.parametrize("config", _DISTRIBUTED_STREAMING_CONFIGS) +def test_model_distributed_streaming( + config: StreamingDistributedTestingConfig, + run_distributed_script, + model_testing_config, + run_test_script_base_path, + worker_resources: WorkerResources, + report_subtest, +): + report_subtest(path := run_test_script_base_path / config.name, config.total_gpus) + compare_safetensor_files( + path / "export" / model_testing_config.checkpoint_format.name / "1", + *( + path / "streaming" / f"rank_{consumer_index}_step_{step}.safetensors" + for consumer_index in range(config.consumer_count) + for step in range(3) + ), + ) diff --git a/tests/trainer/events_fake_consumer.py b/tests/trainer/events_fake_consumer.py deleted file mode 100644 index 9692134db..000000000 --- a/tests/trainer/events_fake_consumer.py +++ /dev/null @@ -1,105 +0,0 @@ -import json -import sys -from pathlib import Path - -import redis -import safetensors.torch -import torch.distributed -import yaml - - -def main(): - if len(sys.argv) != 2: - print("Usage: python -m tests.trainer.events_fake_consumer ") - sys.exit(1) - - config_path = Path(sys.argv[1]) - if not config_path.exists(): - print(f"Config file {config_path} does not exist") - sys.exit(1) - - with config_path.open("rt") as f: - config = yaml.safe_load(f) - - consumer_cfg = config["consumer"] - world_size = consumer_cfg["world_size"] - rank = consumer_cfg["rank"] - results_path = Path(consumer_cfg["results_path"]) - results_path.mkdir(parents=True, exist_ok=True) - - consumer_id = f"[Consumer {rank}/{world_size}]" - - print(f"{consumer_id} Started with config:") - print(yaml.safe_dump(config)) - - assert config["events"]["weights_broadcast"]["enabled"] - assert config["events"]["training_finished"]["enabled"] - - redis_client = redis.Redis(host=config["events"]["redis"]["host"], port=config["events"]["redis"]["port"]) - - print(f"{consumer_id} waiting for pg rendezvous...") - weights_pg = torch.distributed.init_process_group( - backend="nccl", - init_method=f'tcp://{config["events"]["weights_broadcast"]["rdvz_master_address"]}:' - f'{config["events"]["weights_broadcast"]["rdvz_master_port"]}', - world_size=world_size, - rank=rank, - ) - broadcast_source_rank = config["events"]["weights_broadcast"]["rank"] - - last_id = "0-0" - msg_key = config["events"]["redis"]["payload_key"].encode() - stream_key = config["events"]["redis"]["stream_key"] - - print(f"{consumer_id} waiting for messages...") - while True: - result = redis_client.xread( - streams={stream_key: last_id}, - count=1, - block=200, - ) - - if not result: - continue - - _, events = result[0] - - for event_id, msg in events: - last_id = event_id - assert msg_key in msg - msg = json.loads(msg[msg_key].decode()) - print(f"{consumer_id} msg received: {msg}") - if msg["type"] == config["events"]["weights_broadcast"]["weights_ready_message_type"] or ( - msg["type"] == config["events"]["weights_broadcast"]["initial_weights_step_message_type"] - and config["events"]["weights_broadcast"]["initial_weights_step_message_includes_weights"] - ): - weights = {} - while True: - meta = [None] - torch.distributed.broadcast_object_list(meta, group=weights_pg, group_src=broadcast_source_rank) - meta = meta[0] - if meta is None: - print(f"{consumer_id} weight broadcast finished") - break - shard_name, layer_name, tensor_size, tensor_type = meta - tensor = torch.zeros( - tuple(tensor_size), dtype=tensor_type, device="cuda" - ) # so far consumer is single gpu only - torch.distributed.broadcast(tensor, group=weights_pg, group_src=broadcast_source_rank) - print(f"{consumer_id} {shard_name} layer {layer_name} {tensor_size} {tensor_type} received") - if shard_name == "weights": - weights[layer_name] = tensor - safetensors.torch.save_file(weights, results_path / f"{msg["step"]}.safetensors") - - elif msg["type"] == config["events"]["training_finished"]["training_finished_message_type"]: - torch.distributed.destroy_process_group() - (results_path / "training_finished").touch() - return - else: - raise RuntimeError(f"{consumer_id} Received unknown message type {msg}") - if msg["type"] == config["events"]["weights_broadcast"]["initial_weights_step_message_type"]: - (results_path / "initial_weights_step").touch() - - -if __name__ == "__main__": - main() diff --git a/tests/trainer/test_events.py b/tests/trainer/test_events.py deleted file mode 100644 index baa0526e3..000000000 --- a/tests/trainer/test_events.py +++ /dev/null @@ -1,407 +0,0 @@ -import contextlib -import copy -import os -import pathlib -import subprocess -import time -import typing - -import pytest -import safetensors -import torch -import yaml - -from fast_llm.data.dataset.config import StreamingDatasetConfig -from tests.utils.model_configs import MODEL_CONFIGS -from tests.utils.redis import redis_batch_producer -from tests.utils.utils import requires_cuda - - -@contextlib.contextmanager -def run_fake_events_consumers( - model_config: dict, - test_result_path: pathlib.Path, - broadcast_world_size: int, - fake_consumers_broadcast_ranks: list[int], - assigned_gpus: list[str], - timeout: float = 30.0, # seconds -): - """ - Context manager to run fake event consumer subprocesses for testing. - - Each subprocess gets a separate config and CUDA_VISIBLE_DEVICES. - - After exiting the context, all subprocesses are ensured to terminate. - Raises RuntimeError if any subprocess exits with non-zero code. - """ - import tests.trainer.events_fake_consumer - - assert len(assigned_gpus) > 0 - assert len(assigned_gpus) == len(fake_consumers_broadcast_ranks) - - processes = [] - - try: - for i, gpu in enumerate(assigned_gpus): - consumer_path = test_result_path / str(i) - consumer_path.mkdir(parents=True, exist_ok=True) - - # Deep copy config and update per consumer - this_config = copy.deepcopy(model_config) - this_config["consumer"] = { - "idx": i, - "results_path": consumer_path / "results", - "world_size": broadcast_world_size, - "rank": fake_consumers_broadcast_ranks[i], - } - this_config_path = consumer_path / "config.yaml" - - # Save config as YAML - with open(this_config_path, "w") as f: - yaml.safe_dump(convert_paths(this_config), f) - - # Build subprocess command - script = [ - "python", - "-m", - tests.trainer.events_fake_consumer.__name__, - str(this_config_path), - ] - env = os.environ.copy() - env["CUDA_VISIBLE_DEVICES"] = str(gpu) - - # Start subprocess - proc = subprocess.Popen(script, env=env) - processes.append(proc) - - # Yield control to the caller while subprocesses run - yield - - finally: - # Wait for processes to exit or kill after timeout - start_time = time.time() - for proc in processes: - try: - remaining = max(0, timeout - (time.time() - start_time)) - proc.wait(timeout=remaining) - except subprocess.TimeoutExpired: - proc.kill() - - # Check exit codes - errors = [(i, p.returncode) for i, p in enumerate(processes) if p.returncode != 0] - if errors: - raise RuntimeError(f"Some fake consumer subprocesses failed: {errors}") - - -def run_fast_llm_training(model_config, run_distributed_script, assigned_gpus): - import fast_llm.cli - - config_path = model_config["run"]["experiment_dir"] / "load_config.yaml" - config_path.parent.mkdir(parents=True, exist_ok=True) - with config_path.open("wt") as f: - yaml.safe_dump(convert_paths(model_config), f) - - script = [ - "-m", - fast_llm.cli.__name__, - "train", - "gpt", - "--config", - str(config_path), - ] - - env = os.environ.copy() - env["PYTHONHASHSEED"] = "42" - env["CUDA_VISIBLE_DEVICES"] = ",".join(str(gpu) for gpu in assigned_gpus) - run_distributed_script(script, num_gpus=len(assigned_gpus), env=env) - - -def compare_test_tensors_to_checkpoint(test_safetensor_path: str, checkpoint_dir: str): - """ - Compare a test-saved safetensor file (a dict of tensors) - to all safetensors in a checkpoint directory. - - Checks: - - tensor names must match - - shapes must match - - dtypes must match - - values must match (exact) - """ - - # ------------------------- - # Load test tensor file - # ------------------------- - test_tensors = {} - with safetensors.safe_open(test_safetensor_path, framework="pt", device="cpu") as f: - for key in f.keys(): - test_tensors[key] = f.get_tensor(key) - - assert len(test_tensors) > 0, f"No tensors found in {test_safetensor_path}." - - # ------------------------- - # Load checkpoint tensors - # ------------------------- - checkpoint_tensors = {} - - for file in os.listdir(checkpoint_dir): - if file.endswith(".safetensors"): - path = os.path.join(checkpoint_dir, file) - with safetensors.safe_open(path, framework="pt", device="cpu") as f: - for key in f.keys(): - if key in checkpoint_tensors: - raise AssertionError( - f"Duplicate tensor name '{key}' across checkpoint {checkpoint_dir} files." - ) - checkpoint_tensors[key] = f.get_tensor(key) - - assert len(checkpoint_tensors) > 0, f"No safetensors found in checkpoint directory: {checkpoint_dir}" - - # ------------------------- - # Compare tensor sets - # ------------------------- - test_names = set(test_tensors.keys()) - ckpt_names = set(checkpoint_tensors.keys()) - - unexpected_in_test = test_names - ckpt_names - missing_in_test = ckpt_names - test_names - - assert not missing_in_test, "Tensors missing in {test_safetensor_path}:\n" + "\n".join(sorted(missing_in_test)) - assert not unexpected_in_test, "Unexpected tensors in {test_safetensor_path}:\n" + "\n".join( - sorted(unexpected_in_test) - ) - - # ------------------------- - # Compare individual tensors - # ------------------------- - for name in sorted(test_names): - t_test = test_tensors[name] - t_ckpt = checkpoint_tensors[name] - - # dtype - assert t_test.dtype == t_ckpt.dtype, f"Mismatch in dtype for '{name}': " f"{t_test.dtype} != {t_ckpt.dtype}" - - # shape - assert t_test.shape == t_ckpt.shape, ( - f"Mismatch in shape for '{name}': " f"{tuple(t_test.shape)} != {tuple(t_ckpt.shape)}" - ) - - # values - if not torch.equal(t_test, t_ckpt): - diff = (t_test - t_ckpt).abs() - max_diff = diff.max().item() - idx = (diff > 0).nonzero(as_tuple=False) - example = idx[0].tolist() if idx.numel() > 0 else "unknown" - - raise AssertionError( - f"Tensor content mismatch for '{name}'.\n" - f"Max difference: {max_diff}\n" - f"Example differing index: {example}" - ) - - # If we reached here → all is good - return True - - -def check_events_results( - test_results_path_fast_llm, - test_results_path_consumers, - consumer_count, - training_steps, - model_checkpoint_format, -): - for consumer_idx in range(consumer_count): - consumer_test_results_path = test_results_path_consumers / str(consumer_idx) / "results" - assert (consumer_test_results_path / "training_finished").is_file() - assert (consumer_test_results_path / "initial_weights_step").is_file() - # NOTE: We do not test the initial weights broadcast result when enabled, - # because it is identical to subsequent broadcasts. - for training_step in range(1, training_steps + 1): - compare_test_tensors_to_checkpoint( - consumer_test_results_path / f"{training_step}.safetensors", - test_results_path_fast_llm / "export" / model_checkpoint_format / str(training_step), - ) - - -def convert_paths(obj): - if isinstance(obj, dict): - return {k: convert_paths(v) for k, v in obj.items()} - elif isinstance(obj, list): - return [convert_paths(v) for v in obj] - elif isinstance(obj, tuple): - return tuple(convert_paths(v) for v in obj) - elif isinstance(obj, pathlib.Path): - return str(obj) - else: - return obj - - -def parallelism_variants(num_gpus: int) -> list[dict[str, int]]: - if num_gpus == 1: - return [{"tp": 1, "pp": 1, "sp": 1}] - - if num_gpus == 2: - return [ - # NOTE: Streaming dataset is currently not compatible with pipeline parallelism. - {"tp": 2, "pp": 1, "sp": 1}, - # {"tp": 1, "pp": 2, "sp": 1}, - {"tp": 1, "pp": 1, "sp": 2}, - ] - - if num_gpus == 4: - return [ - # NOTE: Streaming dataset is currently not compatible with pipeline parallelism. - {"tp": 4, "pp": 1, "sp": 1}, - # {"tp": 1, "pp": 4, "sp": 1}, - {"tp": 1, "pp": 1, "sp": 4}, - # {"tp": 2, "pp": 2, "sp": 1}, - # {"tp": 1, "pp": 2, "sp": 2}, - {"tp": 2, "pp": 1, "sp": 2}, - ] - - raise ValueError(f"Invalid gpu count for fast_llm parallelism {num_gpus}") - - -def consumer_counts(num_gpus: int) -> int: - if num_gpus == 2: - return 1 - if num_gpus == 3: - return 1 - if num_gpus == 4: - return 2 - if num_gpus == 5: - return 1 - if num_gpus == 6: - return 2 - if num_gpus == 7: - return 3 - if num_gpus >= 8: - return 4 - - -def generate_variants(num_gpus: int) -> list[dict[str, typing.Any]]: - """ - Generate all (consumer_count, tp/pp/sp) variants for given GPU count. - """ - results = [] - - if num_gpus < 2: - return results - if num_gpus == 2: - num_gpus = [2] - elif num_gpus <= 4: - num_gpus = [2, num_gpus] - else: - num_gpus = [2, 4, min(num_gpus, 8)] - - for gpus in num_gpus: - consumers = consumer_counts(gpus) - remaining = gpus - consumers - par_vars = parallelism_variants(remaining) - for pv in par_vars: - results.append( - { - "total_gpus": gpus, - "consumers_gpu_count": consumers, - "fast_llm_gpus_count": remaining, - "consumers_gpus": list(range(consumers)), - "fast_llm_gpus": list(range(consumers, gpus)), - "tensor_parallel": pv["tp"], - "pipeline_parallel": pv["pp"], - "sequence_data_parallel": pv["sp"], - } - ) - - return results - - -variants = generate_variants(torch.cuda.device_count()) - - -@pytest.mark.slow -@requires_cuda -@pytest.mark.parametrize( - "variant", - variants, - ids=[ - f"gpu{v['total_gpus']}_cgpus{v['consumers_gpu_count']}_fgpus{v['fast_llm_gpus_count']}" - f"_tp{v['tensor_parallel']}_pp{v['pipeline_parallel']}_sp{v['sequence_data_parallel']}" - for v in variants - ], -) -def test_trainer_events_with_streaming(variant, run_distributed_script, result_path, request): - stream_config = StreamingDatasetConfig(port=port) - test_result_path = result_path / request.node.name - test_result_path_fast_llm = test_result_path / "fast_llm" - test_result_path_consumers = test_result_path / "consumers" - - broadcast_world_size = variant["consumers_gpu_count"] + 1 - fake_consumers_broadcast_ranks = list(range(variant["consumers_gpu_count"])) - fake_consumers_assigned_gpus = variant["consumers_gpus"] - fast_llm_broadcast_rank = variant["consumers_gpu_count"] - fast_llm_assigned_gpus = variant["fast_llm_gpus"] - train_iters = 2 - - model_config = copy.deepcopy(MODEL_CONFIGS["mistral"].config_dict) - model_config["data"]["datasets"] = {"training": stream_config.to_dict()} - model_config["data"]["sampling"] = {"shuffle": "disabled"} - model_config["training"]["train_iters"] = train_iters - model_config["training"]["export"] = {"interval": 1, "format": MODEL_CONFIGS["mistral"].checkpoint_format.name} - model_config["batch"]["micro_batch_size"] = 1 - model_config["batch"]["truncate_documents"] = False - model_config["run"]["experiment_dir"] = test_result_path_fast_llm - model_config["model"]["distributed"]["tensor_parallel"] = variant["tensor_parallel"] - model_config["model"]["distributed"]["pipeline_parallel"] = variant["pipeline_parallel"] - model_config["model"]["distributed"]["sequence_data_parallel"] = variant["sequence_data_parallel"] - - # We use same stream for messages in the test. Also make all fields explicit, - # so fake consumers can read them as well from this dict config - model_config["events"] = { - "redis": { - "host": stream_config.host, - "port": stream_config.port, - "stream_key": "fast_llm_events", - "payload_key": "event", - }, - "weights_broadcast": { - "enabled": True, - "initial_weights_step_message_type": "initial_weights_step", - "initial_weights_step_message_includes_weights": True, - "weights_ready_message_type": "weights_ready", - "rdvz_master_address": "127.0.0.1", - "rdvz_master_port": 19999, - "world_size": broadcast_world_size, - "rank": fast_llm_broadcast_rank, - }, - "training_finished": { - "enabled": True, - "training_finished_message_type": "training_finished", - }, - } - - batch_size = model_config["batch"]["batch_size"] - sequence_length = model_config["batch"]["sequence_length"] - with redis_batch_producer( - redis_client=fake_redis_client, - fake_redis_server_killer=fake_redis_server_killer, - batch_size=batch_size, - sequence_length=sequence_length, - ): - with run_fake_events_consumers( - model_config=model_config, - test_result_path=test_result_path_consumers, - broadcast_world_size=broadcast_world_size, - fake_consumers_broadcast_ranks=fake_consumers_broadcast_ranks, - assigned_gpus=fake_consumers_assigned_gpus, - ): - run_fast_llm_training( - model_config=model_config, - run_distributed_script=run_distributed_script, - assigned_gpus=fast_llm_assigned_gpus, - ) - check_events_results( - test_results_path_fast_llm=test_result_path_fast_llm, - test_results_path_consumers=test_result_path_consumers, - consumer_count=len(fake_consumers_assigned_gpus), - training_steps=train_iters, - model_checkpoint_format=MODEL_CONFIGS["mistral"].checkpoint_format.name, - ) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index f14472194..2834b0728 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -52,6 +52,7 @@ class ModelTestingGroup(enum.StrEnum): generate = "generate" megatron = "megatron" distributed = "distributed" + streaming = "streaming" class ModelTestingGroupAction(enum.StrEnum): @@ -156,9 +157,9 @@ def should_skip(self, distributed_config: DistributedTestingConfig) -> bool: return any(re.search(pattern, distributed_config.name) for pattern in self.skip_tests) -def _update_and_add_testing_config( - old_name: str, - new_name: str, +def update_and_add_testing_config( + old_name: str | ModelTestingConfig, + new_name: str | None, *, model_type: str | None = None, updates: dict[str | tuple[str, ...], typing.Any] | None = None, @@ -167,7 +168,7 @@ def _update_and_add_testing_config( **kwargs, ) -> ModelTestingConfig: - config = MODEL_CONFIGS[old_name] + config = old_name if isinstance(old_name, ModelTestingConfig) else MODEL_CONFIGS[old_name] config_dict = copy.deepcopy(config.config_dict) if updates is not None: for keys, update in updates.items(): @@ -179,14 +180,15 @@ def _update_and_add_testing_config( megatron_args = config.megatron_args + megatron_args new_config = dataclasses.replace( config, - name=new_name, + name=config.name if new_name is None else new_name, model_type=config.model_type if model_type is None else model_type, groups=groups, config_dict=config_dict, megatron_args=megatron_args, **kwargs, ) - MODEL_CONFIGS[new_name] = new_config + if new_name is not None: + MODEL_CONFIGS[new_name] = new_config return new_config @@ -310,7 +312,7 @@ def _update_and_add_testing_config( }, ) -_update_and_add_testing_config( +update_and_add_testing_config( # Tests MQA. "gpt_2", "starcoder", @@ -329,7 +331,7 @@ def _update_and_add_testing_config( }, ) -_update_and_add_testing_config( +update_and_add_testing_config( # Tests intermediate between gpt2 and llama, closest converter to gpt2. "gpt_2", "starcoder_2", @@ -358,7 +360,7 @@ def _update_and_add_testing_config( del MODEL_CONFIGS["starcoder_2"].config_dict["model"]["base_model"]["embeddings"]["num_position_embeddings"] -_update_and_add_testing_config( +update_and_add_testing_config( # Main tested model. "starcoder_2", "llama", @@ -387,10 +389,11 @@ def _update_and_add_testing_config( ModelTestingGroup.generate: ModelTestingGroupAction.broken, ModelTestingGroup.megatron: ModelTestingGroupAction.normal, ModelTestingGroup.distributed: ModelTestingGroupAction.normal, + ModelTestingGroup.streaming: ModelTestingGroupAction.normal, }, ) -_update_and_add_testing_config( +update_and_add_testing_config( # Tests llama3-style rotary embeddings. "llama", "llama_3", @@ -410,7 +413,7 @@ def _update_and_add_testing_config( }, ) -_update_and_add_testing_config( +update_and_add_testing_config( # Tests yarn-style rotary embeddings. "llama", "llama_yarn", @@ -430,7 +433,7 @@ def _update_and_add_testing_config( }, ) -_update_and_add_testing_config( +update_and_add_testing_config( # Tests diffusion llama converter. "llama_yarn", "diffusion_llama", @@ -455,7 +458,7 @@ def _update_and_add_testing_config( _llama_block = MODEL_CONFIGS["llama"].config_dict["model"]["base_model"]["decoder"]["block"] -_update_and_add_testing_config( +update_and_add_testing_config( # Tests multi-token prediction, custom HF model and converter. "llama", "mtp_llama", @@ -485,7 +488,7 @@ def _update_and_add_testing_config( skip_tests=(r"ce4", r"ms"), ) -_update_and_add_testing_config( +update_and_add_testing_config( # Tests partial linear biases, Qwen2 converter. "llama", "qwen_2", @@ -507,7 +510,7 @@ def _update_and_add_testing_config( }, ) -_update_and_add_testing_config( +update_and_add_testing_config( # Tests diffusion dream converter. "qwen_2", "dream", @@ -529,7 +532,7 @@ def _update_and_add_testing_config( auto_model_class=transformers.AutoModel, ) -_update_and_add_testing_config( +update_and_add_testing_config( # Tests sliding window attention, mistral converter. "llama", "mistral", @@ -552,7 +555,7 @@ def _update_and_add_testing_config( _mistral_base_model = MODEL_CONFIGS["mistral"].config_dict["model"]["base_model"] -_update_and_add_testing_config( +update_and_add_testing_config( # Tests logit distillation. "mistral", "mistral_distill_logits", @@ -580,7 +583,7 @@ def _update_and_add_testing_config( skip_tests=("ms", "pp2s1_bf4", "pp2s2_bf4", "sdp2"), ) -_update_and_add_testing_config( +update_and_add_testing_config( "mistral_distill_logits", "mistral_reverse_kl", updates={ @@ -601,7 +604,7 @@ def _update_and_add_testing_config( skip_tests=("sdp", "ms", "pp"), ) -_update_and_add_testing_config( +update_and_add_testing_config( "mistral_distill_logits", "mistral_distill_activations", updates={ @@ -632,7 +635,7 @@ def _update_and_add_testing_config( skip_tests=("sdp", "ms", "pp", "tp", GRAD_ACC, "fp16"), ) -_update_and_add_testing_config( +update_and_add_testing_config( # Tests mixture of experts, mixtral converter. "llama", "mixtral", @@ -660,7 +663,7 @@ def _update_and_add_testing_config( ) -_update_and_add_testing_config( +update_and_add_testing_config( # Tests hybrid Mamba 2. "llama", "hybrid_mamba", @@ -701,7 +704,7 @@ def _update_and_add_testing_config( skip_tests=("sdp", "ms"), ) -_update_and_add_testing_config( +update_and_add_testing_config( # Tests vision multimodal. "llama", "llava", @@ -744,7 +747,7 @@ def _update_and_add_testing_config( ) -_update_and_add_testing_config( +update_and_add_testing_config( # Tests hybrid with attention + gated delta net mixer. "llama", "apriel2_text_gdn_hybrid", @@ -795,7 +798,7 @@ def _update_and_add_testing_config( skip_tests=("sdp", "ms", TP_NO_STP), ) -_update_and_add_testing_config( +update_and_add_testing_config( # Tests apriel2 format with pattern decoder mixing all mixer types. # This comprehensive test exercises: attention, mamba, stochastic mixer, sliding window attention, gdn. "llama", @@ -918,7 +921,7 @@ def _update_and_add_testing_config( ) -_update_and_add_testing_config( +update_and_add_testing_config( # Tests apriel2 multimodal format combining pattern decoder with vision encoder. # Uses the same decoder as apriel2_text_all_hybrid but adds vision capabilities. "apriel2_text_all_hybrid", @@ -961,7 +964,7 @@ def _update_and_add_testing_config( ) -_update_and_add_testing_config( +update_and_add_testing_config( # Tests hybrid with KDA mixer. "llama", "hybrid_kda", @@ -1015,7 +1018,7 @@ def testing_group_enabled(item: pytest.Function, skip_slow: bool, skip_extra_slo model_testing_config = item.callspec.params["model_testing_config"] model_config: ModelTestingConfig = MODEL_CONFIGS[model_testing_config] for group in groups: - action = model_config.groups[group] + action = model_config.groups.get(group, ModelTestingGroupAction.unimportant) if action == ModelTestingGroupAction.main: pass elif action == ModelTestingGroupAction.normal and not skip_slow: diff --git a/tests/utils/run_test_script.py b/tests/utils/run_test_script.py index d43855789..0b8232cf7 100644 --- a/tests/utils/run_test_script.py +++ b/tests/utils/run_test_script.py @@ -1,4 +1,3 @@ -import argparse import functools import os import pathlib @@ -12,7 +11,7 @@ from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.utils import Assert from tests.utils.distributed_configs import DistributedTestingConfig -from tests.utils.model_configs import MODEL_CONFIGS, ModelTestingConfig +from tests.utils.model_configs import ModelTestingConfig if typing.TYPE_CHECKING: from tests.conftest import WorkerResources @@ -140,16 +139,6 @@ def run_test_script_for_all_models( ) -def parse_run_distributed_script(args: list[str] | None = None): - parser = argparse.ArgumentParser() - parser.add_argument("base_path", type=pathlib.Path) - parser.add_argument("model_testing_config", type=str) - parser.add_argument("--no-distributed-capture", dest="distributed_capture", action="store_false") - - parsed = parser.parse_args(args) - return parsed.base_path, MODEL_CONFIGS[parsed.model_testing_config], parsed.distributed_capture - - @pytest.fixture(scope="session") def compare_results_for_all_models( worker_resources: "WorkerResources", diff --git a/tests/utils/subtest.py b/tests/utils/subtest.py index 9d7d319a1..4fea1fbba 100644 --- a/tests/utils/subtest.py +++ b/tests/utils/subtest.py @@ -67,6 +67,14 @@ def subtest(self, base_path: pathlib.Path, name: str, num_gpus: int): def _configure_logging(self): configure_logging(rank=self._rank, world_size=self._world_size) + @property + def rank(self) -> int: + return self._rank + + @property + def world_size(self) -> int: + return self._world_size + class DistributedSubtestContext: def __init__( self, test_context: "DistributedTestContext", base_path: pathlib.Path, name: str, num_gpus: int From 057aff2a023778143f27b170c281ef1655b60758 Mon Sep 17 00:00:00 2001 From: Oleksiy Ostapenko Date: Tue, 6 Jan 2026 06:02:30 -0800 Subject: [PATCH 11/12] Bug fixing (#434) --- fast_llm/data/dataset/gpt/config.py | 1 - fast_llm/engine/checkpoint/distributed.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 41a2fe7ff..5e978ac2b 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -65,7 +65,6 @@ def build(self, preprocessing: PreprocessingConfig) -> SamplableDataset[SampleTy def _load_config(self) -> SampledDatasetConfig[SampleType]: assert self.path.is_file(), f"File {self.path} does not exist." config = yaml.safe_load(self.path.open("r")) - Assert.eq(config.keys(), {"config", "metadata"}) if config.keys() == {"config", "metadata"}: # Newer format with metadata config = config["config"] diff --git a/fast_llm/engine/checkpoint/distributed.py b/fast_llm/engine/checkpoint/distributed.py index d953ea35d..fecc35ef7 100644 --- a/fast_llm/engine/checkpoint/distributed.py +++ b/fast_llm/engine/checkpoint/distributed.py @@ -123,7 +123,7 @@ def _copy_shard_overlaps(self, loaded_model, loaded_shards, context): for loaded_stage, loaded_fsdp, loaded_fsdp_shards in loaded_model.split_shards_by_fsdp(loaded_shards): # Skip tied weight copies to avoid duplicate loads. # We can't call `loaded_stage.is_tied_weight_copy` because the loaded model isn't setup. - if not loaded_stage.index not in loaded_model.stages_owned: + if loaded_stage.index in loaded_model.stages_owned: for self_stage, self_fsdp, self_fsdp_shards in self._model.split_shards_by_fsdp(self_shards): counter = self_fsdp.copy_shard_overlaps( loaded_fsdp, From c9d66dd03ffb92e626802b0278051ba5d1a9f481 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 6 Jan 2026 22:39:55 -0500 Subject: [PATCH 12/12] fixes --- fast_llm/data/dataset/config.py | 10 +++ fast_llm/data/dataset/streaming.py | 18 ++--- fast_llm/engine/checkpoint/state_dict.py | 3 +- fast_llm/engine/multi_stage/fast_llm_model.py | 4 +- fast_llm/engine/training/config.py | 5 +- fast_llm/engine/training/streaming.py | 21 ++++-- fast_llm/utils.py | 6 +- tests/models/test_checkpoint.py | 5 +- .../{test_events.py => test_streaming.py} | 69 ++++++++++--------- tests/utils/redis.py | 10 +-- 10 files changed, 87 insertions(+), 64 deletions(-) rename tests/models/{test_events.py => test_streaming.py} (75%) diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index 003b1dfb0..b94f0e5f4 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -301,8 +301,18 @@ def build(self, preprocessing: PreprocessingConfig) -> "IndexedDataset[SampleTyp raise FileNotFoundError(self.path) +REDIS_DATA_STREAM = "fast_llm_streaming" +REDIS_FIELD = "data" +REDIS_GROUP_NAME = "fast_llm_group" + + @config_class() class RedisConfig(Config): + REDIS_FIELD: typing.ClassVar[str] = "data" + REDIS_FIELD_B: typing.ClassVar[bytes] = REDIS_FIELD.encode() + REDIS_GROUP_NAME: typing.ClassVar[str] = "fast_llm_group" + REDIS_GROUP_NAME_B: typing.ClassVar[bytes] = REDIS_GROUP_NAME.encode() + # TODO: Move elsewhere? (Also used in trainer) Get it from the trainer in sampling config? host: str = Field( default="localhost", diff --git a/fast_llm/data/dataset/streaming.py b/fast_llm/data/dataset/streaming.py index cff028f62..9f47395a2 100644 --- a/fast_llm/data/dataset/streaming.py +++ b/fast_llm/data/dataset/streaming.py @@ -6,7 +6,7 @@ from fast_llm.config import Configurable from fast_llm.data.dataset.abstract import SamplableIterableDataset -from fast_llm.data.dataset.config import StreamingDatasetConfig +from fast_llm.data.dataset.config import REDIS_DATA_STREAM, REDIS_FIELD, REDIS_GROUP_NAME, StreamingDatasetConfig from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.data.sample.range import RangeSample from fast_llm.data.sample.token import TokenSample @@ -20,10 +20,6 @@ def dtype_from_string(name: str) -> torch.dtype: raise ValueError(f"Unknown torch dtype: {name}") -REDIS_DATA_KEY = "fast_llm_streaming" -REDIS_GROUP_NAME = "fast_llm_group" - - class RedisStreamingDataset[ConfigType: StreamingDatasetConfig, SampleType: LanguageModelSample]( Configurable[ConfigType], SamplableIterableDataset[SampleType] ): @@ -36,7 +32,7 @@ def __init__(self, config: ConfigType, distributed_config: DistributedConfig): # the training step. # raise NotImplementedError("Streaming dataset support is not implemented for pipeline-parallel training.") - self._name = f"redis[{config.host}:{config.port}]({REDIS_DATA_KEY}|{REDIS_GROUP_NAME})[data]" + self._name = f"redis[{config.host}:{config.port}]({REDIS_DATA_STREAM}|{REDIS_GROUP_NAME})[data]" self._config = config self._rank = distributed_config.batch_data_rank self.is_batch_data_group_leader = ( @@ -69,7 +65,7 @@ def __iter__(self) -> typing.Iterator[LanguageModelSample]: # Create the consumer group at the start of the stream ("0") # If the stream already exists, XGROUP CREATE will fail unless we add mkstream=True try: - client.xgroup_create(name=REDIS_DATA_KEY, groupname=REDIS_GROUP_NAME, id="0", mkstream=True) + client.xgroup_create(name=REDIS_DATA_STREAM, groupname=REDIS_GROUP_NAME, id="0", mkstream=True) except redis.exceptions.ResponseError as e: if "BUSYGROUP" in str(e): # Consumer group already exists @@ -86,7 +82,7 @@ def __iter__(self) -> typing.Iterator[LanguageModelSample]: groupname=REDIS_GROUP_NAME, consumername=f"fast_llm_consumer_{self._rank}", # ">" reads only new messages that have not been delivered to any consumer - streams={REDIS_DATA_KEY: ">"}, + streams={REDIS_DATA_STREAM: ">"}, count=1, block=1000, # No explicit ACK: messages are processed immediately; on rank failure the job restarts, @@ -95,14 +91,14 @@ def __iter__(self) -> typing.Iterator[LanguageModelSample]: ) if messages: for stream_key, msgs in messages: - assert stream_key == REDIS_DATA_KEY.encode() + assert stream_key == REDIS_DATA_STREAM.encode() for msg_id, msg_data in msgs: processed += 1 # TODO: or do it after processing all received messaged then count > 1? if processed % self._config.acknowledge_interval == 0: - client.hset(f"{REDIS_DATA_KEY}:ack", str(self._rank), msg_id) + client.hset(f"{REDIS_DATA_STREAM}:ack", str(self._rank), msg_id) - yield self._read_document(json.loads(msg_data[b"data"])) + yield self._read_document(json.loads(msg_data[REDIS_FIELD.encode()])) def _read_document(self, data: dict) -> LanguageModelSample: tokens = torch.tensor(data["tokens"], dtype=dtype_from_string(data["tokens_dtype"])) diff --git a/fast_llm/engine/checkpoint/state_dict.py b/fast_llm/engine/checkpoint/state_dict.py index bbb0fa34b..32eea2db6 100644 --- a/fast_llm/engine/checkpoint/state_dict.py +++ b/fast_llm/engine/checkpoint/state_dict.py @@ -15,6 +15,7 @@ CheckpointLoadMetadataConfig, CheckpointSaveConfig, CheckpointSaveMetadataConfig, + CheckpointStateSaveConfigBase, FastLLMCheckpointFormat, export_safetensors_metadata, ) @@ -72,7 +73,7 @@ def save(self, config: CheckpointSaveConfig, metadata: CheckpointMetadata) -> No self._save_serialized_metadata(config, serialized_metadata, index) def iter_tensors( - self, config: CheckpointSaveConfig, metadata: "CheckpointMetadata" + self, config: CheckpointStateSaveConfigBase, metadata: "CheckpointMetadata" ) -> typing.Iterator[tuple[str, str, torch.Tensor]]: # The tensor mapping may not be one-to-one. `convert_state_dict` pops all tensors from # `state_dict` that are ready for conversion, diff --git a/fast_llm/engine/multi_stage/fast_llm_model.py b/fast_llm/engine/multi_stage/fast_llm_model.py index ed6835140..68fd41af5 100644 --- a/fast_llm/engine/multi_stage/fast_llm_model.py +++ b/fast_llm/engine/multi_stage/fast_llm_model.py @@ -5,7 +5,7 @@ from fast_llm.config import UpdateType from fast_llm.core.distributed import broadcast -from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, CheckpointSaveConfig +from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, CheckpointSaveConfig, CheckpointStateSaveConfigBase from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageMode from fast_llm.engine.multi_stage.multi_stage import MultiStageModel @@ -34,7 +34,7 @@ def save_checkpoint( def iter_checkpoint( self, - config: CheckpointSaveConfig, + config: CheckpointStateSaveConfigBase, extra_metadata: dict | None = None, ) -> typing.Iterator[tuple[str, str, torch.Tensor]]: # TODO: Handle barriers, ok file, mkdir, etc. here diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index ba52e6b5a..0b492703c 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -424,10 +424,11 @@ class TrainerConfig(PretrainedFastLLMModelConfig, ExperimentConfig): def _validate(self) -> None: self.training.export.setup(self.model) - for callback in self.callbacks.values(): - callback.setup(self) for reference_model in self.reference_models.values(): self._add_reference_distributed_to_pretrained(reference_model) + for callback in self.callbacks.values(): + # We don't know anything about the callbacks, so we forward `self` and let them handle their own setup. + callback.setup(self) super()._validate() if self.reference_models: # TODO: Add support. diff --git a/fast_llm/engine/training/streaming.py b/fast_llm/engine/training/streaming.py index 2009e9476..9a8bbc723 100644 --- a/fast_llm/engine/training/streaming.py +++ b/fast_llm/engine/training/streaming.py @@ -10,7 +10,8 @@ logger = logging.getLogger(__name__) -REDIS_TRAINING_KEY = "fast_llm_events" +REDIS_TRAINING_STREAM = "fast_llm_events" +REDIS_TRAINING_FIELD = "event" class StreamingTrainerCallback[ConfigType: StreamingTrainerCallbackConfig](TrainerCallback[ConfigType]): @@ -48,17 +49,23 @@ def step_end( def train_end(self, step: int): # TODO: ====== Send something on unsuccessful ends? ====== if self._do_broadcast: - self._client.xadd(REDIS_TRAINING_KEY, {"event": json.dumps({"type": "training_finished"})}) - torch.distributed.destroy_process_group() + self._client.xadd(REDIS_TRAINING_STREAM, {REDIS_TRAINING_FIELD: json.dumps({"type": "training_finished"})}) + self._clear() def __del__(self): - if self._do_broadcast: - torch.distributed.destroy_process_group() + self._clear() + + def _clear(self): + if hasattr(self, "_process_group"): + torch.distributed.destroy_process_group(self._process_group) + del self._process_group def _broadcast_weights(self, step: int): if self._do_broadcast: - self._client.xadd(REDIS_TRAINING_KEY, {"event": json.dumps({"type": "weights_ready", "step": step})}) - for shard_name, layer_name, tensor in self._model.iter_checkpoint(self._config.get_save_config("", 10), {}): + self._client.xadd( + REDIS_TRAINING_STREAM, {REDIS_TRAINING_FIELD: json.dumps({"type": "weights_ready", "step": step})} + ) + for shard_name, layer_name, tensor in self._model.iter_checkpoint(self._config.export, {}): if self._do_broadcast: # TODO: ====== Broadcast metadata in advance ======= meta = [(shard_name, layer_name, tensor.shape, tensor.dtype)] diff --git a/fast_llm/utils.py b/fast_llm/utils.py index 2ca61aa0e..fa4ea4c2f 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -167,7 +167,7 @@ def rms_close_relative(x, y, threshold, min_threshold=0, *, msg=None): ) @staticmethod - def all_equal(x, *args): + def all_equal(x, *args, msg=None): import torch # Make it work for lists and numpy arrays. @@ -181,7 +181,9 @@ def all_equal(x, *args): index = None if x.numel() == 1 else torch.where(neq) # noqa raise AssertionError( f"Tensors have {index[0].numel()} different entries out of " - f"{x.numel()}: {x[index]} != {arg[index]} at index {torch.stack(index, -1)}" + f"{x.numel()}: {x[index]} != {arg[index]} at index {torch.stack(index, -1)}" + "" + if msg is None + else f"| {msg}" ) @staticmethod diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 41d0952c6..9a3bc4345 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -169,9 +169,10 @@ def compare_safetensor_files( for other_path in other_paths: other = safetensors.torch.load_file(other_path) - Assert.eq(other.keys(), expected_keys) + if other.keys() != expected_keys: + raise ValueError(f"Expected keys {expected_keys} but got {other.keys()} in {other_path}") for key in expected_keys: - Assert.all_equal(reference[key], other[key]) + Assert.all_equal(reference[key], other[key], msg=f"tensor = {key}, path = {other_path}") @requires_cuda diff --git a/tests/models/test_events.py b/tests/models/test_streaming.py similarity index 75% rename from tests/models/test_events.py rename to tests/models/test_streaming.py index b1e538a05..f132b465d 100644 --- a/tests/models/test_events.py +++ b/tests/models/test_streaming.py @@ -10,18 +10,12 @@ import torch from fast_llm.engine.training.config import StreamingTrainerCallbackConfig -from fast_llm.engine.training.streaming import REDIS_TRAINING_KEY -from fast_llm.models.gpt.config import GPTBatchConfig +from fast_llm.engine.training.streaming import REDIS_TRAINING_FIELD, REDIS_TRAINING_STREAM from fast_llm.utils import Assert from tests.conftest import WorkerResources from tests.models.test_checkpoint import compare_safetensor_files from tests.utils.distributed_configs import DistributedTestingConfig -from tests.utils.model_configs import ( - MODEL_CONFIGS, - ModelTestingConfig, - ModelTestingGroup, - update_and_add_testing_config, -) +from tests.utils.model_configs import ModelTestingConfig, ModelTestingGroup, update_and_add_testing_config from tests.utils.redis import redis_batch_producer from tests.utils.run_test_script import do_run_test_script_for_all_models from tests.utils.subtest import DistributedTestContext @@ -59,16 +53,6 @@ def total_gpus(self) -> int: num_gpus=2, consumer_count=2, ), - StreamingDistributedTestingConfig( - name="streaming_pp2s2_bf4", - config_args=[ - "model.distributed.pipeline_parallel=2", - "model.multi_stage.layers_per_stage=2", - "batch.breadth_first_micro_batches=4", - ], - num_gpus=2, - consumer_count=2, - ), ] @@ -80,6 +64,7 @@ def _run_event_consumer( logging.info(f"Waiting for weights broadcast rendezvous at {init_method} ...") path = base_path / "streaming" path.mkdir(parents=True, exist_ok=True) + field = REDIS_TRAINING_FIELD.encode() # TODO: Create a custom process group instead. try: process_group = torch.distributed.init_process_group( @@ -91,19 +76,20 @@ def _run_event_consumer( last_id = "0-0" while True: result = client.xread( - streams={REDIS_TRAINING_KEY: last_id}, + streams={REDIS_TRAINING_STREAM: last_id}, count=1, block=10000, ) if not result: raise TimeoutError("No message received after 10000 ms...") - for _, (event_id, message) in result[0]: - last_id = event_id - message = json.loads(message.decode()) + ((stream, events),) = result + Assert.eq(stream.decode(), REDIS_TRAINING_STREAM) + Assert.eq(len(events), 1) + for last_id, message in events: + Assert.eq(message.keys(), {field}) + message = json.loads(message[field].decode()) logging.info(f"Received: {message}") - Assert.eq(message.keys(), {"event"}) - message = message["event"] if message["type"] == "training_finished": return elif message["type"] == "weights_ready": @@ -140,6 +126,7 @@ def _run_model_streaming_configs( None, updates={ ("data", "datasets"): {"training": {"port": port}}, + ("training", "export"): {"format": model_testing_config.checkpoint_format.name, "interval": 1}, "callbacks": { "streaming": { "type": "streaming", @@ -148,9 +135,12 @@ def _run_model_streaming_configs( "port": port + 1000, "external_world_size": config.consumer_count, }, - "export": {"format": MODEL_CONFIGS["mistral"].checkpoint_format.name}, + "export": {"format": model_testing_config.checkpoint_format.name}, } }, + # Disable tensor logging. + ("run", "tensor_logs"): {}, + ("model", "multi_stage"): {}, }, groups={}, ) @@ -159,20 +149,33 @@ def _run_model_streaming_configs( if test_context.rank < config.num_gpus: do_run_test_script_for_all_models(config, model_testing_config, base_path) elif test_context.rank < config.total_gpus: - streaming_config = StreamingTrainerCallbackConfig.from_dict( - model_testing_config.config_dict["callbacks"]["streaming"] + training_config = model_testing_config.trainer_config_class.from_dict( + model_testing_config.config_dict ) - batch_config = GPTBatchConfig.from_dict(model_testing_config.config_dict["batch"]) with ( - redis_batch_producer(streaming_config, batch_config) + redis_batch_producer(training_config.callbacks["streaming"], training_config.batch) if test_context.rank == config.num_gpus else contextlib.nullcontext() ): - _run_event_consumer(streaming_config, test_context.rank - config.num_gpus, base_path) + _run_event_consumer( + training_config.callbacks["streaming"], + test_context.rank - config.num_gpus, + base_path / config.name, + ) + + +@requires_cuda +@pytest.mark.slow +@pytest.mark.model_testing_group(ModelTestingGroup.streaming, ModelTestingGroup.distributed) +def test_model_streaming(run_parallel_script, model_testing_config, run_test_script_base_path, worker_resources): + # `test_run_model_distributed_streaming` and `test_model_distributed_streaming need a common dependency + # so they are placed in the same testing group and run in the same distributed process. + pass @requires_cuda @pytest.mark.slow +@pytest.mark.depends_on(on=["test_model_streaming[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.streaming, ModelTestingGroup.distributed) def test_run_model_distributed_streaming( run_parallel_script, model_testing_config, run_test_script_base_path, worker_resources @@ -189,6 +192,7 @@ def test_run_model_distributed_streaming( @pytest.mark.slow @requires_cuda +@pytest.mark.depends_on(on=["test_model_streaming[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.streaming, ModelTestingGroup.distributed) @pytest.mark.parametrize("config", _DISTRIBUTED_STREAMING_CONFIGS) def test_model_distributed_streaming( @@ -201,10 +205,9 @@ def test_model_distributed_streaming( ): report_subtest(path := run_test_script_base_path / config.name, config.total_gpus) compare_safetensor_files( - path / "export" / model_testing_config.checkpoint_format.name / "1", + path / "export" / model_testing_config.checkpoint_format.name / f"1/model_0.safetensors", *( - path / "streaming" / f"rank_{consumer_index}_step_{step}.safetensors" + path / "streaming" / f"rank_{consumer_index}_step_1.safetensors" for consumer_index in range(config.consumer_count) - for step in range(3) ), ) diff --git a/tests/utils/redis.py b/tests/utils/redis.py index 7e8072aab..591ee74e6 100644 --- a/tests/utils/redis.py +++ b/tests/utils/redis.py @@ -9,13 +9,15 @@ import fakeredis from fast_llm.data.dataset.config import ( + REDIS_DATA_STREAM, + REDIS_FIELD, + REDIS_GROUP_NAME, RedisConfig, SamplingConfig, SamplingData, SamplingParameters, StreamingDatasetConfig, ) -from fast_llm.data.dataset.streaming import REDIS_DATA_KEY, REDIS_GROUP_NAME from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.models.gpt.config import GPTBatchConfig @@ -29,7 +31,7 @@ def find_free_port(): def push_msg(redis_client, tokens): """Push a message into FakeRedis stream.""" - redis_client.xadd(REDIS_DATA_KEY, {"data": json.dumps({"tokens": tokens, "tokens_dtype": "int64"})}) + redis_client.xadd(REDIS_DATA_STREAM, {REDIS_FIELD: json.dumps({"tokens": tokens, "tokens_dtype": "int64"})}) def wait_until_stream_empty( @@ -57,7 +59,7 @@ def wait_until_stream_empty( def get_consumer_count(redis_client, stop_event, config: StreamingDatasetConfig): while not stop_event.is_set(): - res = redis_client.hget(f"{REDIS_DATA_KEY}:consumer_count", "0") + res = redis_client.hget(f"{REDIS_DATA_STREAM}:consumer_count", "0") if res is None: time.sleep(0.05) continue @@ -76,7 +78,7 @@ def producer_loop(): break push_msg(client, [sample_index] * batch_config.sequence_length) if sample_index % 5 == 0: - wait_until_stream_empty(client, REDIS_DATA_KEY, REDIS_GROUP_NAME, stop_event) + wait_until_stream_empty(client, REDIS_DATA_STREAM, REDIS_GROUP_NAME, stop_event) thread = threading.Thread(target=producer_loop, daemon=True) thread.start()