Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
84 commits
Select commit Hold shift + click to select a range
1a18929
Dataset interface
jlamypoirier Oct 15, 2025
fd63846
misc
jlamypoirier Oct 15, 2025
2486caf
fix
jlamypoirier Oct 15, 2025
92e93e8
Language model sample
jlamypoirier Oct 16, 2025
d6f6944
fix
jlamypoirier Oct 16, 2025
5c802fa
fixes
jlamypoirier Oct 16, 2025
95d1840
test
jlamypoirier Oct 16, 2025
eafd9cb
fixes
jlamypoirier Oct 17, 2025
c56df69
cleanup
jlamypoirier Oct 17, 2025
7f437e1
misc
jlamypoirier Oct 17, 2025
dfd27f5
misc
jlamypoirier Oct 17, 2025
90cd009
Memmap dataset
jlamypoirier Oct 18, 2025
acfd30e
fixes
jlamypoirier Oct 29, 2025
34939e9
fixes
jlamypoirier Oct 29, 2025
c5fa072
int64
jlamypoirier Oct 29, 2025
cd28676
Test and fix preparator
jlamypoirier Nov 5, 2025
435d214
fix
jlamypoirier Nov 5, 2025
f6bef55
fix
jlamypoirier Nov 6, 2025
e05d9a1
fix
jlamypoirier Nov 6, 2025
9ba8d1b
fix
jlamypoirier Nov 6, 2025
b35b297
fixes
jlamypoirier Nov 6, 2025
abe2357
misc
jlamypoirier Nov 11, 2025
1801d87
fix
jlamypoirier Nov 11, 2025
2223b85
fix right stage mode
bigximik Nov 13, 2025
a9a4ace
newer transformers fixes
bigximik Nov 13, 2025
97f2b60
fix distributed tests skip on single gpu
bigximik Nov 13, 2025
0fdc978
set mamba 2 style model conversions to broke
bigximik Nov 13, 2025
665deb5
Merge branch 'jlp/dataset_interface' of github.com:ServiceNow/Fast-LL…
bigximik Nov 17, 2025
4d03889
Merge branch 'jlp/lm_sample' of github.com:ServiceNow/Fast-LLM into d…
bigximik Nov 17, 2025
224c2ec
mmaba2 enable conversion tests
bigximik Nov 17, 2025
f1afbf2
Merge branch 'jlp/memmap_dataset' of github.com:ServiceNow/Fast-LLM i…
bigximik Nov 17, 2025
00bba27
added model_and_sequence_data_group
bigximik Nov 23, 2025
5b20276
added Iterable dataset base classes
bigximik Nov 23, 2025
978a68f
added naive sampled iterable dataset
bigximik Nov 23, 2025
066a0bf
added iterable dataset configs, streaming dataset and PipelineRL samp…
bigximik Nov 23, 2025
68b3d65
added distributed data loader wrapper
bigximik Nov 23, 2025
2fbfe99
added iterable dataset to gpt data
bigximik Nov 23, 2025
0892523
appended comment
bigximik Nov 23, 2025
54fadb4
changed base classes for iterable dataset configs
bigximik Nov 24, 2025
4e11bf3
fix batch type
bigximik Nov 24, 2025
8428df8
fix added name property to the class
bigximik Nov 24, 2025
04ee4d7
add eof for tests
bigximik Nov 24, 2025
1217998
change base class to torch iterable
bigximik Nov 24, 2025
c542dac
added straming dataset, sampling and base data tests
bigximik Nov 24, 2025
3999a8e
merge from main
bigximik Nov 24, 2025
c6ef780
merge from main
bigximik Nov 24, 2025
a1556f8
change import
bigximik Nov 24, 2025
63737b1
fix iterable sampler for spawn, add fake redis server to multi proces…
bigximik Nov 25, 2025
e843c8e
preparation for multi gpu tests
bigximik Nov 25, 2025
d5ce3f2
added multi gpu gptdata streaming test
bigximik Nov 26, 2025
c13c6df
added streming dataset requirements
bigximik Nov 27, 2025
e6d8f49
added streaming dataset installation to tests
bigximik Nov 27, 2025
1e92dd4
removed cheking for max samples
bigximik Nov 27, 2025
3ac4882
remved test eof, reduces timeout
bigximik Nov 27, 2025
46db991
changed tests to work without eof or max_samplmes_count
bigximik Nov 27, 2025
187055b
fix quen2 converter to accept qkv biases properly
bigximik Nov 28, 2025
21833a0
fix import errors
rafapi Dec 4, 2025
2f5f848
changes to config
bigximik Dec 8, 2025
1e07bad
Merge branch 'denis/new_datasets' of github.com:ServiceNow/Fast-LLM i…
bigximik Dec 8, 2025
c8cb9fd
added tensor iterator
bigximik Dec 10, 2025
e367998
added trainer events
bigximik Dec 10, 2025
5230b74
update test for changed config
bigximik Dec 10, 2025
1a94de5
added 2 gpus trainer events test
bigximik Dec 10, 2025
6cfd445
fix for multiple gpus
bigximik Dec 10, 2025
333665d
updated test to multiple gpus
bigximik Dec 10, 2025
5d1f474
added not implemented for pp streaming
bigximik Dec 12, 2025
5f7cb29
removed PipelineRL sample and batch
bigximik Dec 12, 2025
d07a900
base radis and streaming dataset config class refactoring
bigximik Dec 12, 2025
3a7ba92
refactoring of redis config, trainer event config, corresponding tests
bigximik Dec 12, 2025
59f6f7d
removed eof message which is not supported
bigximik Dec 12, 2025
2c20ebd
added implementation for initial_weights_step_message_type event
bigximik Dec 12, 2025
f4107c3
removed explicit msg ack
bigximik Dec 16, 2025
c32ef89
fix of training finished event
bigximik Dec 16, 2025
f637649
alternative streaming immplementaions: one stream and n streams witho…
bigximik Dec 16, 2025
e43ce95
Merge remote-tracking branch 'origin/main' into denis/new_datasets
jlamypoirier Dec 16, 2025
5545598
merge from main
bigximik Dec 16, 2025
0d198ff
fix after merge added preprocessing empty configs
bigximik Dec 16, 2025
70ef5c4
fix for tests with no import
bigximik Dec 16, 2025
058c93c
fixes
jlamypoirier Dec 16, 2025
d34d39a
Merge remote-tracking branch 'origin/denis/new_datasets' into denis/n…
jlamypoirier Dec 16, 2025
ffb0a5f
Merge remote-tracking branch 'origin/main' into denis/new_datasets
jlamypoirier Dec 16, 2025
359231f
removed cloudpickle
bigximik Dec 16, 2025
ca9e94e
Simplify distributed
jlamypoirier Dec 16, 2025
ddd841d
Merge remote-tracking branch 'origin/main' into denis/new_datasets
jlamypoirier Dec 22, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
pip install pybind11
FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE \
MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE \
pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,GENERATION,DEV,DOCS,VISION]"
pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,STREAMING,DEV,DOCS]"
- name: Run tests
run: pytest -v -ra .

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/docs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
pip install pybind11
FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE \
MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE \
pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,GENERATION,DEV,DOCS,VISION]"
pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,DEV,DOCS]"
- name: Build the documentation
run: mkdocs build

