Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 110 additions & 14 deletions fast_llm/data/preparator/gpt_memmap/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,30 +15,81 @@
from fast_llm.data.preparator.gpt_memmap.prepare import GPTMemmapDatasetPreparator


@config_class()
@config_class(registry=True)
class LanguageModelSourceConfig(Config):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should loss_masking_spans be moved to TextSourceConfig, since it's not relevant for ConversationSourceConfig?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And same question for chosen/rejected_span and images. Unless we plan to also support those in ConversationSourceConfig?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, this is cleaner!

"""
A schema holding the name of each relevant column in the dataset.
Setting optional entries will enable the associated feature.
Abstract base class for data source schemas.

Use `type: document` (default) for documents with text, optional span annotations, and optional images.
Use `type: conversation` for structured chat/conversation datasets.
"""

@classmethod
def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self:
if cls is LanguageModelSourceConfig and cls.get_subclass(default.get("type")) is None:
# Default to DocumentSourceConfig when type is not specified
return DocumentSourceConfig._from_dict(default, strict)
return super()._from_dict(default, strict=strict)

@functools.cached_property
def columns(self) -> list[str]:
"""Columns to read from the dataset."""
raise NotImplementedError

@functools.cached_property
def has_loss_masking_span(self) -> bool:
return False

@functools.cached_property
def has_preference_spans(self) -> bool:
return False

@functools.cached_property
def has_images(self) -> bool:
return False


@config_class(dynamic_type={LanguageModelSourceConfig: "document"})
class DocumentSourceConfig(LanguageModelSourceConfig):
"""
Source schema for document datasets with text, optional span annotations, and optional images.

The dataset should have a text column containing the document text.
Optionally, it can have additional columns for:
- Loss masking spans: character ranges to mask from loss computation
- Preference spans: chosen/rejected text for DPO training
- Images: image data with character positions for multimodal training
"""

