Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions .github/ISSUE_TEMPLATE/feature_request.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,26 @@ assignees: ''
---

# 🎯 **Goal (What & Why)**
> **Clearly state the purpose of this feature.**
> **Clearly state the purpose of this feature.**
> _(Example: Add FP8 support using torchao to improve training throughput by 1.5x.)_

# 🚀 **Execution Plan**
> _(This section may start as an incomplete draft but must be defined before implementation begins.)_
> _(This section may start as an incomplete draft but must be defined before implementation begins.)_

### **Step 1: What is the smallest working version?**
> _(Describe the simplest way to implement this feature with minimal effort.)_
> _(Describe the simplest way to implement this feature with minimal effort.)_

### **Step 2: What additional optimizations are possible (but optional)?**
> _(List potential refinements that can be added in later PRs if needed.)_
### **Step 2: What additional optimizations are possible (but optional)?**
> _(List potential refinements that can be added in later PRs if needed.)_

# 📌 **Acceptance Criteria** (Must-Haves for Completion)
* The feature must be **functional and tested**.
* The implementation must be **documented in practical terms**.
* The PR must include a **performance/impact summary**.
* **No refactors unless directly necessary** for feature completion.
* The feature must be **functional and tested**.
* The implementation must be **documented in practical terms**.
* The PR must include a **performance/impact summary**.
* **No refactors unless directly necessary** for feature completion.

# 🛠️ **Project Management**
- [ ] **Assign the project to the Fast-LLM project.**
- [ ] **Set the `Estimate` field (in days) in the GitHub project.**
- [ ] **Use the `Size` field to categorize the PR size (Small/Medium/Large).**
- [ ] **Assign an owner when opening the issue.**
- [ ] **Assign an owner when opening the issue.**
37 changes: 27 additions & 10 deletions fast_llm/core/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,15 @@
logger = logging.getLogger(__name__)


def add_ephemeral_timeout(group: ProcessGroup, timeout: float | None = None) -> None:
@contextlib.contextmanager
def set_timeout(group: ProcessGroup | None, timeout: float | None = None):
if group is not None and timeout is not None:
# TODO: Only works for nccl?
group._add_ephemeral_timeout(datetime.timedelta(seconds=timeout))
timeout_ = group.options._timeout
group.set_timeout(datetime.timedelta(seconds=timeout))
yield
group.set_timeout(timeout_)
else:
yield


