diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index d9f89997..64ecbde1 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -32,7 +32,7 @@ jobs: pip install pybind11 FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE \ MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE \ - pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,GENERATION,DEV,DOCS,VISION]" + pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,STREAMING,DEV,DOCS]" - name: Run tests run: pytest -v -ra . diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index 5f5e5928..0893de47 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -34,7 +34,7 @@ jobs: pip install pybind11 FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE \ MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE \ - pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,GENERATION,DEV,DOCS,VISION]" + pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,DEV,DOCS]" - name: Build the documentation run: mkdocs build diff --git a/Dockerfile b/Dockerfile index 5804d0e4..7ff5d7a7 100644 --- a/Dockerfile +++ b/Dockerfile @@ -39,7 +39,7 @@ COPY --chmod=777 ./fast_llm/__init__.py fast_llm/ COPY --chmod=777 ./fast_llm/csrc/ fast_llm/csrc/ # Install dependencies within the virtual environment. -RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,DEV]" triton==3.5.1 +RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,STREAMING,DEV]" triton==3.5.1 # Copy the remaining source code with universal write permissions. COPY --chmod=777 ./Megatron-LM Megatron-LM diff --git a/fast_llm/data/data/data_loader_wrapper.py b/fast_llm/data/data/data_loader_wrapper.py new file mode 100644 index 00000000..f9e51724 --- /dev/null +++ b/fast_llm/data/data/data_loader_wrapper.py @@ -0,0 +1,52 @@ +import torch.distributed +import torch.utils.data.dataloader + +from fast_llm.core.distributed import broadcast_object + + +class DistributedDataLoaderWrapper: + """ + Wraps a regular dataloader so that only the process group leader + loads data, and then broadcasts the batch to other ranks in the group. + """ + + def __init__( + self, + dataloader: torch.utils.data.dataloader.DataLoader | None, + rank: int, + process_group: torch.distributed.ProcessGroup | None, + ): + self.dataloader = dataloader + self.rank = rank + self.process_group = process_group + + assert (self.rank == 0 and self.dataloader is not None) or (self.rank > 0 and self.dataloader is None) + + def __iter__(self): + if self.rank == 0: + self.iterator = iter(self.dataloader) + if self.process_group is None: + return self.iterator + return self + + def __next__(self): + # TODO: + # Instead of broadcasting a general object, make this iterator yield an actual Batch class. + # Implement `get_state_dict` and `from_state_dict` in the Batch class so that we can + # efficiently broadcast tensors directly. This avoids using `broadcast_object` on the + # entire Batch object, which is inefficient for tensors because it serializes + # (pickles) them before sending. + + if self.rank == 0: + try: + data = next(self.iterator) # may raise StopIteration + except Exception as e: + data = e + data = broadcast_object(data, self.process_group, 0) + else: + data = broadcast_object(None, self.process_group, 0) + + if isinstance(data, Exception): + raise data + + return data diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index dbd77089..70966a05 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -8,8 +8,10 @@ from fast_llm.core.distributed import safe_barrier from fast_llm.data.data.abstract import Data +from fast_llm.data.data.data_loader_wrapper import DistributedDataLoaderWrapper from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.data.dataset.abstract import SampledDataset +from fast_llm.data.dataset.abstract_iterable import SampledIterableDataset from fast_llm.data.dataset.config import SamplingParameters from fast_llm.data.dataset.gpt.config import GPTSamplingData from fast_llm.data.dataset.monitor import DatasetMonitor @@ -90,7 +92,12 @@ def setup( dataset_name=dataset_name, ) dataset = self._config.datasets[dataset_name].build_and_sample(sampling) - self._datasets[dataset_name] = DatasetMonitor(dataset, self._config.data_sample_warn_time_ms) + if isinstance(dataset, SampledDataset): + self._datasets[dataset_name] = DatasetMonitor(dataset, self._config.data_sample_warn_time_ms) + else: + # Do not set monitor for iterable dataset as monitor only works with map style datasets + assert isinstance(dataset, SampledIterableDataset) + self._datasets[dataset_name] = dataset safe_barrier(self._distributed.world_group, "data_preparation", timeout) self._is_setup = True @@ -116,9 +123,11 @@ def get_iterator( Assert.in_range_incl(batch_config.sequence_length, 1, sampling_parameters.sequence_length) log_main_rank(f"Initializing {dataset_name} dataset iterator from sample {consumed_samples}...") - return iter( - torch.utils.data.DataLoader( - self._datasets[dataset_name], # noqa + dataset = self._datasets[dataset_name] + + if isinstance(dataset, SampledDataset): + data_loader = torch.utils.data.DataLoader( + dataset, # noqa batch_sampler=SampledDatasetIterator( total_samples=len(self._datasets[dataset_name]), begin_index=consumed_samples, @@ -132,4 +141,27 @@ def get_iterator( collate_fn=LanguageModelBatch.from_samples, multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None, ) - ) + + elif isinstance(dataset, SampledIterableDataset): + if ( + self.distributed.model_and_sequence_data_group is None + or self.distributed.model_and_sequence_data_group.rank() == 0 + ): + rank = 0 + data_loader = torch.utils.data.DataLoader( + dataset, # noqa + batch_size=batch_config.micro_batch_size, + num_workers=0 if num_workers == 0 else 1, + prefetch_factor=prefetch_factor, + pin_memory=True, + collate_fn=LanguageModelBatch.from_samples, + multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None, + ) + else: + rank = self.distributed.model_and_sequence_data_group.rank() + data_loader = None + data_loader = DistributedDataLoaderWrapper( + data_loader, rank, self.distributed.model_and_sequence_data_group + ) + + return iter(data_loader) diff --git a/fast_llm/data/dataset/abstract.py b/fast_llm/data/dataset/abstract.py index 33942708..2efdf384 100644 --- a/fast_llm/data/dataset/abstract.py +++ b/fast_llm/data/dataset/abstract.py @@ -44,7 +44,6 @@ def __len__(self) -> int: class SamplableDataset[SampleType: Sample](Dataset[SampleType]): - @abc.abstractmethod def sample(self, config: "SamplingData") -> SampledDataset[SampleType]: pass diff --git a/fast_llm/data/dataset/abstract_iterable.py b/fast_llm/data/dataset/abstract_iterable.py new file mode 100644 index 00000000..770f4f97 --- /dev/null +++ b/fast_llm/data/dataset/abstract_iterable.py @@ -0,0 +1,30 @@ +import abc +import typing + +import torch.utils.data + +from fast_llm.data.sample.abstract import Sample + +if typing.TYPE_CHECKING: + from fast_llm.data.dataset.config import SamplingData + + +# NOTE: We need to inherit from IterableDataset otherwise torch data loader can not detect it properly +class SampledIterableDataset[SampleType: Sample](torch.utils.data.IterableDataset[SampleType]): + """ + A sampled dataset class that provides an iterator over samples. + """ + + # NOTE: We add name here so it is compatible with Fast-LLM Dataset + @property + @abc.abstractmethod + def name(self) -> str: + """ + A name for the dataset to facilitate identification and debugging. + """ + + +class SamplableIterableDataset[SampleType: Sample](SampledIterableDataset[SampleType]): + @abc.abstractmethod + def sample(self, config: "SamplingData") -> SampledIterableDataset[SampleType]: + pass diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index 2858d8d1..f2a24c48 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -7,13 +7,15 @@ import pathlib import typing -from fast_llm.config import Config, Field, FieldHint, UpdateType, check_field, config_class +from fast_llm.config import Config, Field, FieldHint, FieldUpdate, UpdateType, check_field, config_class from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset from fast_llm.data.preprocessing.abstract import PreprocessingConfig from fast_llm.data.sample.abstract import Sample +from fast_llm.redis.config import RedisConfig from fast_llm.utils import Assert, normalize_probabilities if typing.TYPE_CHECKING: + from fast_llm.data.dataset.abstract_iterable import SamplableIterableDataset, SampledIterableDataset from fast_llm.data.dataset.indexed import ConcatenatedDataset, DatasetSlice, IndexedDataset from fast_llm.engine.distributed.distributed import Distributed @@ -106,19 +108,25 @@ class DatasetConfig[SampleType: Sample](Config): @config_class(registry=True) class SampledDatasetConfig[SampleType: Sample](DatasetConfig[SampleType]): """ - A sampled dataset containing a prepared list of samples to be indexed sequentially (as-is) during training. + A sampled dataset containing a prepared list or iterable of samples to be indexed sequentially (as-is) during training. """ - def build_and_sample(self, sampling: SamplingData) -> SampledDataset[SampleType]: + def build_and_sample( + self, sampling: SamplingData + ) -> "SampledDataset[SampleType] | SampledIterableDataset[SampleType]": raise NotImplementedError() @config_class() class SamplableDatasetConfig[SampleType: Sample](SampledDatasetConfig[SampleType]): - def build(self, preprocessing: PreprocessingConfig) -> SamplableDataset[SampleType]: + def build( + self, preprocessing: PreprocessingConfig + ) -> "SamplableDataset[SampleType] | SamplableIterableDataset[SampleType]": raise NotImplementedError() - def build_and_sample(self, sampling: SamplingData) -> SampledDataset[SampleType]: + def build_and_sample( + self, sampling: SamplingData + ) -> "SampledDataset[SampleType] | SampledIterableDataset[SampleType]": return self.build(sampling.preprocessing).sample(sampling) @@ -298,3 +306,91 @@ def build(self, preprocessing: PreprocessingConfig) -> "IndexedDataset[SampleTyp return LegacyMemmapDataset[SampleType](name, self.path, preprocessing) else: raise FileNotFoundError(self.path) + + +@config_class() +class StreamingDatasetRedisConfig(RedisConfig): + stream_key: str = FieldUpdate(default="fast_llm_streaming") + + payload_key: str = FieldUpdate( + default="data", + ) + + +class IngestionType(str, enum.Enum): + CONSUMER_GROUP = "consumer_group" + ONE_STREAM = "one_stream" + N_STREAMS = "n_streams" + + +class HashType(str, enum.Enum): + MESSAGE_INDEX = "message_index" + """Use the index of the received message for hashing. Provides precise distribution but may not be well shuffled.""" + + MESSAGE_ID = "message_id" + """Hash messages based on their unique message ID. Good for probabilistic distribution. + Redis message IDs are regenerated each time, so this is not reproducible. + """ + + MESSAGE_BODY = "message_body" + """Hash messages based on their payload content (bytes). Distributes messages roughly evenly. + Deterministic based on message content, but not perfectly balanced across ranks. + """ + + PRODUCER_PROVIDED = "producer_provided" + """Use the hash or index provided by the producer. Allows deterministic splitting and perfect balance.""" + + +@config_class(dynamic_type={SampledDatasetConfig: "streaming"}) +class StreamingDatasetConfig[SampleType: LanguageModelSample](SamplableDatasetConfig[SampleType]): + """ + Configuration for a streaming dataset that reads training data from a Redis stream. + """ + + _abstract = False + + redis: StreamingDatasetRedisConfig = Field( + desc="Redis connection and stream settings used to fetch incoming training data.", + hint=FieldHint.core, + ) + + group_name: str = Field( + default="fast_llm_dp_group", + desc="Name of the Redis consumer group used for data-parallel reading.", + hint=FieldHint.core, + ) + + consumer_name_prefix: str = Field( + default="fast_llm_dp_group_consumer", + desc="Prefix used to generate unique consumer names for each rank in Redis consumer group.", + hint=FieldHint.core, + ) + + ingestion_type: IngestionType = Field( + default=IngestionType.CONSUMER_GROUP, + desc="Strategy used to ingest data from Redis streams (consumer group, single stream, or multiple streams).", + hint=FieldHint.core, + ) + + hash_type: HashType = Field( + default=HashType.MESSAGE_ID, + desc="How to compute hash for assigning messages to ranks.", + hint=FieldHint.core, + ) + + hash_key: str = Field( + default="hash", + desc="Key in the message dict containing the hash or index provided by the producer.", + hint=FieldHint.core, + ) + + ack_period_per_consumer: int = Field( + default=10, + desc="Number of messages after which the consumer acknowledges received IDs back to the Redis hash.", + hint=FieldHint.core, + ) + + def build_and_sample(self, sampling: SamplingData) -> "SampledIterableDataset[SampleType]": + from fast_llm.data.dataset.streaming import StreamingDataset + + return StreamingDataset[SampleType](self, sampling.distributed).sample(sampling) diff --git a/fast_llm/data/dataset/sampled.py b/fast_llm/data/dataset/sampled.py index d51a6874..979fd7a6 100644 --- a/fast_llm/data/dataset/sampled.py +++ b/fast_llm/data/dataset/sampled.py @@ -9,6 +9,7 @@ import yaml from fast_llm.data.dataset.abstract import SampledDataset +from fast_llm.data.dataset.abstract_iterable import SamplableIterableDataset, SampledIterableDataset from fast_llm.data.dataset.config import SamplingData, ShufflingType from fast_llm.data.dataset.indexed import IndexedDataset from fast_llm.data.sample.abstract import Sample @@ -429,3 +430,55 @@ def _load_yaml_data(self, data: dict[str, typing.Any]) -> None: self._unshuffled_tokens = data["unshuffled_tokens"] self._unshuffled_documents = data["unshuffled_epochs"] * self._documents_per_epoch + + +class NaiveSampledIterableDataset[SampleType: Sample](SampledIterableDataset[SampleType]): + def __init__( + self, + iterable_dataset: SamplableIterableDataset[SampleType], + sampling: SamplingData, + ): + self._dataset = iterable_dataset + self._config = sampling.config + self._parameters = sampling.parameters + + assert self._parameters.truncate_documents == False + assert self._config.shuffle == ShufflingType.disabled + + def __iter__(self) -> typing.Iterator[SampleType]: + sample_length = self._parameters.sequence_length + self._parameters.extra_tokens + current_sample_length = 0 + documents: list[SampleType] = [] + for doc in self._dataset: + if len(doc) > sample_length: + logging.warning(f"Dropping doc with length {len(doc)} higher then sample_length {sample_length}") + continue + if current_sample_length + len(doc) > sample_length: + padding_length = sample_length - current_sample_length + assert padding_length > 0 + documents.append(documents[-1].get_padding(padding_length)) + + yield documents[0].from_documents(documents) + + documents = [doc] + current_sample_length = len(doc) + else: + documents.append(doc) + current_sample_length += len(doc) + + if current_sample_length == sample_length: + yield documents[0].from_documents(documents) + + documents = [] + current_sample_length = 0 + + if current_sample_length > 0: + padding_length = sample_length - current_sample_length + assert padding_length > 0 + documents.append(documents[-1].get_padding(padding_length)) + + yield documents[0].from_documents(documents) + + @property + def name(self) -> str: + return self._dataset.name diff --git a/fast_llm/data/dataset/streaming.py b/fast_llm/data/dataset/streaming.py new file mode 100644 index 00000000..1aabf60c --- /dev/null +++ b/fast_llm/data/dataset/streaming.py @@ -0,0 +1,259 @@ +import typing + +import torch.utils.data +import xxhash + +from fast_llm.data.dataset.abstract_iterable import SamplableIterableDataset, SampledIterableDataset +from fast_llm.data.dataset.config import HashType, IngestionType, SamplingData, StreamingDatasetConfig +from fast_llm.data.dataset.sampled import NaiveSampledIterableDataset +from fast_llm.data.sample.language_model import LanguageModelSample +from fast_llm.data.sample.range import RangeSample +from fast_llm.data.sample.token import TokenSample +from fast_llm.engine.config_utils.run import is_main_rank +from fast_llm.engine.distributed.distributed import Distributed + +if typing.TYPE_CHECKING: + import redis + + +def dtype_from_string(name: str) -> torch.dtype: + try: + return getattr(torch, name) + except AttributeError: + raise ValueError(f"Unknown torch dtype: {name}") + + +class StreamingDataset[SampleType: LanguageModelSample](SamplableIterableDataset[SampleType]): + def __init__(self, config: StreamingDatasetConfig, distributed: Distributed): + super().__init__() + if distributed.config.pipeline_parallel > 1: + # NOTE: It is not yet clear whether the issue comes from the streaming dataset + # itself or from the distributed data-loader wrappers, but currently it + # interferes with pipeline-parallel training and causes a timeout during + # the training step. + raise NotImplementedError("Streaming dataset support is not implemented for pipeline-parallel training.") + + self._name = f"redis[{config.redis.host}:{config.redis.port}]({config.redis.stream_key}|{config.group_name})[{config.redis.payload_key}]" + self._config = config + self.batch_data_rank = distributed.config.batch_data_rank + self.batch_data_parallel = distributed.config.batch_data_parallel + self.is_batch_data_group_leader = ( + distributed.model_and_sequence_data_group is None or distributed.model_and_sequence_data_group.rank() == 0 + ) + self.payload_key_b = self._config.redis.payload_key.encode() + self.hash_key_b = self._config.hash_key.encode() + + self._set_consumer_count() + + @property + def name(self) -> str: + return self._name + + def sample(self, config: SamplingData) -> SampledIterableDataset[LanguageModelSample]: + # TODO: actually sample the dataset and not return docs + return NaiveSampledIterableDataset(self, config) + + def _set_consumer_count(self): + import redis + + if is_main_rank(): + redis_client = redis.Redis(host=self._config.redis.host, port=self._config.redis.port) + redis_client.hset(f"{self._config.redis.stream_key}:consumer_count", "0", self.batch_data_parallel) + + def __getstate__(self) -> tuple[str, StreamingDatasetConfig, int, int, bool, bytes, bytes]: + return ( + self._name, + self._config, + self.batch_data_parallel, + self.batch_data_rank, + self.is_batch_data_group_leader, + self.payload_key_b, + self.hash_key_b, + ) + + def __setstate__(self, state: tuple[str, StreamingDatasetConfig, int, bool, bytes, bytes]): + name, config, batch_data_parallel, batch_data_rank, is_batch_data_group_leader, payload_key_b, hash_key_b = ( + state + ) + self._name = name + self._config = config + self.batch_data_parallel = batch_data_parallel + self.batch_data_rank = batch_data_rank + self.is_batch_data_group_leader = is_batch_data_group_leader + self.payload_key_b = payload_key_b + self.hash_key_b = hash_key_b + + def __iter__(self) -> typing.Iterator[LanguageModelSample]: + import orjson + import redis + + worker_info = torch.utils.data.get_worker_info() + if worker_info is not None and worker_info.num_workers > 1: + raise RuntimeError("StreamingDataset can work only with one instance per rank") + + if not self.is_batch_data_group_leader: + raise RuntimeError("Must be only called on the batch data group leader") + + redis_client = redis.Redis(host=self._config.redis.host, port=self._config.redis.port) + + match self._config.ingestion_type: + case IngestionType.CONSUMER_GROUP: + messages_iter = self._iter_consumer_group + case IngestionType.ONE_STREAM: + messages_iter = self._iter_stream + case IngestionType.N_STREAMS: + messages_iter = self._iter_stream + case _: + raise ValueError(f"Unknown ingestion type {self._config.ingestion_type}") + + ack_hash = f"{self._config.redis.stream_key}:ack" + consumer_id = str(self.batch_data_rank) + # If one stream each groups receives all data otherwise each groups receives just its data + if self._config.ingestion_type == IngestionType.ONE_STREAM: + ack_period = self._config.ack_period_per_consumer * self.batch_data_parallel + else: + ack_period = self._config.ack_period_per_consumer + + for msg_data in messages_iter(redis_client, ack_hash=ack_hash, consumer_id=consumer_id, ack_period=ack_period): + data = orjson.loads(msg_data[self.payload_key_b]) + yield self._sample_from_msg_data(data) + + def _iter_consumer_group( + self, redis_client: "redis.Redis", ack_hash: str, consumer_id: str, ack_period: int + ) -> typing.Iterator[LanguageModelSample]: + import redis.exceptions + + # Create the consumer group at the start of the stream ("0") + # If the stream already exists, XGROUP CREATE will fail unless we add mkstream=True + try: + redis_client.xgroup_create( + name=self._config.redis.stream_key, groupname=self._config.group_name, id="0", mkstream=True + ) + except redis.exceptions.ResponseError as e: + if "BUSYGROUP" in str(e): + # Consumer group already exists + pass + else: + raise + + processed = 0 + while True: + # XREADGROUP reads from the consumer group + # COUNT: max number of messages to fetch at once + # BLOCK: wait for new messages (milliseconds) + messages = redis_client.xreadgroup( + groupname=self._config.group_name, + consumername=f"{self._config.consumer_name_prefix}_{self.batch_data_rank}", + # ">" reads only new messages that have not been delivered to any consumer + streams={self._config.redis.stream_key: ">"}, + count=1, + block=1000, + # No explicit ACK: messages are processed immediately; on rank failure the job restarts, + # so message loss is acceptable and simplifies coordination + noack=True, + ) + if messages: + for stream_key, msgs in messages: + assert stream_key == self._config.redis.stream_key.encode() + for msg_id, msg_data in msgs: + processed += 1 + # TODO: or do it after processing all received messaged then count > 1? + if processed % ack_period == 0: + redis_client.hset(ack_hash, consumer_id, msg_id) + + yield msg_data + + def _iter_stream( + self, redis_client: "redis.Redis", ack_hash: str, consumer_id: str, ack_period: int + ) -> typing.Iterator[LanguageModelSample]: + last_id = "0-0" + stream_key = self._config.redis.stream_key + if self._config.ingestion_type == IngestionType.N_STREAMS: + stream_key += f"_{self.batch_data_rank}" + stream_key_b = stream_key.encode() + processed = 0 + while True: + messages = redis_client.xread( + streams={stream_key: last_id}, + count=1, + block=1000, + ) + if not messages: + continue + for this_stream_key_b, msgs in messages: + assert this_stream_key_b == stream_key_b + for msg_id, msg_data in msgs: + last_id = msg_id + + processed += 1 + # TODO: or do it after processing all received messaged then count > 1? + if processed % ack_period == 0: + redis_client.hset(ack_hash, consumer_id, last_id) + + if self._config.ingestion_type == IngestionType.N_STREAMS or self._is_for_this_rank( + msg_id, msg_data, processed - 1 + ): + yield msg_data + + def _is_for_this_rank(self, msg_id: bytes, msg_data: dict, msg_index: int) -> bool: + hash_type = self._config.hash_type + + if hash_type is HashType.MESSAGE_INDEX: + h = msg_index + elif hash_type is HashType.MESSAGE_ID: + h = xxhash.xxh64(msg_id).intdigest() + elif hash_type is HashType.MESSAGE_BODY: + h = xxhash.xxh64(msg_data[self.payload_key_b]).intdigest() + elif hash_type is HashType.PRODUCER_PROVIDED: + h = self._get_hash_key_value(msg_data[self.hash_key_b]) + else: + raise ValueError(f"Unknown hash_type: {hash_type}") + + return (h % self.batch_data_parallel) == self.batch_data_rank + + def _get_hash_key_value(self, value: bytes | int | str): + if isinstance(value, int): + # already an integer + msg_hash = value + elif isinstance(value, bytes): + try: + # try decoding as UTF-8 string and converting to int + msg_hash = int(value.decode("utf-8")) + except ValueError: + # not an integer, treat as a hash string + import xxhash + + msg_hash = xxhash.xxh64(value).intdigest() + elif isinstance(value, str): + try: + msg_hash = int(value) + except ValueError: + # not an integer, treat as a hash string + import xxhash + + msg_hash = xxhash.xxh64(value.encode("utf-8")).intdigest() + else: + raise TypeError(f"Unexpected type for hash key: {type(value)}") + return msg_hash + + def _sample_from_msg_data(self, data: dict) -> LanguageModelSample: + tokens = torch.tensor(data["tokens"], dtype=dtype_from_string(data["tokens_dtype"])) + sample_size = len(tokens) + if "loss_masking_spans" in data: + loss_masking_spans = [tuple(el) for el in data["loss_masking_spans"]] + else: + loss_masking_spans = None + if "chosen_spans" in data: + chosen_spans = [tuple(el) for el in data["chosen_spans"]] + else: + chosen_spans = None + if "rejected_spans" in data: + rejected_spans = [tuple(el) for el in data["rejected_spans"]] + else: + rejected_spans = None + return LanguageModelSample( + TokenSample(tokens, [sample_size]), + RangeSample(loss_masking_spans, sample_size) if loss_masking_spans is not None else None, + RangeSample(chosen_spans, sample_size) if chosen_spans is not None else None, + RangeSample(rejected_spans, sample_size) if rejected_spans is not None else None, + ) diff --git a/fast_llm/engine/checkpoint/state_dict.py b/fast_llm/engine/checkpoint/state_dict.py index 7a257a5f..bbb0fa34 100644 --- a/fast_llm/engine/checkpoint/state_dict.py +++ b/fast_llm/engine/checkpoint/state_dict.py @@ -71,6 +71,31 @@ def save(self, config: CheckpointSaveConfig, metadata: CheckpointMetadata) -> No if self._model.config.distributed.rank == 0: self._save_serialized_metadata(config, serialized_metadata, index) + def iter_tensors( + self, config: CheckpointSaveConfig, metadata: "CheckpointMetadata" + ) -> typing.Iterator[tuple[str, str, torch.Tensor]]: + # The tensor mapping may not be one-to-one. `convert_state_dict` pops all tensors from + # `state_dict` that are ready for conversion, + # and return a dict containing the converted tensors(s). + # If converting a tensor requires another one that is not yet available (e.g. for concatenation), + # it will remain in `state_dict` until that tensor is available. + state_dict = {} + for parameter_name, shard_name, tensor in self._model.get_state_tensor_iterator( + self.get_shard_names(config), config.data_type + ): + if shard_name not in state_dict: + state_dict[shard_name] = {} + shard_state_dict = state_dict[shard_name] + assert parameter_name not in shard_state_dict + shard_state_dict[parameter_name] = tensor + for exported_name, exported_tensor in self._convert_state_dict(shard_state_dict, True).items(): + yield shard_name, self._get_key(exported_name, shard_name), exported_tensor + + for shard_name, shard_state_dict in state_dict.items(): + assert ( + not shard_state_dict + ), f"Un-handled entries after conversion: {({k: list(v) for k, v in state_dict.items()})}" + @classmethod @abc.abstractmethod def _save_serialized_metadata(cls, config: CheckpointSaveMetadataConfig, metadata: dict, index: dict) -> None: diff --git a/fast_llm/engine/distributed/config.py b/fast_llm/engine/distributed/config.py index 7f4b7bc3..e63ef01f 100644 --- a/fast_llm/engine/distributed/config.py +++ b/fast_llm/engine/distributed/config.py @@ -97,6 +97,29 @@ def setup(self, group: "ProcessGroup|None"): def check_ranks_in_range(self, start, stop): check_ranks_in_range(self.global_ranks, start, stop) + @classmethod + def from_sizes_and_strides(cls, name: str, global_rank: int, *sizes_and_strides: tuple[int, int]) -> typing.Self: + start = global_rank + rank = 0 + world_size = 1 + for size, stride in sizes_and_strides: + rank_ = global_rank // stride % size + start -= rank_ * stride + rank += world_size * rank_ + world_size *= size + global_ranks = [start] + for size, stride in sizes_and_strides: + if size == 1: + continue + if len(global_ranks) == 1 or ( + isinstance(global_ranks, range) and stride == global_ranks.stop - global_ranks.start + ): + global_ranks = range(start, start + size * stride, sizes_and_strides[0][0]) + else: + global_ranks = (rank0 + rank1 for rank1 in range(0, size, stride) for rank0 in global_ranks) + global_ranks = global_ranks if isinstance(global_ranks, range) else list(global_ranks) + return DistributedDim(name=name, size=world_size, rank=rank, global_ranks=global_ranks) + def check_ranks_in_range(global_ranks, start, stop): Assert.geq(min(global_ranks), start) @@ -112,6 +135,7 @@ class DistributedDimNames: sequence_data = "sequence_data" batch_data = "batch_data" tensor_and_sequence_data = "tensor_and_sequence_data" + model_and_sequence_data = "model_and_sequence_data" tensor_and_data = "tensor_and_data" @@ -263,6 +287,8 @@ class DistributedConfig(Config): ) def _validate(self) -> None: + super()._validate() + if self.world_size is None: self.world_size = self.default_world_size if self.rank is None: @@ -300,88 +326,59 @@ def _validate(self) -> None: else: self.distributed_dims = {} - data_stride = self.tensor_parallel * (self.pipeline_parallel if self.pipeline_first else 1) + tensor_stride = 1 + sequence_data_stride = self.tensor_parallel * (self.pipeline_parallel if self.pipeline_first else 1) + batch_data_stride = sequence_data_stride * self.sequence_data_parallel pipeline_stride = self.tensor_parallel * (1 if self.pipeline_first else self.data_parallel) - self._add_distributed_dim( - DistributedDim( - name=DistributedDimNames.world, - size=self.world_size, - rank=self.rank, - global_ranks=range(self.world_size), - ) + self._add_distributed_dim_from_sizes_and_strides( + DistributedDimNames.world, + (self.world_size, 1), + ) + self._add_distributed_dim_from_sizes_and_strides( + DistributedDimNames.data, + (self.sequence_data_parallel, sequence_data_stride), + (self.batch_data_parallel, batch_data_stride), ) - self._add_distributed_dim( - DistributedDim( - name=DistributedDimNames.data, - size=self.data_parallel, - rank=self.data_rank, - global_ranks=self._get_global_ranks(self.data_parallel, data_stride), - ) + self._add_distributed_dim_from_sizes_and_strides( + DistributedDimNames.pipeline, (self.pipeline_parallel, pipeline_stride) ) - self._add_distributed_dim( - DistributedDim( - name=DistributedDimNames.pipeline, - size=self.pipeline_parallel, - rank=self.pipeline_rank, - global_ranks=self._get_global_ranks(self.pipeline_parallel, pipeline_stride), - ) + self._add_distributed_dim_from_sizes_and_strides( + DistributedDimNames.tensor, (self.tensor_parallel, tensor_stride) ) - self._add_distributed_dim( - DistributedDim( - name=DistributedDimNames.tensor, - size=self.tensor_parallel, - rank=self.tensor_rank, - global_ranks=self._get_global_ranks(self.tensor_parallel, 1), - ) + self._add_distributed_dim_from_sizes_and_strides( + DistributedDimNames.sequence_data, + (self.sequence_data_parallel, sequence_data_stride), ) - self._add_distributed_dim( - DistributedDim( - name=DistributedDimNames.sequence_data, - size=self.sequence_data_parallel, - rank=self.sequence_data_rank, - global_ranks=self._get_global_ranks(self.sequence_data_parallel, data_stride), - ) + self._add_distributed_dim_from_sizes_and_strides( + DistributedDimNames.batch_data, (self.batch_data_parallel, batch_data_stride) ) - self._add_distributed_dim( - DistributedDim( - name=DistributedDimNames.batch_data, - size=self.batch_data_parallel, - rank=self.batch_data_rank, - global_ranks=self._get_global_ranks( - self.batch_data_parallel, data_stride * self.sequence_data_parallel - ), - ) + self._add_distributed_dim_from_sizes_and_strides( + DistributedDimNames.tensor_and_sequence_data, + (self.tensor_parallel, tensor_stride), + (self.sequence_data_parallel, sequence_data_stride), + ) + self._add_distributed_dim_from_sizes_and_strides( + DistributedDimNames.tensor_and_data, + (self.tensor_parallel, tensor_stride), + (self.sequence_data_parallel, sequence_data_stride), + (self.batch_data_parallel, batch_data_stride), ) - # Global ranks wrong with pipeline first, so we hide the dims as a safety check. - if not self.pipeline_first: - self._add_distributed_dim( - DistributedDim( - name=DistributedDimNames.tensor_and_sequence_data, - size=self.sequence_data_parallel * self.tensor_parallel, - rank=self.tensor_rank + self.sequence_data_rank * self.tensor_parallel, - global_ranks=self._get_global_ranks(self.sequence_data_parallel * self.tensor_parallel, 1), - ) - ) - self._add_distributed_dim( - DistributedDim( - name=DistributedDimNames.tensor_and_data, - size=self.data_parallel * self.tensor_parallel, - rank=self.tensor_rank + self.data_rank * self.tensor_parallel, - global_ranks=self._get_global_ranks(self.data_parallel * self.tensor_parallel, 1), - ) - ) - super()._validate() + self._add_distributed_dim_from_sizes_and_strides( + DistributedDimNames.model_and_sequence_data, + (self.tensor_parallel, tensor_stride), + (self.sequence_data_parallel, sequence_data_stride), + (self.pipeline_rank, pipeline_stride), + ) if self.reference_config is not None: self.compare(self.reference_config, ValueError) Assert.in_range(self.rank, 0, self.world_size) Assert.in_range(self.local_rank, 0, self.local_world_size) - def _get_global_ranks(self, size: int, stride: int) -> range: - start = self.rank // (size * stride) * size * stride + self.rank % stride - return range(start, start + size * stride, stride) + def _add_distributed_dim_from_sizes_and_strides(self, name: str, *sizes_and_strides: tuple[int, int]) -> None: + self._add_distributed_dim(DistributedDim.from_sizes_and_strides(name, self.rank, *sizes_and_strides)) def _add_distributed_dim(self, distributed_dim: DistributedDim) -> None: Assert.eq(distributed_dim.global_ranks[distributed_dim.rank], self.rank, msg=distributed_dim) diff --git a/fast_llm/engine/distributed/distributed.py b/fast_llm/engine/distributed/distributed.py index aa2be6ce..eb4a0929 100644 --- a/fast_llm/engine/distributed/distributed.py +++ b/fast_llm/engine/distributed/distributed.py @@ -180,6 +180,7 @@ def __init__(self, config: DistributedConfig, use_cpu: bool = False): self.tensor_group = self.add_group(self._config.distributed_dims[DistributedDimNames.tensor]) self.sequence_data_group = self.add_group(self._config.distributed_dims[DistributedDimNames.sequence_data]) self.batch_data_group = self.add_group(self._config.distributed_dims[DistributedDimNames.batch_data]) + # Global ranks wrong with pipeline first, so we hide the dims as a safety check. if not self._config.pipeline_first: self.tensor_and_sequence_data_group = self.add_group( @@ -189,6 +190,10 @@ def __init__(self, config: DistributedConfig, use_cpu: bool = False): self._config.distributed_dims[DistributedDimNames.tensor_and_data] ) + self.model_and_sequence_data_group = self.add_group( + self._config.distributed_dims[DistributedDimNames.model_and_sequence_data] + ) + self._config.log_first_rank(f"Setting random seeds...") dp_shift = self._config.dp_seed_shift * self._config.data_rank diff --git a/fast_llm/engine/multi_stage/fast_llm_model.py b/fast_llm/engine/multi_stage/fast_llm_model.py index 6a6223cb..ed683514 100644 --- a/fast_llm/engine/multi_stage/fast_llm_model.py +++ b/fast_llm/engine/multi_stage/fast_llm_model.py @@ -1,6 +1,8 @@ import logging import typing +import torch + from fast_llm.config import UpdateType from fast_llm.core.distributed import broadcast from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, CheckpointSaveConfig @@ -30,6 +32,20 @@ def save_checkpoint( ) converter.save(config, fast_llm_metadata) + def iter_checkpoint( + self, + config: CheckpointSaveConfig, + extra_metadata: dict | None = None, + ) -> typing.Iterator[tuple[str, str, torch.Tensor]]: + # TODO: Handle barriers, ok file, mkdir, etc. here + converter = config.format.get_handler_class()(self) + fast_llm_metadata = self._config.to_metadata( + config, + shards=converter.get_shard_names(config), + metadata={} if extra_metadata is None else extra_metadata, + ) + yield from converter.iter_tensors(config, fast_llm_metadata) + def load_checkpoint(self, config: CheckpointLoadConfig) -> dict[str, typing.Any] | None: # TODO: Simplify branching. # TODO: Test with more distributed configs. diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 867cca98..7624c72c 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -29,6 +29,7 @@ from fast_llm.engine.optimizer.config import OptimizerConfig from fast_llm.engine.schedule.config import BatchConfig, ScheduleConfig from fast_llm.profile import ProfilingConfig +from fast_llm.redis.config import RedisConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -321,6 +322,113 @@ def _validate(self) -> None: self.wandb.alert.assert_sub_interval(self.logs) +@config_class() +class TrainerEventsRedisConfig(RedisConfig): + stream_key: str = FieldUpdate(default="fast_llm_events") + + payload_key: str = FieldUpdate(default="event") + + +@config_class() +class TrainerEvent(Config): + enabled: bool = Field( + default=False, + desc="Flag indicating whether this event is enabled. If False, the event will be skipped.", + hint=FieldHint.feature, + ) + + +@config_class() +class WeightsBroadcastEventConfig(TrainerEvent): + """ + Event sent to indicate that updated weights are ready for broadcast. + """ + + initial_weights_step_message_type: str = Field( + default="initial_weights_step", + desc="Message indicating that weights the training starting/ continuing from.", + hint=FieldHint.feature, + ) + + initial_weights_step_message_includes_weights: bool = Field( + default=False, + desc=( + "Whether to include the loaded model weights in the initial event message. " + "Useful when training restarts from an internal checkpoint format that " + "which does not have an exported checkpoint for that step." + ), + hint=FieldHint.feature, + ) + + weights_ready_message_type: str = Field( + default="weights_ready", + desc="Message indicating that weights are ready to be broadcast.", + hint=FieldHint.feature, + ) + + # NCCL rendezvous details + rdvz_master_address: str | None = Field( + default=None, + desc="Master address for the external NCCL process group.", + hint=FieldHint.feature, + ) + + rdvz_master_port: int | None = Field( + default=None, + desc="Master port for the external NCCL process group.", + hint=FieldHint.feature, + ) + + world_size: int | None = Field( + default=None, + desc="World size of the external NCCL process group.", + hint=FieldHint.feature, + ) + + rank: int | None = Field( + default=None, + desc="Rank of this process in the external NCCL process group.", + hint=FieldHint.feature, + ) + + +@config_class() +class TrainingFinishedEventConfig(TrainerEvent): + """ + Event sent to indicate that training has completed. + """ + + training_finished_message_type: str = Field( + default="training_finished", + desc="Message indicating that weights the training starting/ continuing from.", + hint=FieldHint.feature, + ) + + +@config_class() +class TrainerEventsConfig(Config): + """ + Aggregates all trainer-side Redis-based event configurations. + """ + + redis: TrainerEventsRedisConfig = Field( + desc="Redis connection and stream settings used to fetch incoming training data.", + hint=FieldHint.core, + ) + + weights_broadcast: WeightsBroadcastEventConfig = Field( + default=None, + desc="Configuration for signaling weight-ready events via Redis.", + hint=FieldHint.feature, + ) + + training_finished: TrainingFinishedEventConfig = Field( + default=None, + desc="Configuration for signaling training-finished events via Redis.", + hint=FieldHint.feature, + ) + + @config_class(registry=True, dynamic_type={RunnableConfig: "train"}) class TrainerConfig(PretrainedFastLLMModelConfig, ExperimentConfig): _abstract = True @@ -352,6 +460,12 @@ class TrainerConfig(PretrainedFastLLMModelConfig, ExperimentConfig): hint=FieldHint.feature, ) + events: TrainerEventsConfig = Field( + default=None, + desc="Optional Trainer event configurations (weight broadcast, training finished, etc.).", + hint=FieldHint.feature, + ) + def _validate(self) -> None: self.training.export.setup(self.model) for reference_model in self.reference_models.values(): diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index 7225ed20..a2f98c05 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -36,6 +36,7 @@ TrainingCheckpointConfig, TrainingEvaluatorConfig, ) +from fast_llm.engine.training.trainer_events import TrainerEvents from fast_llm.engine.training.wandb import Wandb from fast_llm.logging import format_metrics, log_memory_usage from fast_llm.utils import Assert, Interrupter, get_and_reset_memory_usage_mib @@ -131,6 +132,8 @@ def __init__(self, config: TrainerConfig): self._is_evaluation_only = config.training.train_iters == 0 + self.trainer_events = TrainerEvents(config.events) + self._data = self._get_data() log_main_rank("Creating model...") self._multi_stage = self._config.model.get_model_class()( @@ -286,6 +289,7 @@ def run(self) -> None: assert self._is_setup with self._wandb: self._run_training() + self.trainer_events.send_training_finished() def _run_training(self) -> None: self._prepare_training_state() @@ -358,6 +362,11 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: # TODO: Synchronization is probably unnecessary. safe_barrier(self._distributed.world_group, "train begin") + + self.trainer_events.send_initial_weights_step( + self._completed_steps, self._multi_stage, self._config.training.export + ) + torch.cuda.synchronize() start_time = time.perf_counter() last_time = start_time @@ -384,6 +393,9 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: advanced_iters += 1 for name, value in reduced_losses.items(): total_losses[name] += value + self.trainer_events.send_weights( + self._completed_steps, self._multi_stage, self._config.training.export + ) else: skipped_iters += 1 nan_iters += not all(math.isfinite(loss) for loss in reduced_losses.values()) diff --git a/fast_llm/engine/training/trainer_events.py b/fast_llm/engine/training/trainer_events.py new file mode 100644 index 00000000..8bce3e6d --- /dev/null +++ b/fast_llm/engine/training/trainer_events.py @@ -0,0 +1,105 @@ +import logging + +import orjson +import redis +import torch.distributed + +from fast_llm.engine.config_utils.run import is_main_rank +from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel +from fast_llm.engine.training.config import TrainerEventsConfig, TrainerEventsRedisConfig, TrainingExportConfig + +logger = logging.getLogger(__name__) + + +class RedisEventSender: + def __init__(self, config: TrainerEventsRedisConfig): + self.config = config + self.client = None + + if is_main_rank(): + self.client = redis.Redis( + host=config.host, + port=config.port, + ) + + def send(self, msg_type: str, payload: dict | None = None): + if not is_main_rank(): + return + + if not payload: + payload = {} + payload.update({"type": msg_type}) + + self.client.xadd(self.config.stream_key, {self.config.payload_key: orjson.dumps(payload)}) + + +class TrainerEvents: + """ + Main helper class holding all event channels. + Each event may have its own RedisConfig. + + Usage: + events = TrainerEvents(cfg.events) + events.weights_broadcast.send({"step": 100}) + events.training_finished.send() + """ + + def __init__(self, config: TrainerEventsConfig): + self.config = config + + if config.weights_broadcast.enabled or config.training_finished.enabled: + self.sender = RedisEventSender(config.redis) + else: + self.sender = None + + if config.weights_broadcast.enabled and is_main_rank(): + init_method = ( + f"tcp://{config.weights_broadcast.rdvz_master_address}:{config.weights_broadcast.rdvz_master_port}" + ) + logger.info(f"Waiting for weights broadcast rendezvous at {init_method} ...") + self.weights_pg = torch.distributed.init_process_group( + backend="nccl", + init_method=init_method, + world_size=config.weights_broadcast.world_size, + rank=config.weights_broadcast.rank, + ) + logger.info(f"Weights broadcast rendezvous at {init_method} connected") + else: + self.weights_pg = None + + def send_initial_weights_step(self, step: int, model: FastLLMModel, export_config: TrainingExportConfig): + if self.config.weights_broadcast.enabled: + self.sender.send( + msg_type=self.config.weights_broadcast.initial_weights_step_message_type, payload={"step": step} + ) + if self.config.weights_broadcast.initial_weights_step_message_includes_weights: + self._broadcast_weights(model, export_config) + + def send_weights(self, step: int, model: FastLLMModel, export_config: TrainingExportConfig): + if self.config.weights_broadcast.enabled: + self.sender.send(msg_type=self.config.weights_broadcast.weights_ready_message_type, payload={"step": step}) + self._broadcast_weights(model, export_config) + + def send_training_finished(self): + if self.config.training_finished.enabled: + self.sender.send(msg_type=self.config.training_finished.training_finished_message_type) + + if is_main_rank() and self.config.weights_broadcast.enabled: + torch.distributed.destroy_process_group() + + def _broadcast_weights(self, model: FastLLMModel, export_config: TrainingExportConfig): + for shard_name, layer_name, tensor in model.iter_checkpoint(export_config.get_save_config("", 10), {}): + if is_main_rank(): + meta = [(shard_name, layer_name, tensor.shape, tensor.dtype)] + torch.distributed.broadcast_object_list( + meta, group=self.weights_pg, group_src=self.config.weights_broadcast.rank + ) + torch.distributed.broadcast( + tensor, group=self.weights_pg, group_src=self.config.weights_broadcast.rank + ) + # Broadcast end of weights broadcast + if is_main_rank(): + meta = [None] + torch.distributed.broadcast_object_list( + meta, group=self.weights_pg, group_src=self.config.weights_broadcast.rank + ) diff --git a/fast_llm/models/gpt/conversion/qwen2.py b/fast_llm/models/gpt/conversion/qwen2.py index a8bc3345..57c9614b 100644 --- a/fast_llm/models/gpt/conversion/qwen2.py +++ b/fast_llm/models/gpt/conversion/qwen2.py @@ -2,6 +2,7 @@ from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.layers.attention.config import AttentionConfig +from fast_llm.layers.decoder.mlp.config import MLPConfig from fast_llm.models.gpt.conversion.config import Qwen2CheckpointFormat from fast_llm.models.gpt.conversion.llama import ( LlamaAttentionConverter, @@ -10,6 +11,7 @@ LlamaDecoderConverter, LlamaHeadConverter, LlamaHuggingfaceCheckpointHandler, + LlamaMLPConverter, ) from fast_llm.utils import Assert @@ -17,6 +19,22 @@ class Qwen2AttentionConverter(LlamaAttentionConverter): # TODO: Support sliding window with max_window_layers (need 2 kinds of block?) + @classmethod + def import_config(cls, config: dict) -> dict: + config["attention_bias"] = True + out = super().import_config(config) + out["query_layer"] = {"bias": {"enabled": True}} + out["key_layer"] = {"bias": {"enabled": True}} + out["value_layer"] = {"bias": {"enabled": True}} + out["dense_layer"] = {"bias": {"enabled": False}} + return out + + @classmethod + def export_config(cls, config: AttentionConfig) -> dict: + out = super().export_config(config) + del out["attention_bias"] + return out + @classmethod def _check_config(cls, config: AttentionConfig) -> None: Assert.is_(type(config), AttentionConfig) @@ -33,8 +51,22 @@ def _check_config(cls, config: AttentionConfig) -> None: Assert.incl(config.dense_layer.bias.enabled, (None, False)) +class Qwen2MLPConverter(LlamaMLPConverter): + @classmethod + def import_config(cls, config: dict) -> dict: + config["mlp_bias"] = False + return super().import_config(config) + + @classmethod + def export_config(cls, config: MLPConfig) -> dict: + out = super().export_config(config) + del out["mlp_bias"] + return out + + class Qwen2BlockConverter(LlamaBlockConverter): mixer_converter_class: typing.ClassVar[type[Qwen2AttentionConverter]] = Qwen2AttentionConverter + mlp_converter_class: typing.ClassVar[type[Qwen2MLPConverter]] = Qwen2MLPConverter class Qwen2DecoderConverter(LlamaDecoderConverter): diff --git a/fast_llm/redis/config.py b/fast_llm/redis/config.py new file mode 100644 index 00000000..5b6bfbdd --- /dev/null +++ b/fast_llm/redis/config.py @@ -0,0 +1,28 @@ +from fast_llm.config import Config, Field, FieldHint, config_class + + +@config_class() +class RedisConfig(Config): + host: str = Field( + default="localhost", + desc="Hostname or IP address of the Redis server.", + hint=FieldHint.core, + ) + + port: int = Field( + default=6379, + desc="Port number on which the Redis server is running.", + hint=FieldHint.core, + ) + + stream_key: str = Field( + default=None, + desc="Name of the Redis stream to read data from.", + hint=FieldHint.core, + ) + + payload_key: str = Field( + default=None, + desc="Key under which the message data is stored inside the Redis payload dict.", + hint=FieldHint.core, + ) diff --git a/setup.cfg b/setup.cfg index 005ae5a8..34995ce9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -60,6 +60,9 @@ SSM = GENERATION = lm_eval>=0.4.9 +STREAMING = + redis>=-7.1.0 + orjson>=3.11.5 # Required for supporting vision inputs VISION = @@ -78,6 +81,7 @@ DEV = setuptools>=80.9.0 # Dependency manager needs colorama to show colors. colorama>=0.4.6 + fakeredis>=2.32.1 # Required for building the documentation DOCS = diff --git a/tests/conftest.py b/tests/conftest.py index ba2927c6..33e70f6a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -27,9 +27,11 @@ from tests.utils.run_test_script import ( # isort: skip compare_results_for_all_models, run_distributed_script, + run_distributed_script_lean, run_test_script_base_path, run_test_script_for_all_models, ) +from tests.utils.redis import fake_redis_server, stream_config # isort: skip from tests.utils.model_configs import model_testing_config, ModelTestingConfig, testing_group_enabled # isort: skip from tests.utils.utils import result_path, format_resource_report, report_subtest # isort: skip diff --git a/tests/data/gptdata_streaming_test.py b/tests/data/gptdata_streaming_test.py new file mode 100644 index 00000000..3e388cc4 --- /dev/null +++ b/tests/data/gptdata_streaming_test.py @@ -0,0 +1,115 @@ +import argparse +import pathlib +import pickle + +from fast_llm.config import NoAutoValidate +from fast_llm.data.data.gpt.config import GPTDataConfig +from fast_llm.data.data.gpt.data import GPTData +from fast_llm.data.dataset.config import IngestionType +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.models.gpt.config import GPTBatchConfig +from tests.utils.redis import get_stream_config, make_sampling + + +def distributed_gptdata_streaming_test( + sequence_length, + micro_batch_size, + batch_size, + tensor_parallel, + pipeline_parallel, + sequence_data_parallel, + total_gpus, + redis_port, + result_path, + ingestion_type, +): + stream_config = get_stream_config() + stream_config = stream_config.from_dict( + stream_config.to_dict(), {("redis", "port"): redis_port, ("ingestion_type"): ingestion_type} + ) + + distributed = Distributed( + DistributedConfig( + tensor_parallel=tensor_parallel, + pipeline_parallel=pipeline_parallel, + sequence_data_parallel=sequence_data_parallel, + ), + use_cpu=total_gpus == 0, + ) + sampling_data = make_sampling(sequence_length, 0, micro_batch_size, distributed) + + data_config = {"datasets": {"streaming1": stream_config.to_dict()}, "sampling": {"shuffle": "disabled"}} + data_config = GPTDataConfig.from_dict(data_config) + + data = GPTData(data_config, distributed.config) + + data.setup( + distributed=distributed, + sampling_parameters={"streaming1": sampling_data.parameters}, + preprocessing={}, + cache_directory="/tmp", + ) + + with NoAutoValidate(): + batch_config = GPTBatchConfig( + micro_batch_size=micro_batch_size, batch_size=batch_size, sequence_length=sequence_length + ) + batch_config.setup(distributed_config=distributed.config) + batch_config.validate() + + data_iter = data.get_iterator(batch_config, "streaming1", consumed_samples=0, num_workers=1, prefetch_factor=1) + + batch = next(data_iter) + # TODO: save result per batch_data_group and rank + assert batch.tokens.tokens.shape == (micro_batch_size, sequence_length) + + result_path = ( + pathlib.Path(result_path) + / ( + f"{distributed.config.batch_data_rank}_" + f"{distributed.model_and_sequence_data_group.rank() if distributed.model_and_sequence_data_group is not None else 0}" + ) + / "batch.pkl" + ) + result_path.parent.mkdir(exist_ok=True, parents=True) + with result_path.open("wb") as f: + pickle.dump(batch, f) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Run distributed GPT data streaming test.") + + parser.add_argument("--sequence-length", type=int, required=True, help="Sequence length of the model input.") + parser.add_argument("--micro-batch-size", type=int, required=True, help="Micro batch size.") + parser.add_argument("--batch-size", type=int, required=True, help="Global batch size.") + parser.add_argument("--tensor-parallel", type=int, required=True, help="Tensor parallel degree.") + parser.add_argument("--pipeline-parallel", type=int, required=True, help="Pipeline parallel degree.") + parser.add_argument("--sequence-data-parallel", type=int, required=True, help="Sequence data parallel degree.") + parser.add_argument("--total-gpus", type=int, required=True, help="Total number of GPUs available.") + parser.add_argument("--redis-port", type=int, required=True, help="Redis port to connect to.") + parser.add_argument("--result-path", type=str, required=True, help="Path to save test results.") + parser.add_argument("--ingestion-type", type=str, required=True, help="Ingestion type used in streaming dataset.") + + return parser.parse_args() + + +def main(): + args = parse_args() + + distributed_gptdata_streaming_test( + sequence_length=args.sequence_length, + micro_batch_size=args.micro_batch_size, + batch_size=args.batch_size, + tensor_parallel=args.tensor_parallel, + pipeline_parallel=args.pipeline_parallel, + sequence_data_parallel=args.sequence_data_parallel, + total_gpus=args.total_gpus, + redis_port=args.redis_port, + result_path=args.result_path, + ingestion_type=IngestionType(args.ingestion_type), + ) + + +if __name__ == "__main__": + main() diff --git a/tests/data/test_streaming.py b/tests/data/test_streaming.py new file mode 100644 index 00000000..a0bfae31 --- /dev/null +++ b/tests/data/test_streaming.py @@ -0,0 +1,399 @@ +import logging +import os +import pickle + +import fakeredis +import pytest +import torch + +from fast_llm.data.dataset.config import IngestionType +from fast_llm.data.dataset.streaming import StreamingDataset +from fast_llm.data.sample.language_model import LanguageModelSample +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.distributed.distributed import Distributed +from tests.utils.redis import make_sampling, push_msg, redis_batch_producer +from tests.utils.utils import requires_cuda + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------- + + +@pytest.fixture +def fake_redis(): + """Return a FakeRedis instance.""" + return fakeredis.FakeRedis() + + +@pytest.fixture +def monkeypatched_redis(monkeypatch, fake_redis): + """Monkeypatch redis.Redis globally (works even for imports inside functions).""" + import redis + + monkeypatch.setattr(redis, "Redis", lambda *args, **kwargs: fake_redis) + return fake_redis + + +# --------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------- + + +def generate_parallelism_variants(total_gpus: int): + """ + Generate all valid variants of (data_groups, tensor_parallel, pipeline_parallel, sequence_parallel) + for a number of GPUs up to the total_gpus. + If total_gpus is odd and > 1, fallback to nearest lower even number for decomposable parallelism. + """ + if total_gpus > 1 and total_gpus % 2 == 1: + total_gpus = total_gpus - 1 + + if total_gpus < 2: + # No gpu and one gpu tests are the same, + # so no need of creation of variant for a single gpu + return [] + + variants = [] + + for gpus in range(2, total_gpus + 1, 2): + # try all possible numbers of data groups (1..total_gpus) + for data_groups in range(1, gpus + 1): + if gpus % data_groups != 0: + continue # cannot evenly split + + gpus_per_group = gpus // data_groups + + # now find all decompositions of gpus_per_group into tp*pp*sp + for tp in range(1, gpus_per_group + 1): + if gpus_per_group % tp != 0: + continue + rem_after_tp = gpus_per_group // tp + # TODO: currently streaming dataset does not support pipeline parallel setup + # for pp in range(1, rem_after_tp + 1): + for pp in range(1, 2): + if rem_after_tp % pp != 0: + continue + sp = rem_after_tp // pp + try: + # instead of repeating all safeguards here just try to + # instantiate distributed config to check if combination is valid + dist_config = DistributedConfig( + tensor_parallel=tp, + pipeline_parallel=pp, + sequence_data_parallel=sp, + world_size=gpus, + # TODO: works only on one node + local_world_size=gpus, + rank=0, + ) + except Exception: + continue + + variants.append( + { + "data_groups": data_groups, + "batch_data_parallel": dist_config.batch_data_parallel, + "tensor_parallel": tp, + "pipeline_parallel": pp, + "sequence_data_parallel": sp, + "total_gpus": gpus, + } + ) + return variants + + +def run_distributed_gptdata_streaming_test( + fake_redis_server, + variant, + run_distributed_script, + result_path, + request, + ingestion_type: IngestionType, +): + import tests.data.gptdata_streaming_test + + stream_config, fake_redis, fake_redis_server_killer = fake_redis_server + stream_config = stream_config.from_dict(stream_config.to_dict(), {("ingestion_type"): ingestion_type}) + + sequence_length = 10 + micro_batch_size = 2 + batch_size = micro_batch_size * variant["batch_data_parallel"] + tensor_parallel = variant["tensor_parallel"] + pipeline_parallel = variant["pipeline_parallel"] + sequence_data_parallel = variant["sequence_data_parallel"] + total_gpus = variant["total_gpus"] + redis_port = stream_config.redis.port + + result_path = result_path / "distributed_gptdata_streaming_test" / request.node.name + + with redis_batch_producer( + redis_client=fake_redis, + fake_redis_server_killer=fake_redis_server_killer, + stream_config=stream_config, + batch_size=batch_size, + sequence_length=10, + ): + if total_gpus > 0: + script = [ + "-m", + tests.data.gptdata_streaming_test.__name__, + "--sequence-length", + str(sequence_length), + "--micro-batch-size", + str(micro_batch_size), + "--batch-size", + str(batch_size), + "--tensor-parallel", + str(tensor_parallel), + "--pipeline-parallel", + str(pipeline_parallel), + "--sequence-data-parallel", + str(sequence_data_parallel), + "--total-gpus", + str(total_gpus), + "--result-path", + str(result_path), + "--redis-port", + str(redis_port), + "--ingestion-type", + str(ingestion_type.value), + ] + # TODO: distributed_capture is ignored now inside the script + if request.config.getoption("distributed_capture"): + logger.warning( + "Capturing output and forwarding to associated tests. Run with `--no-distributed-capture` to disable." + ) + else: + script.append("--no-distributed-capture") + + env = os.environ.copy() + env["PYTHONHASHSEED"] = "42" + run_distributed_script(script, num_gpus=total_gpus, env=env) + else: + tests.data.gptdata_streaming_test.distributed_gptdata_streaming_test( + sequence_length=sequence_length, + micro_batch_size=micro_batch_size, + batch_size=batch_size, + tensor_parallel=tensor_parallel, + pipeline_parallel=pipeline_parallel, + sequence_data_parallel=sequence_data_parallel, + total_gpus=total_gpus, + redis_port=redis_port, + result_path=result_path, + ingestion_type=ingestion_type, + ) + + check_distributed_gptdata_streaming_test_results( + result_path=result_path, + micro_batch_size=micro_batch_size, + batch_data_parallel=variant["batch_data_parallel"], + total_gpu=variant["total_gpus"], + ) + + +def check_distributed_gptdata_streaming_test_results( + result_path, + micro_batch_size, + batch_data_parallel, + total_gpu, +): + batch_data_parallel_size = total_gpu // batch_data_parallel if total_gpu > 0 else 1 + sample_idx = set() + for i in range(batch_data_parallel): + ref_batch = None + for j in range(batch_data_parallel_size): + with (result_path / f"{i}_{j}" / "batch.pkl").open("rb") as f: + batch = pickle.load(f) + if ref_batch is None: + ref_batch = batch + else: + # batches for same batch_data_parallel_group must be equal on all ranks + assert torch.equal(batch.tokens.tokens, ref_batch.tokens.tokens) + for j in range(micro_batch_size): + val = ref_batch.tokens.tokens[j, 0].item() + # all samples in batches between groups and in the batch must be unique + assert val not in sample_idx + sample_idx.add(val) + # unique sample count must be the same as global batch size + assert len(sample_idx) == micro_batch_size * batch_data_parallel + + +# --------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------- + + +def test_streaming_dataset_reads_single_message(monkeypatched_redis, stream_config): + """StreamingDataset should read a message and convert it into LanguageModelSample.""" + fake_redis = monkeypatched_redis + + distributed = Distributed(DistributedConfig(), use_cpu=True) + dataset = StreamingDataset(stream_config, distributed) + + # Insert a message + push_msg(fake_redis, stream_config, [1, 2, 3]) + + it = iter(dataset) + sample = next(it) + + assert isinstance(sample, LanguageModelSample) + assert torch.equal(sample.tokens.tokens, torch.tensor([1, 2, 3], dtype=torch.int64)) + assert sample.tokens.lengths == [3] + assert sample.loss_masking_spans is None + assert sample.chosen_spans is None + assert sample.rejected_spans is None + + +def test_streaming_dataset_reads_multiple_messages(monkeypatched_redis, stream_config): + """StreamingDataset should read a message and convert it into LanguageModelSample.""" + fake_redis = monkeypatched_redis + + distributed = Distributed(DistributedConfig(), use_cpu=True) + dataset = StreamingDataset(stream_config, distributed) + + # Insert a message + push_msg(fake_redis, stream_config, [1, 2, 3]) + push_msg(fake_redis, stream_config, [1, 2, 3]) + push_msg(fake_redis, stream_config, [1, 2, 3]) + + it = iter(dataset) + for i in range(3): + sample = next(it) + + assert isinstance(sample, LanguageModelSample) + assert torch.equal(sample.tokens.tokens, torch.tensor([1, 2, 3], dtype=torch.int64)) + assert sample.tokens.lengths == [3] + assert sample.loss_masking_spans is None + assert sample.chosen_spans is None + assert sample.rejected_spans is None + + +def test_sampling_1_doc_exact_fit(monkeypatched_redis, stream_config): + """Docs exactly fill one sample.""" + fake_redis = monkeypatched_redis + + push_msg(fake_redis, stream_config, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + + distributed = Distributed(DistributedConfig(), use_cpu=True) + sampler = StreamingDataset(stream_config, distributed).sample(make_sampling(10, 0, 1, distributed)) + + out = next(iter(sampler)) + + assert isinstance(out, LanguageModelSample) + assert len(out) == 10 + assert out.tokens.tokens.tolist() == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + + +def test_sampling_2_docs_exact_fit(monkeypatched_redis, stream_config): + """Docs exactly fill one sample.""" + fake_redis = monkeypatched_redis + + # Two rollouts: lengths 4 and 6 -> exactly 10 + push_msg(fake_redis, stream_config, [1, 2, 3, 4]) + push_msg(fake_redis, stream_config, [5, 6, 7, 8, 9, 10]) + + distributed = Distributed(DistributedConfig(), use_cpu=True) + sampler = StreamingDataset(stream_config, distributed).sample(make_sampling(10, 0, 1, distributed)) + + out = next(iter(sampler)) + + assert isinstance(out, LanguageModelSample) + assert len(out) == 10 + assert out.tokens.tokens.tolist() == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + + +def test_sampling_skips_too_long_doc_and_padding_final(monkeypatched_redis, stream_config): + """Rollout longer than sample_length must be dropped.""" + fake_redis = monkeypatched_redis + + push_msg(fake_redis, stream_config, list(range(20))) # skip: too long + push_msg(fake_redis, stream_config, list(range(10))) # usable + + distributed = Distributed(DistributedConfig(), use_cpu=True) + sampler = StreamingDataset(stream_config, distributed).sample(make_sampling(10, 0, 1, distributed)) + + out = next(iter(sampler)) + + # too big message is skipped and next message is returned instead + assert len(out) == 10 + assert out.tokens.tokens.tolist() == list(range(10)) + + +def test_sampling_overflow_creates_two(monkeypatched_redis, stream_config): + """A document overflowing the boundary triggers padding + next sample.""" + fake_redis = monkeypatched_redis + + push_msg(fake_redis, stream_config, list(range(6))) + push_msg(fake_redis, stream_config, list(range(10))) + + distributed = Distributed(DistributedConfig(), use_cpu=True) + sampler = StreamingDataset(stream_config, distributed).sample(make_sampling(10, 0, 2, distributed)) + + sampler_iter = iter(sampler) + out = [next(sampler_iter)] + out.append(next(sampler_iter)) + + # sample 1: 0..5 + pad(4) + assert out[0].tokens.tokens.tolist() == list(range(6)) + [-100, -100, -100, -100] + + # sample 2: 0..5 + pad(4) + assert out[1].tokens.tokens.tolist() == list(range(10)) + + +@pytest.mark.parametrize( + "ingestion_type", + [ + IngestionType.CONSUMER_GROUP, + # TODO: need to implement wait_until_stream_empty for those variants on test side to enable tests for them + # IngestionType.ONE_STREAM, + # IngestionType.N_STREAMS, + ], +) +def test_gptdata_streaming_single_consumer( + fake_redis_server, run_distributed_script_lean, ingestion_type, result_path, request +): + + run_distributed_gptdata_streaming_test( + fake_redis_server=fake_redis_server, + variant={ + "data_groups": 1, + "tensor_parallel": 1, + "pipeline_parallel": 1, + "sequence_data_parallel": 1, + "total_gpus": 0, + "batch_data_parallel": 1, + }, + run_distributed_script=run_distributed_script_lean, + result_path=result_path, + request=request, + ingestion_type=ingestion_type, + ) + + +variants = generate_parallelism_variants(torch.cuda.device_count()) + + +@pytest.mark.slow +@requires_cuda +@pytest.mark.parametrize( + "variant", + variants, + ids=[ + f"dg{v['data_groups']}_tp{v['tensor_parallel']}_pp{v['pipeline_parallel']}_sp{v['sequence_data_parallel']}_gpu{v['total_gpus']}" + for v in variants + ], +) +def test_gptdata_streamin_gpus(fake_redis_server, variant, run_distributed_script_lean, result_path, request): + # TODO: make tests on the same number of gpu as subtests + # similar to how it is done in the test_model for speed + run_distributed_gptdata_streaming_test( + fake_redis_server=fake_redis_server, + variant=variant, + run_distributed_script=run_distributed_script_lean, + result_path=result_path, + request=request, + ingestion_type=IngestionType.CONSUMER_GROUP, + ) diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index bb53de29..53804d87 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -431,6 +431,7 @@ def reference_distributed_shard(get_convert_path) -> torch.Tensor | None: # We don't want to depend on `test_save_and_load_in_parallel` because we still want to run this in cas of failure. # This should still run after `test_save_and_load_in_parallel` @requires_cuda +# NOTE: Should it depend on test_model_distributed instead? @pytest.mark.depends_on(on=["test_load_pretrained[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.convert, ModelTestingGroup.distributed) def test_load_parallel_checkpoint_in_single_gpu( diff --git a/tests/trainer/events_fake_consumer.py b/tests/trainer/events_fake_consumer.py new file mode 100644 index 00000000..4c2d3089 --- /dev/null +++ b/tests/trainer/events_fake_consumer.py @@ -0,0 +1,105 @@ +import sys +from pathlib import Path + +import orjson +import redis +import safetensors.torch +import torch.distributed +import yaml + + +def main(): + if len(sys.argv) != 2: + print("Usage: python -m tests.trainer.events_fake_consumer ") + sys.exit(1) + + config_path = Path(sys.argv[1]) + if not config_path.exists(): + print(f"Config file {config_path} does not exist") + sys.exit(1) + + with config_path.open("rt") as f: + config = yaml.safe_load(f) + + consumer_cfg = config["consumer"] + world_size = consumer_cfg["world_size"] + rank = consumer_cfg["rank"] + results_path = Path(consumer_cfg["results_path"]) + results_path.mkdir(parents=True, exist_ok=True) + + consumer_id = f"[Consumer {rank}/{world_size}]" + + print(f"{consumer_id} Started with config:") + print(yaml.safe_dump(config)) + + assert config["events"]["weights_broadcast"]["enabled"] + assert config["events"]["training_finished"]["enabled"] + + redis_client = redis.Redis(host=config["events"]["redis"]["host"], port=config["events"]["redis"]["port"]) + + print(f"{consumer_id} waiting for pg rendezvous...") + weights_pg = torch.distributed.init_process_group( + backend="nccl", + init_method=f'tcp://{config["events"]["weights_broadcast"]["rdvz_master_address"]}:' + f'{config["events"]["weights_broadcast"]["rdvz_master_port"]}', + world_size=world_size, + rank=rank, + ) + broadcast_source_rank = config["events"]["weights_broadcast"]["rank"] + + last_id = "0-0" + msg_key = config["events"]["redis"]["payload_key"].encode() + stream_key = config["events"]["redis"]["stream_key"] + + print(f"{consumer_id} waiting for messages...") + while True: + result = redis_client.xread( + streams={stream_key: last_id}, + count=1, + block=200, + ) + + if not result: + continue + + _, events = result[0] + + for event_id, msg in events: + last_id = event_id + assert msg_key in msg + msg = orjson.loads(msg[msg_key].decode()) + print(f"{consumer_id} msg received: {msg}") + if msg["type"] == config["events"]["weights_broadcast"]["weights_ready_message_type"] or ( + msg["type"] == config["events"]["weights_broadcast"]["initial_weights_step_message_type"] + and config["events"]["weights_broadcast"]["initial_weights_step_message_includes_weights"] + ): + weights = {} + while True: + meta = [None] + torch.distributed.broadcast_object_list(meta, group=weights_pg, group_src=broadcast_source_rank) + meta = meta[0] + if meta is None: + print(f"{consumer_id} weight broadcast finished") + break + shard_name, layer_name, tensor_size, tensor_type = meta + tensor = torch.zeros( + tuple(tensor_size), dtype=tensor_type, device="cuda" + ) # so far consumer is single gpu only + torch.distributed.broadcast(tensor, group=weights_pg, group_src=broadcast_source_rank) + print(f"{consumer_id} {shard_name} layer {layer_name} {tensor_size} {tensor_type} received") + if shard_name == "weights": + weights[layer_name] = tensor + safetensors.torch.save_file(weights, results_path / f"{msg["step"]}.safetensors") + + elif msg["type"] == config["events"]["training_finished"]["training_finished_message_type"]: + torch.distributed.destroy_process_group() + (results_path / "training_finished").touch() + return + else: + raise RuntimeError(f"{consumer_id} Received unknown message type {msg}") + if msg["type"] == config["events"]["weights_broadcast"]["initial_weights_step_message_type"]: + (results_path / "initial_weights_step").touch() + + +if __name__ == "__main__": + main() diff --git a/tests/trainer/test_events.py b/tests/trainer/test_events.py new file mode 100644 index 00000000..4894a022 --- /dev/null +++ b/tests/trainer/test_events.py @@ -0,0 +1,407 @@ +import contextlib +import copy +import os +import pathlib +import subprocess +import time +import typing + +import pytest +import safetensors +import torch +import yaml + +from tests.utils.model_configs import MODEL_CONFIGS +from tests.utils.redis import redis_batch_producer +from tests.utils.utils import requires_cuda + + +@contextlib.contextmanager +def run_fake_events_consumers( + model_config: dict, + test_result_path: pathlib.Path, + broadcast_world_size: int, + fake_consumers_broadcast_ranks: list[int], + assigned_gpus: list[str], + timeout: float = 30.0, # seconds +): + """ + Context manager to run fake event consumer subprocesses for testing. + + Each subprocess gets a separate config and CUDA_VISIBLE_DEVICES. + + After exiting the context, all subprocesses are ensured to terminate. + Raises RuntimeError if any subprocess exits with non-zero code. + """ + import tests.trainer.events_fake_consumer + + assert len(assigned_gpus) > 0 + assert len(assigned_gpus) == len(fake_consumers_broadcast_ranks) + + processes = [] + + try: + for i, gpu in enumerate(assigned_gpus): + consumer_path = test_result_path / str(i) + consumer_path.mkdir(parents=True, exist_ok=True) + + # Deep copy config and update per consumer + this_config = copy.deepcopy(model_config) + this_config["consumer"] = { + "idx": i, + "results_path": consumer_path / "results", + "world_size": broadcast_world_size, + "rank": fake_consumers_broadcast_ranks[i], + } + this_config_path = consumer_path / "config.yaml" + + # Save config as YAML + with open(this_config_path, "w") as f: + yaml.safe_dump(convert_paths(this_config), f) + + # Build subprocess command + script = [ + "python", + "-m", + tests.trainer.events_fake_consumer.__name__, + str(this_config_path), + ] + env = os.environ.copy() + env["CUDA_VISIBLE_DEVICES"] = str(gpu) + + # Start subprocess + proc = subprocess.Popen(script, env=env) + processes.append(proc) + + # Yield control to the caller while subprocesses run + yield + + finally: + # Wait for processes to exit or kill after timeout + start_time = time.time() + for proc in processes: + try: + remaining = max(0, timeout - (time.time() - start_time)) + proc.wait(timeout=remaining) + except subprocess.TimeoutExpired: + proc.kill() + + # Check exit codes + errors = [(i, p.returncode) for i, p in enumerate(processes) if p.returncode != 0] + if errors: + raise RuntimeError(f"Some fake consumer subprocesses failed: {errors}") + + +def run_fast_llm_training(model_config, run_distributed_script, assigned_gpus): + import fast_llm.cli + + config_path = model_config["run"]["experiment_dir"] / "load_config.yaml" + config_path.parent.mkdir(parents=True, exist_ok=True) + with config_path.open("wt") as f: + yaml.safe_dump(convert_paths(model_config), f) + + script = [ + "-m", + fast_llm.cli.__name__, + "train", + "gpt", + "--config", + str(config_path), + ] + + env = os.environ.copy() + env["PYTHONHASHSEED"] = "42" + env["CUDA_VISIBLE_DEVICES"] = ",".join(str(gpu) for gpu in assigned_gpus) + run_distributed_script(script, num_gpus=len(assigned_gpus), env=env) + + +def compare_test_tensors_to_checkpoint(test_safetensor_path: str, checkpoint_dir: str): + """ + Compare a test-saved safetensor file (a dict of tensors) + to all safetensors in a checkpoint directory. + + Checks: + - tensor names must match + - shapes must match + - dtypes must match + - values must match (exact) + """ + + # ------------------------- + # Load test tensor file + # ------------------------- + test_tensors = {} + with safetensors.safe_open(test_safetensor_path, framework="pt", device="cpu") as f: + for key in f.keys(): + test_tensors[key] = f.get_tensor(key) + + assert len(test_tensors) > 0, f"No tensors found in {test_safetensor_path}." + + # ------------------------- + # Load checkpoint tensors + # ------------------------- + checkpoint_tensors = {} + + for file in os.listdir(checkpoint_dir): + if file.endswith(".safetensors"): + path = os.path.join(checkpoint_dir, file) + with safetensors.safe_open(path, framework="pt", device="cpu") as f: + for key in f.keys(): + if key in checkpoint_tensors: + raise AssertionError( + f"Duplicate tensor name '{key}' across checkpoint {checkpoint_dir} files." + ) + checkpoint_tensors[key] = f.get_tensor(key) + + assert len(checkpoint_tensors) > 0, f"No safetensors found in checkpoint directory: {checkpoint_dir}" + + # ------------------------- + # Compare tensor sets + # ------------------------- + test_names = set(test_tensors.keys()) + ckpt_names = set(checkpoint_tensors.keys()) + + unexpected_in_test = test_names - ckpt_names + missing_in_test = ckpt_names - test_names + + assert not missing_in_test, "Tensors missing in {test_safetensor_path}:\n" + "\n".join(sorted(missing_in_test)) + assert not unexpected_in_test, "Unexpected tensors in {test_safetensor_path}:\n" + "\n".join( + sorted(unexpected_in_test) + ) + + # ------------------------- + # Compare individual tensors + # ------------------------- + for name in sorted(test_names): + t_test = test_tensors[name] + t_ckpt = checkpoint_tensors[name] + + # dtype + assert t_test.dtype == t_ckpt.dtype, f"Mismatch in dtype for '{name}': " f"{t_test.dtype} != {t_ckpt.dtype}" + + # shape + assert t_test.shape == t_ckpt.shape, ( + f"Mismatch in shape for '{name}': " f"{tuple(t_test.shape)} != {tuple(t_ckpt.shape)}" + ) + + # values + if not torch.equal(t_test, t_ckpt): + diff = (t_test - t_ckpt).abs() + max_diff = diff.max().item() + idx = (diff > 0).nonzero(as_tuple=False) + example = idx[0].tolist() if idx.numel() > 0 else "unknown" + + raise AssertionError( + f"Tensor content mismatch for '{name}'.\n" + f"Max difference: {max_diff}\n" + f"Example differing index: {example}" + ) + + # If we reached here → all is good + return True + + +def check_events_results( + test_results_path_fast_llm, + test_results_path_consumers, + consumer_count, + training_steps, + model_checkpoint_format, +): + for consumer_idx in range(consumer_count): + consumer_test_results_path = test_results_path_consumers / str(consumer_idx) / "results" + assert (consumer_test_results_path / "training_finished").is_file() + assert (consumer_test_results_path / "initial_weights_step").is_file() + # NOTE: We do not test the initial weights broadcast result when enabled, + # because it is identical to subsequent broadcasts. + for training_step in range(1, training_steps + 1): + compare_test_tensors_to_checkpoint( + consumer_test_results_path / f"{training_step}.safetensors", + test_results_path_fast_llm / "export" / model_checkpoint_format / str(training_step), + ) + + +def convert_paths(obj): + if isinstance(obj, dict): + return {k: convert_paths(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [convert_paths(v) for v in obj] + elif isinstance(obj, tuple): + return tuple(convert_paths(v) for v in obj) + elif isinstance(obj, pathlib.Path): + return str(obj) + else: + return obj + + +def parallelism_variants(num_gpus: int) -> list[dict[str, int]]: + if num_gpus == 1: + return [{"tp": 1, "pp": 1, "sp": 1}] + + if num_gpus == 2: + return [ + # NOTE: Streaming dataset is currently not compatible with pipeline parallelism. + {"tp": 2, "pp": 1, "sp": 1}, + # {"tp": 1, "pp": 2, "sp": 1}, + {"tp": 1, "pp": 1, "sp": 2}, + ] + + if num_gpus == 4: + return [ + # NOTE: Streaming dataset is currently not compatible with pipeline parallelism. + {"tp": 4, "pp": 1, "sp": 1}, + # {"tp": 1, "pp": 4, "sp": 1}, + {"tp": 1, "pp": 1, "sp": 4}, + # {"tp": 2, "pp": 2, "sp": 1}, + # {"tp": 1, "pp": 2, "sp": 2}, + {"tp": 2, "pp": 1, "sp": 2}, + ] + + raise ValueError(f"Invalid gpu count for fast_llm parallelism {num_gpus}") + + +def consumer_counts(num_gpus: int) -> int: + if num_gpus == 2: + return 1 + if num_gpus == 3: + return 1 + if num_gpus == 4: + return 2 + if num_gpus == 5: + return 1 + if num_gpus == 6: + return 2 + if num_gpus == 7: + return 3 + if num_gpus >= 8: + return 4 + + +def generate_variants(num_gpus: int) -> list[dict[str, typing.Any]]: + """ + Generate all (consumer_count, tp/pp/sp) variants for given GPU count. + """ + results = [] + + if num_gpus < 2: + return results + if num_gpus == 2: + num_gpus = [2] + elif num_gpus <= 4: + num_gpus = [2, num_gpus] + else: + num_gpus = [2, 4, min(num_gpus, 8)] + + for gpus in num_gpus: + consumers = consumer_counts(gpus) + remaining = gpus - consumers + par_vars = parallelism_variants(remaining) + for pv in par_vars: + results.append( + { + "total_gpus": gpus, + "consumers_gpu_count": consumers, + "fast_llm_gpus_count": remaining, + "consumers_gpus": list(range(consumers)), + "fast_llm_gpus": list(range(consumers, gpus)), + "tensor_parallel": pv["tp"], + "pipeline_parallel": pv["pp"], + "sequence_data_parallel": pv["sp"], + } + ) + + return results + + +variants = generate_variants(torch.cuda.device_count()) + + +@pytest.mark.slow +@requires_cuda +@pytest.mark.parametrize( + "variant", + variants, + ids=[ + f"gpu{v['total_gpus']}_cgpus{v['consumers_gpu_count']}_fgpus{v['fast_llm_gpus_count']}" + f"_tp{v['tensor_parallel']}_pp{v['pipeline_parallel']}_sp{v['sequence_data_parallel']}" + for v in variants + ], +) +def test_trainer_events_with_streaming(fake_redis_server, variant, run_distributed_script_lean, result_path, request): + stream_config, fake_redis_client, fake_redis_server_killer = fake_redis_server + test_result_path = result_path / request.node.name + test_result_path_fast_llm = test_result_path / "fast_llm" + test_result_path_consumers = test_result_path / "consumers" + + broadcast_world_size = variant["consumers_gpu_count"] + 1 + fake_consumers_broadcast_ranks = list(range(variant["consumers_gpu_count"])) + fake_consumers_assigned_gpus = variant["consumers_gpus"] + fast_llm_broadcast_rank = variant["consumers_gpu_count"] + fast_llm_assigned_gpus = variant["fast_llm_gpus"] + train_iters = 2 + + model_config = copy.deepcopy(MODEL_CONFIGS["mistral"].config_dict) + model_config["data"]["datasets"] = {"training": stream_config.to_dict()} + model_config["data"]["sampling"] = {"shuffle": "disabled"} + model_config["training"]["train_iters"] = train_iters + model_config["training"]["export"] = {"interval": 1, "format": MODEL_CONFIGS["mistral"].checkpoint_format.name} + model_config["batch"]["micro_batch_size"] = 1 + model_config["batch"]["truncate_documents"] = False + model_config["run"]["experiment_dir"] = test_result_path_fast_llm + model_config["model"]["distributed"]["tensor_parallel"] = variant["tensor_parallel"] + model_config["model"]["distributed"]["pipeline_parallel"] = variant["pipeline_parallel"] + model_config["model"]["distributed"]["sequence_data_parallel"] = variant["sequence_data_parallel"] + + # We use same stream for messages in the test. Also make all fields explicit, + # so fake consumers can read them as well from this dict config + model_config["events"] = { + "redis": { + "host": stream_config.redis.host, + "port": stream_config.redis.port, + "stream_key": "fast_llm_events", + "payload_key": "event", + }, + "weights_broadcast": { + "enabled": True, + "initial_weights_step_message_type": "initial_weights_step", + "initial_weights_step_message_includes_weights": True, + "weights_ready_message_type": "weights_ready", + "rdvz_master_address": "127.0.0.1", + "rdvz_master_port": 19999, + "world_size": broadcast_world_size, + "rank": fast_llm_broadcast_rank, + }, + "training_finished": { + "enabled": True, + "training_finished_message_type": "training_finished", + }, + } + + batch_size = model_config["batch"]["batch_size"] + sequence_length = model_config["batch"]["sequence_length"] + with redis_batch_producer( + redis_client=fake_redis_client, + fake_redis_server_killer=fake_redis_server_killer, + stream_config=stream_config, + batch_size=batch_size, + sequence_length=sequence_length, + ): + with run_fake_events_consumers( + model_config=model_config, + test_result_path=test_result_path_consumers, + broadcast_world_size=broadcast_world_size, + fake_consumers_broadcast_ranks=fake_consumers_broadcast_ranks, + assigned_gpus=fake_consumers_assigned_gpus, + ): + run_fast_llm_training( + model_config=model_config, + run_distributed_script=run_distributed_script_lean, + assigned_gpus=fast_llm_assigned_gpus, + ) + check_events_results( + test_results_path_fast_llm=test_result_path_fast_llm, + test_results_path_consumers=test_result_path_consumers, + consumer_count=len(fake_consumers_assigned_gpus), + training_steps=train_iters, + model_checkpoint_format=MODEL_CONFIGS["mistral"].checkpoint_format.name, + ) diff --git a/tests/utils/redis.py b/tests/utils/redis.py new file mode 100644 index 00000000..7e9d3b68 --- /dev/null +++ b/tests/utils/redis.py @@ -0,0 +1,235 @@ +import contextlib +import pathlib +import socket +import threading +import time + +import fakeredis +import orjson +import pytest + +from fast_llm.data.dataset.config import ( + IngestionType, + SamplingConfig, + SamplingData, + SamplingParameters, + ShufflingType, + StreamingDatasetConfig, + StreamingDatasetRedisConfig, +) +from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig + + +def get_stream_config(): + return StreamingDatasetConfig( + redis=StreamingDatasetRedisConfig( + host="localhost", + port=6379, + stream_key="test_stream", + payload_key="data", + ), + group_name="test_group", + consumer_name_prefix="consumer", + ) + + +def find_free_port(): + """Find a free TCP port and return it.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +def push_msg(redis_client, config, tokens=None, stream_key_suffix=None): + """Push a message into FakeRedis stream.""" + msg = { + "tokens": tokens, + "tokens_dtype": "int64", + } + stream_key = config.redis.stream_key + if stream_key_suffix is not None: + stream_key += stream_key_suffix + redis_client.xadd(stream_key, {config.redis.payload_key: orjson.dumps(msg)}) + + +class FakeRedisServerKiller: + def __init__(self, server): + self._server = server + self._is_killed = False + self._lock = threading.Lock() + + def kill(self): + with self._lock: + if not self._is_killed: + try: + self._server.shutdown() + self._server.server_close() + finally: + self._is_killed = True + + +def wait_until_stream_empty( + redis_client, + stream_key, + consumer_group, + stop_event, + consumer_count, + ingestion_type: IngestionType, +): + if ingestion_type == IngestionType.CONSUMER_GROUP: + return wait_until_stream_empty_consumer_group(redis_client, stream_key, consumer_group, stop_event) + elif ingestion_type == IngestionType.ONE_STREAM: + raise NotImplementedError() + elif ingestion_type == IngestionType.N_STREAMS: + raise NotImplementedError() + else: + raise ValueError(f"Unknown ingestion type {ingestion_type.value}") + + +def wait_until_stream_empty_consumer_group(redis_client, stream_key, consumer_group, stop_event): + """ + Wait until lag == 0, meaning all messages have been delivered AND acknowledged. + Absence of group mean test has not started yet, so we wait + """ + consumer_group = consumer_group.encode() + while not stop_event.is_set(): + groups = redis_client.xinfo_groups(stream_key) + + g = next((g for g in groups if g["name"] == consumer_group), None) + if g is not None: + lag = g.get("lag", 0) + if lag == 0: + return + + time.sleep(0.05) + + +def get_consumer_count(redis_client, stop_event, config: StreamingDatasetConfig): + while not stop_event.is_set(): + res = redis_client.hget(f"{config.redis.stream_key}:consumer_count", "0") + if res is None: + time.sleep(0.05) + continue + return int(res) + + +@contextlib.contextmanager +def redis_batch_producer( + redis_client, fake_redis_server_killer, stream_config, batch_size, sequence_length, num_batches=None +): + stop_event = threading.Event() + thread_exc = [] + + def producer_loop(): + is_n_streams = stream_config.ingestion_type == IngestionType.N_STREAMS + try: + consumer_count = get_consumer_count(redis_client, stop_event, stream_config) + stream = stream_config.redis.stream_key + group = stream_config.group_name + batch_idx = 0 + while not stop_event.is_set(): + if num_batches is not None and batch_idx >= num_batches: + break + for i in range(batch_size): + if stop_event.is_set(): + return + push_msg( + redis_client, + stream_config, + [batch_idx * batch_size + i] * sequence_length, + stream_key_suffix=f"_{i % consumer_count}" if is_n_streams else None, + ) + wait_until_stream_empty( + redis_client, + stream, + group, + stop_event, + consumer_count=consumer_count, + ingestion_type=stream_config.ingestion_type, + ) + batch_idx += 1 + except Exception as e: + # if failed to push messages kill redis server so waiting side in the test would unlock + fake_redis_server_killer.kill() + thread_exc.append(e) + raise + + thread = threading.Thread(target=producer_loop, daemon=True) + thread.start() + + try: + yield + finally: + stop_event.set() + thread.join(timeout=10) + if thread_exc: + raise thread_exc[0] + + +def make_sampling(sequence_length, extra_tokens, num_samples, distributed): + return SamplingData( + parameters=SamplingParameters( + sequence_length=sequence_length, + extra_tokens=extra_tokens, + num_samples=num_samples, + truncate_documents=False, + ), + config=SamplingConfig(shuffle=ShufflingType.disabled), + distributed=distributed, + dataset_name="test", + cache_directory=pathlib.Path("/tmp"), + preprocessing=LanguageModelPreprocessingConfig(), + ) + + +@pytest.fixture +def stream_config(): + return get_stream_config() + + +@pytest.fixture +def fake_redis_server(stream_config): + # We search for free port as port from previous test can still be not free even after server shutdown + stream_config = stream_config.from_dict(stream_config.to_dict(), {("redis", "port"): find_free_port()}) + + server_address = (stream_config.redis.host, stream_config.redis.port) + + # ----- Monkey-patch handler to suppress broken pipes ----- + orig_handle = fakeredis._tcp_server.TCPFakeRequestHandler.handle + + def safe_handle(self): + try: + orig_handle(self) + except (ConnectionResetError, BrokenPipeError): + # Client disconnected abruptly (e.g., when a PyTorch DataLoader iterator is deleted). + # These errors occur only with fake Redis and can be safely ignored. + pass + except Exception as e: + print(f"Unexpected exception in fake Redis handler: {e}") + + fakeredis._tcp_server.TCPFakeRequestHandler.handle = safe_handle + + server = fakeredis.TcpFakeServer(server_address, server_type="redis") + server_killer = FakeRedisServerKiller(server) + + # ----- Start server thread ----- + def serve(): + try: + server.serve_forever() + except Exception: + # Extra safety: catch anything from serve_forever + pass + + thread = threading.Thread(target=serve, daemon=True) + thread.start() + + # ----- reate a redis-py client pointing at the fake serve ----- + import redis + + client = redis.Redis(host=server_address[0], port=server_address[1]) + + yield stream_config, client, server_killer + + # ----- Teardown ----- + server_killer.kill() + thread.join() diff --git a/tests/utils/run_test_script.py b/tests/utils/run_test_script.py index 5c07324c..e880e67e 100644 --- a/tests/utils/run_test_script.py +++ b/tests/utils/run_test_script.py @@ -59,6 +59,17 @@ def run_distributed_script( ) +@pytest.fixture(scope="session") +def run_distributed_script_lean( + worker_resources: "WorkerResources", +): + return functools.partial( + do_run_distributed_script, + rendezvous_port=worker_resources.rendezvous_port, + torchrun_port=worker_resources.torchrun_port, + ) + + @pytest.fixture(scope="session") def run_test_script_base_path(model_testing_config, result_path, request): return result_path / "models" / model_testing_config.name