Expand Down
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ COPY --chmod=777 ./fast_llm/__init__.py fast_llm/
COPY --chmod=777 ./fast_llm/csrc/ fast_llm/csrc/

# Install dependencies within the virtual environment.
RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,DEV]" triton==3.5.1
RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,STREAMING,DEV]" triton==3.5.1

# Copy the remaining source code with universal write permissions.
COPY --chmod=777 ./Megatron-LM Megatron-LM
Expand Down
52 changes: 52 additions & 0 deletions fast_llm/data/data/data_loader_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import torch.distributed
import torch.utils.data.dataloader

from fast_llm.core.distributed import broadcast_object


class DistributedDataLoaderWrapper:
"""
Wraps a regular dataloader so that only the process group leader
loads data, and then broadcasts the batch to other ranks in the group.
"""

def __init__(
self,
dataloader: torch.utils.data.dataloader.DataLoader | None,
rank: int,
process_group: torch.distributed.ProcessGroup | None,
):
self.dataloader = dataloader
self.rank = rank
self.process_group = process_group

assert (self.rank == 0 and self.dataloader is not None) or (self.rank > 0 and self.dataloader is None)

def __iter__(self):
if self.rank == 0:
self.iterator = iter(self.dataloader)
if self.process_group is None:
return self.iterator
return self

