Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions maester/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class DatasetConfig(BaseSettings):
num_data_workers: int = 1
# col_name: str = "tokens"
# file_type: str = "arrow"
dataset_type: str = "parquet"


class SFTConfig(BaseSettings):
Expand Down
13 changes: 12 additions & 1 deletion maester/datasets/experimental_otf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1309,9 +1309,20 @@ def __init__(

# Build subdataset iterators
self.data = []
match self.cfg.dataset_type:
case "parquet":
dataset_class = ParquetDataset
case "jinx":
try:
from .jinx_dataset import JinxDataset
dataset_class = JinxDataset
except ImportError:
raise ImportError("JinxDataset requires the mldataforge package. Please install it with `pip install mldataforge`.")
case _:
raise ValueError(f"Unsupported dataset type {self.cfg.dataset_type}. Supported types are 'parquet' and 'jinx'.")
for i, d in enumerate(data_dirs):
self.data.append(
ParquetDataset(
dataset_class(
data_dir=d,
rank=rank,
worldsize=worldsize,
Expand Down
25 changes: 25 additions & 0 deletions maester/datasets/jinx_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from mldataforge.jinx import JinxDatasetReader

from .experimental_otf import logger
from .base import ParquetDataset

class JinxDataset(ParquetDataset):
def __init__(self, *args, **kwargs):
self._jinx_readers = {}
super(JinxDataset, self).__init__(*args, **kwargs)

def _gather_doc_count(self, file):
_, reader = self._get_reader(None, file, None)
return len(reader)

def _get_reader(self, path, newpath, reader):
reader = self._jinx_readers.get(newpath, None)
if reader is None:
if self.verbose:
logger.info(f"Worker {self.rank} opening new file {newpath}")
reader = JinxDatasetReader(newpath)
self._jinx_readers[newpath] = reader
return newpath, reader

def _read_specific_row(self, reader, row_index):
return reader[row_index]