def broadcast(
Expand All @@ -43,8 +48,8 @@ def broadcast(
opts = torch.distributed.BroadcastOptions()
opts.rootRank = src
opts.rootTensor = 0
add_ephemeral_timeout(group, timeout)
work = group.broadcast([tensor], opts)
with set_timeout(group, timeout):
work = group.broadcast([tensor], opts)
if async_op:
return work
else:
Expand All @@ -55,7 +60,7 @@ def broadcast(
def check_parallel_match(tensor: torch.Tensor, group: ProcessGroup | None, name: str) -> None:
# A utility function to check for tensor-parallel (or other) mismatches.
all_tensors = tensor.new_empty((group.size(),) + tensor.shape)
all_gather_into_tensor(all_tensors, tensor, group)
all_gather_into_tensor(all_tensors, tensor.unsqueeze(0), group)

mismatches = (all_tensors != tensor).any(dim=0)
num_mismatches = mismatches.sum().item()
Expand Down Expand Up @@ -84,8 +89,8 @@ def allreduce_scalar(
) -> float | int:
if group:
value = torch.full([1], value, dtype=dtype, device=torch.cuda.current_device())
add_ephemeral_timeout(group, timeout)
torch.distributed.all_reduce(value, op=op, group=group)
with set_timeout(group, timeout):
torch.distributed.all_reduce(value, op=op, group=group)
return value.item()
else:
return value
Expand All @@ -99,9 +104,9 @@ def all_gather_scalar(
):
if group:
value = torch.full([1], value, dtype=dtype, device=torch.cuda.current_device())
add_ephemeral_timeout(group, timeout)
output_tensor = value.new_empty((group.size(),))
torch.distributed.all_gather_into_tensor(output_tensor, value, group=group)
with set_timeout(group, timeout):
torch.distributed.all_gather_into_tensor(output_tensor, value, group=group)
return output_tensor.tolist()
else:
return value
Expand Down Expand Up @@ -147,6 +152,12 @@ def broadcast_object(input_object: typing.Any | None, group: ProcessGroup | None

def send(tensor: torch.Tensor, dst: int, group: ProcessGroup, async_op=False, tag: int = 0) -> Work | None:
assert group is not None
if isinstance(group, torch.distributed.ProcessGroupGloo) and tensor.device.type != "cpu":
# send not supported for gloo on GPU.
tensor_cpu = tensor.cpu()
group.send([tensor_cpu], dst, tag).wait()
tensor.copy_(tensor_cpu)
return None
work = group.send([tensor], dst, tag)
if async_op:
return work
Expand All @@ -157,6 +168,12 @@ def send(tensor: torch.Tensor, dst: int, group: ProcessGroup, async_op=False, ta

def recv(tensor: torch.Tensor, src: int, group: ProcessGroup, async_op=False, tag: int = 0) -> Work | None:
assert group is not None
if isinstance(group, torch.distributed.ProcessGroupGloo) and tensor.device.type != "cpu":
# recv not supported for gloo on GPU.
tensor_cpu = tensor.cpu()
group.recv([tensor_cpu], src, tag).wait()
tensor.copy_(tensor_cpu)
return None
work = group.recv([tensor], src, tag)
if async_op:
return work
Expand Down
72 changes: 72 additions & 0 deletions fast_llm/data/data/data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import itertools
import typing

import torch.utils.data

from fast_llm.core.distributed import broadcast_object


class SampledDatasetIterator(torch.utils.data.Sampler):
"""
A distributed sampler generating indices for a `SampledDataset` (i.e., the natural numbers).
To be used as the `batch_sampler` of a `torch.utils.data.DataLoader`.
"""

def __init__(self, total_samples, begin_index, micro_batch_size, data_rank, data_parallel):
super().__init__()
self._total_samples = total_samples
self._begin_index = begin_index
self._batch_size = micro_batch_size * data_parallel
self._start_idx = data_rank * micro_batch_size
self._end_idx = (data_rank + 1) * micro_batch_size

def __len__(self) -> int:
return self._total_samples

def __iter__(self) -> typing.Iterator[list[int]]:
for idx in range(self._begin_index, self._total_samples - self._batch_size + 1, self._batch_size):
yield list(range(idx + self._start_idx, idx + self._end_idx))


class DistributedDataLoaderWrapper:
"""
Wraps a regular dataloader so that only the process group leader
loads data, and then broadcasts the batch to other ranks in the group.
"""

def __init__(
self,
data_loader: torch.utils.data.dataloader.DataLoader,
process_group: torch.distributed.ProcessGroup | None,
):
self._data_loader = data_loader
self._rank = 0 if process_group is None else process_group.rank()
self._process_group = process_group

def __iter__(self):
if self._rank == 0:
self._iterator = iter(self._data_loader)
else:
self._iterator = itertools.repeat(None)
if self._process_group is None:
return self._iterator
return self

def __next__(self):
# TODO:
# Instead of broadcasting a general object, make this iterator yield an actual Batch class.
# Implement `get_state_dict` and `from_state_dict` in the Batch class so that we can
# efficiently broadcast tensors directly. This avoids using `broadcast_object` on the
# entire Batch object, which is inefficient for tensors because it serializes
# (pickles) them before sending.

try:
data = next(self._iterator) # may raise StopIteration
except Exception as e:
data = e
data = broadcast_object(data, self._process_group, 0)

if isinstance(data, Exception):
raise data

return data
52 changes: 0 additions & 52 deletions fast_llm/data/data/data_loader_wrapper.py

This file was deleted.

69 changes: 20 additions & 49 deletions fast_llm/data/data/gpt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,12 @@

from fast_llm.core.distributed import safe_barrier
from fast_llm.data.data.abstract import Data
from fast_llm.data.data.data_loader_wrapper import DistributedDataLoaderWrapper
from fast_llm.data.data.data_loader import DistributedDataLoaderWrapper, SampledDatasetIterator
from fast_llm.data.data.gpt.config import GPTDataConfig
from fast_llm.data.dataset.abstract import SampledDataset
from fast_llm.data.dataset.abstract_iterable import SampledIterableDataset
from fast_llm.data.dataset.config import SamplingParameters
from fast_llm.data.dataset.gpt.config import GPTSamplingData
from fast_llm.data.dataset.monitor import DatasetMonitor
from fast_llm.data.iterator import SampledDatasetIterator
from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig
from fast_llm.data.sample.language_model import LanguageModelBatch
from fast_llm.engine.config_utils.run import log_main_rank
Expand Down Expand Up @@ -92,12 +90,7 @@ def setup(
dataset_name=dataset_name,
)
dataset = self._config.datasets[dataset_name].build_and_sample(sampling)
if isinstance(dataset, SampledDataset):
self._datasets[dataset_name] = DatasetMonitor(dataset, self._config.data_sample_warn_time_ms)
else:
# Do not set monitor for iterable dataset as monitor only works with map style datasets
assert isinstance(dataset, SampledIterableDataset)
self._datasets[dataset_name] = dataset
self._datasets[dataset_name] = DatasetMonitor(dataset, self._config.data_sample_warn_time_ms)

safe_barrier(self._distributed.world_group, "data_preparation", timeout)
self._is_setup = True
Expand All @@ -123,45 +116,23 @@ def get_iterator(
Assert.in_range_incl(batch_config.sequence_length, 1, sampling_parameters.sequence_length)
log_main_rank(f"Initializing {dataset_name} dataset iterator from sample {consumed_samples}...")

dataset = self._datasets[dataset_name]

if isinstance(dataset, SampledDataset):
data_loader = torch.utils.data.DataLoader(
dataset, # noqa
batch_sampler=SampledDatasetIterator(
total_samples=len(self._datasets[dataset_name]),
begin_index=consumed_samples,
micro_batch_size=batch_config.micro_batch_size,
data_rank=self._distributed.config.batch_data_rank,
data_parallel=self._distributed.config.batch_data_parallel,
),
num_workers=num_workers,
prefetch_factor=prefetch_factor,
pin_memory=True,
collate_fn=LanguageModelBatch.from_samples,
multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None,
)

elif isinstance(dataset, SampledIterableDataset):
if (
self.distributed.model_and_sequence_data_group is None
or self.distributed.model_and_sequence_data_group.rank() == 0
):
rank = 0
data_loader = torch.utils.data.DataLoader(
dataset, # noqa
batch_size=batch_config.micro_batch_size,
num_workers=0 if num_workers == 0 else 1,
prefetch_factor=prefetch_factor,
pin_memory=True,
collate_fn=LanguageModelBatch.from_samples,
multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None,
)
else:
rank = self.distributed.model_and_sequence_data_group.rank()
data_loader = None
data_loader = DistributedDataLoaderWrapper(
data_loader, rank, self.distributed.model_and_sequence_data_group
)
data_loader = torch.utils.data.DataLoader(
self._datasets[dataset_name], # noqa
batch_sampler=SampledDatasetIterator(
total_samples=len(self._datasets[dataset_name]),
begin_index=consumed_samples,
micro_batch_size=batch_config.micro_batch_size,
data_rank=self._distributed.config.batch_data_rank,
data_parallel=self._distributed.config.batch_data_parallel,
),
num_workers=num_workers,
prefetch_factor=prefetch_factor,
pin_memory=True,
collate_fn=LanguageModelBatch.from_samples,
multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None,
)

if self._datasets[dataset_name].requires_broadcast:
data_loader = DistributedDataLoaderWrapper(data_loader, self.distributed.model_and_sequence_data_group)

return iter(data_loader)
21 changes: 21 additions & 0 deletions fast_llm/data/dataset/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

if typing.TYPE_CHECKING:
from fast_llm.data.dataset.config import SamplingData
from fast_llm.data.dataset.sampled import SampledIterableDataset


class Dataset[SampleType: Sample](abc.ABC):
Expand All @@ -27,6 +28,14 @@ def __getstate__(self):
del state["__orig_class__"]
return state

@property
def requires_broadcast(self) -> bool:
"""
Some dataset schemes load the dataset on a batch-data-parallel group leaders,
then broadcast to the other devices.
"""
return False


class SampledDataset[SampleType: Sample](Dataset[SampleType]):
"""
Expand All @@ -44,6 +53,18 @@ def __len__(self) -> int:


class SamplableDataset[SampleType: Sample](Dataset[SampleType]):

@abc.abstractmethod
def sample(self, config: "SamplingData") -> SampledDataset[SampleType]:
pass


class SamplableIterableDataset[SampleType: Sample](SamplableDataset[SampleType]):
@abc.abstractmethod
def __iter__(self) -> typing.Iterator[SampleType]:
pass

def sample(self, config: "SamplingData") -> "SampledIterableDataset[SampleType]":
from fast_llm.data.dataset.sampled import SampledIterableDataset

return SampledIterableDataset(self, config)
Loading