def __next__(self):
# TODO:
# Instead of broadcasting a general object, make this iterator yield an actual Batch class.
# Implement `get_state_dict` and `from_state_dict` in the Batch class so that we can
# efficiently broadcast tensors directly. This avoids using `broadcast_object` on the
# entire Batch object, which is inefficient for tensors because it serializes
# (pickles) them before sending.

if self.rank == 0:
try:
data = next(self.iterator) # may raise StopIteration
except Exception as e:
data = e
data = broadcast_object(data, self.process_group, 0)
else:
data = broadcast_object(None, self.process_group, 0)

if isinstance(data, Exception):
raise data

return data
42 changes: 37 additions & 5 deletions fast_llm/data/data/gpt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@

from fast_llm.core.distributed import safe_barrier
from fast_llm.data.data.abstract import Data
from fast_llm.data.data.data_loader_wrapper import DistributedDataLoaderWrapper
from fast_llm.data.data.gpt.config import GPTDataConfig
from fast_llm.data.dataset.abstract import SampledDataset
from fast_llm.data.dataset.abstract_iterable import SampledIterableDataset
from fast_llm.data.dataset.config import SamplingParameters
from fast_llm.data.dataset.gpt.config import GPTSamplingData
from fast_llm.data.dataset.monitor import DatasetMonitor
Expand Down Expand Up @@ -90,7 +92,12 @@ def setup(
dataset_name=dataset_name,
)
dataset = self._config.datasets[dataset_name].build_and_sample(sampling)
self._datasets[dataset_name] = DatasetMonitor(dataset, self._config.data_sample_warn_time_ms)
if isinstance(dataset, SampledDataset):
self._datasets[dataset_name] = DatasetMonitor(dataset, self._config.data_sample_warn_time_ms)
else:
# Do not set monitor for iterable dataset as monitor only works with map style datasets
assert isinstance(dataset, SampledIterableDataset)
self._datasets[dataset_name] = dataset