text: str = Field(
default="text",
desc="Field of the dataset to use.",
desc="Field containing the document text.",
hint=FieldHint.optional,
)
loss_masking_spans: str | None = Field(
default=None,
desc="Field containing character spans to mask for loss computation.",
hint=FieldHint.optional,
)
loss_masking_spans: None | str = Field(
default=None, desc="Field containing character spans to mask for loss computation", hint=FieldHint.optional
chosen_span: str | None = Field(
default=None,
desc="Field containing chosen text for preference optimization.",
hint=FieldHint.optional,
)
chosen_span: None | str = Field(
default=None, desc="Field containing chosen text for preference optimization", hint=FieldHint.optional
rejected_span: str | None = Field(
default=None,
desc="Field containing rejected text for preference optimization.",
hint=FieldHint.optional,
)
rejected_span: None | str = Field(
default=None, desc="Field containing rejected text for preference optimization", hint=FieldHint.optional
images: str | None = Field(
default=None,
desc="Field containing images.",
hint=FieldHint.optional,
)
images: None | str = Field(default=None, desc="Field containing images", hint=FieldHint.optional)
image_positions: None | str = Field(
default=None, desc="Field containing image positions in the text.", hint=FieldHint.optional
image_positions: str | None = Field(
default=None,
desc="Field containing image positions in the text.",
hint=FieldHint.optional,
)

@functools.cached_property
Expand All @@ -48,6 +99,8 @@ def columns(self) -> list[str]:
columns.append(self.loss_masking_spans)
if self.has_preference_spans:
columns.extend([self.chosen_span, self.rejected_span])
if self.has_images:
columns.extend([self.images, self.image_positions])
return columns

@functools.cached_property
Expand All @@ -67,7 +120,50 @@ def has_images(self) -> bool:
def _validate(self):
super()._validate()
if self.has_preference_spans and self.has_loss_masking_span:
raise ValueError(f"Can not enable both loss masking and preference spans.")
raise ValueError("Cannot enable both loss masking and preference spans.")


@config_class(dynamic_type={LanguageModelSourceConfig: "conversation"})
class ConversationSourceConfig(LanguageModelSourceConfig):
"""
Source schema for chat/conversation datasets (e.g., Tulu 3, ShareGPT, OpenAI format).

The dataset should have a messages column containing a list of message dicts,
where each message has 'role' and 'content' keys. Common roles include:
- 'system': System prompt
- 'user': User input
- 'assistant': Model response (trained on by default)
- 'tool': Tool/function results
- 'ipython': Code execution results

The conversation is formatted using the tokenizer's chat template, which must
contain {% generation %}...{% endgeneration %} markers to define which content
to train on. Loss masking spans are automatically computed from these markers.

Example dataset format:
{
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Hi there!"},
]
}
"""

messages: str = Field(
default="messages",
desc="Field containing the conversation messages list. Each message should have 'role' and 'content' keys.",
hint=FieldHint.core,
)

@functools.cached_property
def columns(self) -> list[str]:
return [self.messages]

@functools.cached_property
def has_loss_masking_span(self) -> bool:
# Conversation format always generates loss masking spans from chat template markers
return True


@config_class()
Expand Down
187 changes: 108 additions & 79 deletions fast_llm/data/preparator/gpt_memmap/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@
)
from fast_llm.data.dataset.memmap import MemmapDataset
from fast_llm.data.preparator.config import DatasetPreparator
from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig, LanguageModelSourceConfig
from fast_llm.data.preparator.gpt_memmap.config import (
ConversationSourceConfig,
GPTMemmapDatasetPreparatorConfig,
LanguageModelSourceConfig,
DocumentSourceConfig,
)
from fast_llm.data.preprocessing.abstract import NullPreprocessingConfig
from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig
from fast_llm.data.preprocessing.tokenizer import Tokenizer
Expand Down Expand Up @@ -132,6 +137,10 @@ def run(self) -> None:
# Load tokenizer
self._tokenizer = self._config.tokenizer.get_tokenizer()

# Validate chat template for conversation format
if isinstance(self._source_schema, ConversationSourceConfig):
self._tokenizer.validate_chat_template()

# Decide the datatype based on the tokenizer vocabulary size
self._data_type = (
get_unsigned_integer_type(self._tokenizer.vocab_size)
Expand Down Expand Up @@ -216,92 +225,110 @@ def _preprocessing_config(self) -> LanguageModelPreprocessingConfig:
)

def _prepare_sample(self, sample: dict[str, typing.Any]) -> LanguageModelSample:
text = sample[self._source_schema.text]
all_spans = []
if self._source_schema.has_loss_masking_span:
# Spans are typically stored in the (begin, last) format. We convert to (begin, end) range format.
loss_masking_spans = _sort_spans(
(SpanType.loss_masking, (begin, last + 1))
for begin, last in np.array(sample[self._source_schema.loss_masking_spans], dtype=np.int32)
.reshape(-1, 2)
.tolist()
token_spans_by_type = collections.defaultdict(list)
image_patches = image_token_maps = image_position_ids = patch_counts = None

if isinstance(self._source_schema, ConversationSourceConfig):
# Conversation format: tokenize messages and get loss masking spans from chat template
tokens, loss_masking_spans = self._tokenizer.tokenize_chat(
sample[self._source_schema.messages],
True,
True,
data_type=self._data_type,
)
all_spans.extend(loss_masking_spans)

if self._source_schema.has_preference_spans:
full_chosen_text = text + sample[self._source_schema.chosen_span] + self._tokenizer.tokenizer.eos_token
full_rejected_text = self._tokenizer.tokenizer.bos_token + text + sample[self._source_schema.rejected_span]
# compute chosen span
chosen_spans = [(SpanType.chosen, (len(text), len(full_chosen_text)))]

# compute rejected span
rejected_span = [
(
SpanType.rejected,
(
len(full_chosen_text) + len(self._tokenizer.tokenizer.bos_token) + len(text),
len(full_chosen_text) + len(full_rejected_text),
),
token_spans_by_type[SpanType.loss_masking] = loss_masking_spans
elif isinstance(self._source_schema, DocumentSourceConfig):
# Document format: use the text-spans pipeline
text = sample[self._source_schema.text]
all_spans = []

if self._source_schema.has_loss_masking_span:
# Spans are typically stored in the (begin, last) format. We convert to (begin, end) range format.
loss_masking_spans = _sort_spans(
(SpanType.loss_masking, (begin, last + 1))
for begin, last in np.array(sample[self._source_schema.loss_masking_spans], dtype=np.int32)
.reshape(-1, 2)
.tolist()
)
]
# pack texts
text = full_chosen_text + full_rejected_text
all_spans.extend(chosen_spans + rejected_span)

if self._source_schema.has_images:
# Get the images and positions, sorted by position.
images, image_positions = (
zip(
*sorted(
zip(
sample[self._source_schema.images],
sample[self._source_schema.image_positions],
strict=True,
all_spans.extend(loss_masking_spans)

if self._source_schema.has_preference_spans:
full_chosen_text = text + sample[self._source_schema.chosen_span] + self._tokenizer.tokenizer.eos_token
full_rejected_text = (
self._tokenizer.tokenizer.bos_token + text + sample[self._source_schema.rejected_span]
)
# compute chosen span
chosen_spans = [(SpanType.chosen, (len(text), len(full_chosen_text)))]

# compute rejected span
rejected_span = [
(
SpanType.rejected,
(
len(full_chosen_text) + len(self._tokenizer.tokenizer.bos_token) + len(text),
len(full_chosen_text) + len(full_rejected_text),
),
key=lambda x: x[1],
)
]
# pack texts
text = full_chosen_text + full_rejected_text
all_spans.extend(chosen_spans + rejected_span)

if self._source_schema.has_images:
# Get the images and positions, sorted by position.
images, image_positions = (
zip(
*sorted(
zip(
sample[self._source_schema.images],
sample[self._source_schema.image_positions],
strict=True,
),
key=lambda x: x[1],
)
)
if len(sample[self._source_schema.images]) > 0
else ([], [])
)
if len(sample[self._source_schema.images]) > 0
else ([], [])
)
# Get the image patches and associated data.
image_patches, image_position_ids, image_token_maps, image_token_ids, patch_counts = (
self._config.image_patches.get_patches_from_images(images, self._data_type)
# Get the image patches and associated data.
image_patches, image_position_ids, image_token_maps, image_token_ids, patch_counts = (
self._config.image_patches.get_patches_from_images(images, self._data_type)
)
patch_count_cumsum = padded_cumsum(patch_counts).tolist()
# Add an empty "span" at each image position so we know where to insert them in the tokenized sequence.
all_spans.extend([(SpanType.image, (position, position)) for position in image_positions])

# Sort the spans by location (begin), keeping track of their type.
# Note: overlapping spans are not supported (explicit assertion in the tokenizer).
span_types, spans = zip(*_sort_spans(all_spans)) if all_spans else ([], [])
# Tokenize the text, and determine the span locations in the tokenized text.
tokens, token_spans = self._tokenizer.tokenize_with_spans(
text, True, True, text_spans=spans, data_type=self._data_type
)
patch_count_cumsum = padded_cumsum(patch_counts).tolist()
# Add an empty "span" at each image position so we know where to insert them in the tokenized sequence.
all_spans.extend([(SpanType.image, (position, position)) for position in image_positions])

# Sort the spans by location (begin), keeping track of their type.
# Note: overlapping spans are not supported (explicit assertion in the tokenizer).
span_types, spans = zip(*_sort_spans(all_spans)) if all_spans else ([], [])
# Tokenize the text, and determine the span locations in the tokenized text.
tokens, token_spans = self._tokenizer.tokenize_with_spans(
text, True, True, text_spans=spans, data_type=self._data_type
)

# Gather token spans by type.
token_spans_by_type = collections.defaultdict(list)
if self._source_schema.has_images:
# Insert the image token ids in the token sequence and shift the spans accordingly.
tokens_shift = 0
image_index = 0
for span_type, (begin, end) in zip(span_types, token_spans, strict=True):
# Account for the tokens already inserted.
begin = begin + tokens_shift
end = end + tokens_shift
if span_type == SpanType.image:
# Shift the token map to the image location.
image_token_maps[patch_count_cumsum[image_index] : patch_count_cumsum[image_index + 1]] += begin
# Insert the placeholder and image break tokens.
tokens = torch.cat([tokens[:begin], image_token_ids[image_index], tokens[begin:]])
tokens_shift += len(image_token_ids[image_index])
image_index += 1
else:
token_spans_by_type[span_type].append((begin, end))
# Gather token spans by type.
if self._source_schema.has_images:
# Insert the image token ids in the token sequence and shift the spans accordingly.
tokens_shift = 0
image_index = 0
for span_type, (begin, end) in zip(span_types, token_spans, strict=True):
# Account for the tokens already inserted.
begin = begin + tokens_shift
end = end + tokens_shift
if span_type == SpanType.image:
# Shift the token map to the image location.
image_token_maps[patch_count_cumsum[image_index] : patch_count_cumsum[image_index + 1]] += begin
# Insert the placeholder and image break tokens.
tokens = torch.cat([tokens[:begin], image_token_ids[image_index], tokens[begin:]])
tokens_shift += len(image_token_ids[image_index])
image_index += 1
else:
token_spans_by_type[span_type].append((begin, end))
else:
for span_type, token_span in zip(span_types, token_spans, strict=True):
token_spans_by_type[span_type].append(token_span)
else:
for span_type, token_span in zip(span_types, token_spans, strict=True):
token_spans_by_type[span_type].append(token_span)
raise NotImplementedError(f"Unsupported source schema type: {type(self._source_schema)}")

sample_size = len(tokens)

Expand Down Expand Up @@ -479,3 +506,5 @@ def _get_nearest_split(cumsum: np.ndarray, value: float) -> int:
if left == len(cumsum):
return left.item()
return left.item() + 1 if (value - cumsum[left]) / (cumsum[left + 1] - cumsum[left]) > 0.5 else left.item()


Loading