safe_barrier(self._distributed.world_group, "data_preparation", timeout)
self._is_setup = True
Expand All @@ -116,9 +123,11 @@ def get_iterator(
Assert.in_range_incl(batch_config.sequence_length, 1, sampling_parameters.sequence_length)
log_main_rank(f"Initializing {dataset_name} dataset iterator from sample {consumed_samples}...")

return iter(
torch.utils.data.DataLoader(
self._datasets[dataset_name], # noqa
dataset = self._datasets[dataset_name]

if isinstance(dataset, SampledDataset):
data_loader = torch.utils.data.DataLoader(
dataset, # noqa
batch_sampler=SampledDatasetIterator(
total_samples=len(self._datasets[dataset_name]),
begin_index=consumed_samples,
Expand All @@ -132,4 +141,27 @@ def get_iterator(
collate_fn=LanguageModelBatch.from_samples,
multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None,
)
)

elif isinstance(dataset, SampledIterableDataset):
if (
self.distributed.model_and_sequence_data_group is None
or self.distributed.model_and_sequence_data_group.rank() == 0
):
rank = 0
data_loader = torch.utils.data.DataLoader(
dataset, # noqa
batch_size=batch_config.micro_batch_size,
num_workers=0 if num_workers == 0 else 1,
prefetch_factor=prefetch_factor,
pin_memory=True,
collate_fn=LanguageModelBatch.from_samples,
multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None,
)
else:
rank = self.distributed.model_and_sequence_data_group.rank()
data_loader = None
data_loader = DistributedDataLoaderWrapper(
data_loader, rank, self.distributed.model_and_sequence_data_group
)

return iter(data_loader)
1 change: 0 additions & 1 deletion fast_llm/data/dataset/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def __len__(self) -> int:


class SamplableDataset[SampleType: Sample](Dataset[SampleType]):

@abc.abstractmethod
def sample(self, config: "SamplingData") -> SampledDataset[SampleType]:
pass
30 changes: 30 additions & 0 deletions fast_llm/data/dataset/abstract_iterable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import abc
import typing

import torch.utils.data

from fast_llm.data.sample.abstract import Sample

if typing.TYPE_CHECKING:
from fast_llm.data.dataset.config import SamplingData


# NOTE: We need to inherit from IterableDataset otherwise torch data loader can not detect it properly
class SampledIterableDataset[SampleType: Sample](torch.utils.data.IterableDataset[SampleType]):
"""
A sampled dataset class that provides an iterator over samples.
"""

# NOTE: We add name here so it is compatible with Fast-LLM Dataset
@property
@abc.abstractmethod
def name(self) -> str:
"""
A name for the dataset to facilitate identification and debugging.
"""


class SamplableIterableDataset[SampleType: Sample](SampledIterableDataset[SampleType]):
@abc.abstractmethod
def sample(self, config: "SamplingData") -> SampledIterableDataset[SampleType]:
pass
106 changes: 101 additions & 5 deletions fast_llm/data/dataset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
import pathlib
import typing

from fast_llm.config import Config, Field, FieldHint, UpdateType, check_field, config_class
from fast_llm.config import Config, Field, FieldHint, FieldUpdate, UpdateType, check_field, config_class
from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset
from fast_llm.data.preprocessing.abstract import PreprocessingConfig
from fast_llm.data.sample.abstract import Sample
from fast_llm.redis.config import RedisConfig
from fast_llm.utils import Assert, normalize_probabilities

if typing.TYPE_CHECKING:
from fast_llm.data.dataset.abstract_iterable import SamplableIterableDataset, SampledIterableDataset
from fast_llm.data.dataset.indexed import ConcatenatedDataset, DatasetSlice, IndexedDataset
from fast_llm.engine.distributed.distributed import Distributed

Expand Down Expand Up @@ -106,19 +108,25 @@ class DatasetConfig[SampleType: Sample](Config):
@config_class(registry=True)
class SampledDatasetConfig[SampleType: Sample](DatasetConfig[SampleType]):
"""
A sampled dataset containing a prepared list of samples to be indexed sequentially (as-is) during training.
A sampled dataset containing a prepared list or iterable of samples to be indexed sequentially (as-is) during training.
"""

def build_and_sample(self, sampling: SamplingData) -> SampledDataset[SampleType]:
def build_and_sample(
self, sampling: SamplingData
) -> "SampledDataset[SampleType] | SampledIterableDataset[SampleType]":
raise NotImplementedError()


@config_class()
class SamplableDatasetConfig[SampleType: Sample](SampledDatasetConfig[SampleType]):
def build(self, preprocessing: PreprocessingConfig) -> SamplableDataset[SampleType]:
def build(
self, preprocessing: PreprocessingConfig
) -> "SamplableDataset[SampleType] | SamplableIterableDataset[SampleType]":
raise NotImplementedError()

def build_and_sample(self, sampling: SamplingData) -> SampledDataset[SampleType]:
def build_and_sample(
self, sampling: SamplingData
) -> "SampledDataset[SampleType] | SampledIterableDataset[SampleType]":
return self.build(sampling.preprocessing).sample(sampling)


Expand Down Expand Up @@ -298,3 +306,91 @@ def build(self, preprocessing: PreprocessingConfig) -> "IndexedDataset[SampleTyp
return LegacyMemmapDataset[SampleType](name, self.path, preprocessing)
else:
raise FileNotFoundError(self.path)


@config_class()
class StreamingDatasetRedisConfig(RedisConfig):
stream_key: str = FieldUpdate(default="fast_llm_streaming")

payload_key: str = FieldUpdate(
default="data",
)


class IngestionType(str, enum.Enum):
CONSUMER_GROUP = "consumer_group"
ONE_STREAM = "one_stream"
N_STREAMS = "n_streams"


class HashType(str, enum.Enum):
MESSAGE_INDEX = "message_index"
"""Use the index of the received message for hashing. Provides precise distribution but may not be well shuffled."""

MESSAGE_ID = "message_id"
"""Hash messages based on their unique message ID. Good for probabilistic distribution.
Redis message IDs are regenerated each time, so this is not reproducible.
"""

MESSAGE_BODY = "message_body"
"""Hash messages based on their payload content (bytes). Distributes messages roughly evenly.
Deterministic based on message content, but not perfectly balanced across ranks.
"""

PRODUCER_PROVIDED = "producer_provided"
"""Use the hash or index provided by the producer. Allows deterministic splitting and perfect balance."""


@config_class(dynamic_type={SampledDatasetConfig: "streaming"})
class StreamingDatasetConfig[SampleType: LanguageModelSample](SamplableDatasetConfig[SampleType]):
"""
Configuration for a streaming dataset that reads training data from a Redis stream.
"""

_abstract = False

redis: StreamingDatasetRedisConfig = Field(
desc="Redis connection and stream settings used to fetch incoming training data.",
hint=FieldHint.core,
)

group_name: str = Field(
default="fast_llm_dp_group",
desc="Name of the Redis consumer group used for data-parallel reading.",
hint=FieldHint.core,
)

consumer_name_prefix: str = Field(
default="fast_llm_dp_group_consumer",
desc="Prefix used to generate unique consumer names for each rank in Redis consumer group.",
hint=FieldHint.core,
)

ingestion_type: IngestionType = Field(
default=IngestionType.CONSUMER_GROUP,
desc="Strategy used to ingest data from Redis streams (consumer group, single stream, or multiple streams).",
hint=FieldHint.core,
)

hash_type: HashType = Field(
default=HashType.MESSAGE_ID,
desc="How to compute hash for assigning messages to ranks.",
hint=FieldHint.core,
)

hash_key: str = Field(
default="hash",
desc="Key in the message dict containing the hash or index provided by the producer.",
hint=FieldHint.core,
)

ack_period_per_consumer: int = Field(
default=10,
desc="Number of messages after which the consumer acknowledges received IDs back to the Redis hash.",
hint=FieldHint.core,
)

def build_and_sample(self, sampling: SamplingData) -> "SampledIterableDataset[SampleType]":
from fast_llm.data.dataset.streaming import StreamingDataset

return StreamingDataset[SampleType](self, sampling.distributed).sample(sampling)
53 changes: 53 additions & 0 deletions fast_llm/data/dataset/sampled.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import yaml

from fast_llm.data.dataset.abstract import SampledDataset
from fast_llm.data.dataset.abstract_iterable import SamplableIterableDataset, SampledIterableDataset
from fast_llm.data.dataset.config import SamplingData, ShufflingType
from fast_llm.data.dataset.indexed import IndexedDataset
from fast_llm.data.sample.abstract import Sample
Expand Down Expand Up @@ -429,3 +430,55 @@ def _load_yaml_data(self, data: dict[str, typing.Any]) -> None:

self._unshuffled_tokens = data["unshuffled_tokens"]
self._unshuffled_documents = data["unshuffled_epochs"] * self._documents_per_epoch


class NaiveSampledIterableDataset[SampleType: Sample](SampledIterableDataset[SampleType]):
def __init__(
self,
iterable_dataset: SamplableIterableDataset[SampleType],
sampling: SamplingData,
):
self._dataset = iterable_dataset
self._config = sampling.config
self._parameters = sampling.parameters

assert self._parameters.truncate_documents == False
assert self._config.shuffle == ShufflingType.disabled

def __iter__(self) -> typing.Iterator[SampleType]:
sample_length = self._parameters.sequence_length + self._parameters.extra_tokens
current_sample_length = 0
documents: list[SampleType] = []
for doc in self._dataset:
if len(doc) > sample_length:
logging.warning(f"Dropping doc with length {len(doc)} higher then sample_length {sample_length}")
continue
if current_sample_length + len(doc) > sample_length:
padding_length = sample_length - current_sample_length
assert padding_length > 0
documents.append(documents[-1].get_padding(padding_length))

yield documents[0].from_documents(documents)

documents = [doc]
current_sample_length = len(doc)
else:
documents.append(doc)
current_sample_length += len(doc)

if current_sample_length == sample_length:
yield documents[0].from_documents(documents)

documents = []
current_sample_length = 0

if current_sample_length > 0:
padding_length = sample_length - current_sample_length
assert padding_length > 0
documents.append(documents[-1].get_padding(padding_length))

yield documents[0].from_documents(documents)

@property
def name(self) -> str:
return self._dataset.name
Loading
Loading