From 2947feaf25911bc726e47ba669f01f948b6625f4 Mon Sep 17 00:00:00 2001 From: Nikolay Karpov Date: Thu, 26 Jun 2025 08:30:37 -0700 Subject: [PATCH 01/13] add backend Signed-off-by: Nikolay Karpov --- sdp/processors/base_processor.py | 129 +++++++++++++++++++------------ sdp/run_processors.py | 73 ++++++++++------- tests/test_curator.py | 98 +++++++++++++++++++++++ tests/test_tts_sdp_end_to_end.py | 32 ++++---- 4 files changed, 240 insertions(+), 92 deletions(-) create mode 100644 tests/test_curator.py diff --git a/sdp/processors/base_processor.py b/sdp/processors/base_processor.py index 6fc22ee8..a9ff16d3 100644 --- a/sdp/processors/base_processor.py +++ b/sdp/processors/base_processor.py @@ -22,6 +22,8 @@ from itertools import chain from typing import Any, Dict, List, Optional, Union +from ray_curator.stages.base import ProcessingStage +from ray_curator.tasks import _EmptyTask from tqdm import tqdm from tqdm.contrib.concurrent import process_map @@ -59,7 +61,6 @@ class BaseProcessor(ABC): """ def __init__(self, output_manifest_file: str, input_manifest_file: Optional[str] = None, **kwargs): - if output_manifest_file and input_manifest_file and (output_manifest_file == input_manifest_file): # we cannot have the same input and output manifest file specified because we need to be able to # read from the input_manifest_file and write to the output_manifest_file at the same time @@ -83,7 +84,8 @@ def test(self): There are not tests by default. """ -class BaseParallelProcessor(BaseProcessor): + +class BaseParallelProcessor(BaseProcessor, ProcessingStage[_EmptyTask, _EmptyTask]): """ A processor that performs per-entry processing in parallel (using Dask or multiprocessing). @@ -94,10 +96,10 @@ class BaseParallelProcessor(BaseProcessor): chunksize (int): Chunk size used for parallel routines. in_memory_chunksize (int): Maximum number of entries to load at once. test_cases (list[dict]): Optional list of test cases. - use_dask (bool): If True, use Dask for parallelization; otherwise, use multiprocessing. + use_backend (str): Use {None, dask, curator} for parallelization. Use None for multiprocessing. dask_client: (Optional) An existing Dask client. """ - + def __getstate__(self): state = self.__dict__.copy() # Remove the Dask client from state (it is not picklable) @@ -113,11 +115,11 @@ def __init__( chunksize: int = 100, in_memory_chunksize: int = 100000, test_cases: Optional[List[Dict]] = None, - use_dask: bool = True, + use_backend: bool = True, dask_client=None, **kwargs, ): - kwargs.pop("use_dask", None) # + kwargs.pop("use_backend", None) # super().__init__(input_manifest_file=input_manifest_file, output_manifest_file=output_manifest_file, **kwargs) if max_workers == -1: max_workers = os.cpu_count() @@ -128,33 +130,41 @@ def __init__( self.total_duration = 0 self.start_time = time.time() self.test_cases = test_cases or [] - self.use_dask = use_dask + self.use_backend = use_backend self.dask_client = dask_client - + def prepare(self): - """Can be used in derived classes to prepare the processing. - - """ + """Can be used in derived classes to prepare the processing.""" pass - def process(self): - """A fork in the road to pick dask or classic processing - - """ + def process(self, task: _EmptyTask) -> _EmptyTask: + """A fork in the road to pick dask or classic processing""" os.environ.setdefault("PATH", os.defpath) self.prepare() - + os.makedirs(os.path.dirname(self.output_manifest_file), exist_ok=True) metrics = [] - - #Ability to work sa legacy and as dask - if self.use_dask: + + # Ability to work sa legacy and as dask + if self.use_backend == "dask": self._process_with_dask(metrics) else: self._process_with_multiprocessing(metrics) self.finalize(metrics) + return _EmptyTask(task_id="empty", dataset_name="empty", data=None) + + def inputs(self) -> tuple[list[str], list[str]]: + return [], [] + + def outputs(self) -> tuple[list[str], list[str]]: + return [], [] + + @property + def name(self) -> str: + return "BaseParallelProcessor" + def _process_with_dask(self, metrics): import dask.bag as db from dask.distributed import Client @@ -162,7 +172,8 @@ def _process_with_dask(self, metrics): if self.dask_client is None: self.dask_client = Client() client = self.dask_client - from sdp.logging import logger + from sdp.logging import logger + logger.info(f"Using Dask client with dashboard at: {client.dashboard_link}") # Delegate manifest reading to read_manifest() which returns a Dask bag. @@ -211,10 +222,10 @@ def _process_with_multiprocessing(self, metrics): def _chunk_manifest(self): """Splits the input manifest into chunks of in_memory_chunksize size. - Only used in non-Dask (multiprocessing) mode. + Only used in non-Dask (multiprocessing) mode. """ manifest_chunk = [] - # When use_dask is False, read_manifest() returns an iterator. + # When use_backend is False, read_manifest() returns an iterator. for idx, data_entry in enumerate(self.read_manifest(), 1): manifest_chunk.append(data_entry) if idx % self.in_memory_chunksize == 0: @@ -226,38 +237,51 @@ def _chunk_manifest(self): def read_manifest(self): """ Reads entries from the input manifest. - + Behavior depends on the parallelization mode: - - When use_dask is True: + - When use_backend is "dask": If the input_manifest_file exists and is non-empty, returns a Dask bag (reading in 256KB blocks). Otherwise, logs the condition and returns an empty Dask bag. - - When use_dask is False: + - When use_backend is "curator": + ToDo + - When use_backend is None: If the input_manifest_file does not exist or is empty, logs the condition and returns an empty iterator. Otherwise, opens the file in text mode, strips each line, and yields the parsed JSON from non-empty lines. - + This unified behavior lets the processor run even in manifest-creation mode. """ - from sdp.logging import logger - if self.use_dask: + from sdp.logging import logger + + if self.use_backend == "dask": import dask.bag as db - if self.input_manifest_file and os.path.exists(self.input_manifest_file) and os.path.getsize(self.input_manifest_file) > 0: + + if ( + self.input_manifest_file + and os.path.exists(self.input_manifest_file) + and os.path.getsize(self.input_manifest_file) > 0 + ): bag = db.read_text(self.input_manifest_file, blocksize=2**18).map(json.loads) return bag else: - logger.info("No input manifest file provided or file is empty. Returning an empty Dask bag for manifest creation.") + logger.info( + "No input manifest file provided or file is empty. Returning an empty Dask bag for manifest creation." + ) return db.from_sequence([]) else: if not self.input_manifest_file or not os.path.exists(self.input_manifest_file): - logger.info("No input manifest file provided or file does not exist. Continuing with an empty manifest.") + logger.info( + "No input manifest file provided or file does not exist. Continuing with an empty manifest." + ) return iter([]) - else: - #if use_dask = False, we get here - def generator(): #Reading manifest line by line, adding only non emply lines + else: + # if self.use_backend = None, we get here + def generator(): # Reading manifest line by line, adding only non emply lines with open(self.input_manifest_file, "rt", encoding="utf8") as fin: for line in fin: - if line: - yield json.loads(line) + if line: + yield json.loads(line) + return generator() @abstractmethod @@ -271,38 +295,43 @@ def process_dataset_entry(self, data_entry) -> List[Any]: def finalize(self, metrics: List[Any]): """Outputs metrics about the processed data.""" from sdp.logging import logger + logger.info("Total number of entries after processing: %d", self.number_of_entries) if self.total_duration: logger.info("Total audio duration (hours) after processing: %.2f", self.total_duration / 3600) else: - logger.info("Unable to calculate total audio duration (hours). Ensure that the manifest file includes a 'duration' key.") + logger.info( + "Unable to calculate total audio duration (hours). Ensure that the manifest file includes a 'duration' key." + ) elapsed = time.time() - self.start_time logger.info("Processor completed in (seconds): %.2f", elapsed) def test(self): - """Applies processing to each test case and raises an error if the output does not match expected output.""" + """Applies processing to each test case and raises an error if the output does not match expected output.""" for test_case in self.test_cases: input_data = test_case["input"].copy() if isinstance(test_case["input"], dict) else test_case["input"] generated_outputs = self.process_dataset_entry(input_data) - expected_outputs = [test_case["output"]] if not isinstance(test_case["output"], list) else test_case["output"] + expected_outputs = ( + [test_case["output"]] if not isinstance(test_case["output"], list) else test_case["output"] + ) for gen_out, exp_out in zip(generated_outputs, expected_outputs): gen_data = gen_out.data if hasattr(gen_out, "data") else gen_out if gen_data != exp_out: raise RuntimeError( - "Runtime test failed.\nTest input: {}\nGenerated output: {}\nExpected output: {}" - .format(test_case["input"], gen_data, exp_out) + "Runtime test failed.\nTest input: {}\nGenerated output: {}\nExpected output: {}".format( + test_case["input"], gen_data, exp_out + ) ) - # ------------------ Legacy Parallel Processor ------------------ #Just for reference class LegacyParallelProcessor(BaseProcessor): """ A legacy parallel processor implementation using multiprocessing and process_map. - + This class processes the manifest in chunks (using process_map) and is provided for compatibility. Child classes must implement process_dataset_entry(). - + Args: max_workers (int): maximum number of workers that will be spawned during the parallel processing. @@ -313,12 +342,13 @@ class LegacyParallelProcessor(BaseProcessor): test_cases (list[dict]): an optional list of dicts containing test cases for checking that the processor makes the changes that we are expecting. - + The dicts must have a key ``input``, the value of which is a dictionary containing data which is our test's input manifest line, and a key ``output``, the value of which is a dictionary containing data which is the expected output manifest line. """ + def __init__( self, max_workers: int = -1, @@ -327,7 +357,7 @@ def __init__( test_cases: Optional[List[Dict]] = None, **kwargs, ): - kwargs.pop("use_dask", None) # + kwargs.pop("use_backend", None) # super().__init__(**kwargs) if max_workers == -1: max_workers = multiprocessing.cpu_count() @@ -479,9 +509,12 @@ def finalize(self, metrics): if self.total_duration: logger.info("Total audio duration (hours) after processing (legacy): %.2f", self.total_duration / 3600) else: - logger.info("Unable to calculate total audio duration (legacy). Please ensure that the manifest file includes a 'duration' key.") + logger.info( + "Unable to calculate total audio duration (legacy). Please ensure that the manifest file includes a 'duration' key." + ) elapsed = time.time() - self.start_time logger.info("Legacy processor completed in (seconds): %.2f", elapsed) + def test(self): """Applies processing to "test_cases" and raises an error in case of mismatch.""" for test_case in self.test_cases: @@ -499,4 +532,4 @@ def test(self): f"Test input: {test_case['input']}\n" f"Generated output: {generated_output}\n" f"Expected output: {expected_output}" - ) \ No newline at end of file + ) diff --git a/sdp/run_processors.py b/sdp/run_processors.py index 8c498cf2..471064f2 100644 --- a/sdp/run_processors.py +++ b/sdp/run_processors.py @@ -12,19 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import logging import os import tempfile import uuid from typing import List, Optional -import psutil -import json import hydra +import psutil from omegaconf import OmegaConf, open_dict +from ray_curator.backends.xenna import XennaExecutor +from ray_curator.pipeline import Pipeline +from ray_curator.tasks import _EmptyTask from sdp.logging import logger - from sdp.utils.import_manager import ImportManager # registering new resolvers to simplify config files @@ -46,16 +48,18 @@ logger.addHandler(handler) logger.propagate = False + def update_processor_imports(config_path: str, init_file: str = None): """ Update processor imports based on config file. - + Args: config_path: Path to the YAML config file init_file: Optional path to __init__.py file to update """ try: import yaml + manager = ImportManager() manager.sync_with_config(config_path, init_file) logger.info(f"Successfully updated imports for config: {config_path}") @@ -120,13 +124,14 @@ def run_processors(cfg): if cfg.get("use_import_manager", False): try: import yaml + yaml_path = cfg.get("config_path") if not yaml_path: raise ValueError("No configuration path provided in 'config_path'. Please specify the path.") if not os.path.exists(yaml_path): raise FileNotFoundError(f"Configuration file not found: {yaml_path}") - + logger.info(f"Managing imports for config: {yaml_path}") manager = ImportManager() manager.sync_with_config(yaml_path) @@ -144,19 +149,23 @@ def run_processors(cfg): # Detecting dask try: from dask.distributed import Client + dask_available = True except ImportError: logger.warning("Dask not installed; using multiprocessing for all processors") dask_available = False - + # look for global directions in cfg for dask usage - global_use_dask = bool(cfg.get("use_dask", True)) and dask_available + if bool(cfg.get("use_backend", None) == "dask") and dask_available: + global_use_backend = "dask" + else: + global_use_backend = cfg.get("use_backend", None) processors_to_run = cfg.get("processors_to_run", "all") if processors_to_run == "all": processors_to_run = ":" selected_cfgs = select_subset(cfg.processors, processors_to_run) - + # filtering out any processors that have should_run=False processors_cfgs = [] for processor_cfg in selected_cfgs: @@ -169,9 +178,7 @@ def run_processors(cfg): "Specified to run the following processors: %s ", [proc_cfg["_target_"] for proc_cfg in processors_cfgs], ) - - - + processors = [] # Create a temporary directory to hold intermediate files if needed. with tempfile.TemporaryDirectory() as tmp_dir: @@ -187,7 +194,7 @@ def run_processors(cfg): with open_dict(processors_cfgs[0]): processors_cfgs[0]["input_manifest_file"] = cfg.processors[idx - 1]["output_manifest_file"] break - + for idx, processor_cfg in enumerate(processors_cfgs): logger.info('=> Building processor "%s"', processor_cfg["_target_"]) @@ -205,27 +212,26 @@ def run_processors(cfg): if idx != len(processors_cfgs) - 1 and "input_manifest_file" not in processors_cfgs[idx + 1]: with open_dict(processors_cfgs[idx + 1]): processors_cfgs[idx + 1]["input_manifest_file"] = processor_cfg["output_manifest_file"] - - #check if we have processor level directions of using dask - flag=processor_cfg.get("use_dask", None) + + # check if we have processor level directions of using dask + flag = processor_cfg.get("use_backend", None) # if no processor-specific flag, fallback to global; otherwise use provided value if flag is None: - use_dask_flag = global_use_dask + use_backend_flag = global_use_backend else: - use_dask_flag = flag + use_backend_flag = flag processor = hydra.utils.instantiate(processor_cfg) - processor.use_dask = use_dask_flag + processor.use_backend = use_backend_flag # running runtime tests to fail right-away if something is not # matching users expectations processor.test() processors.append(processor) - # Start Dask client if any processor requires it dask_client = None - if any(p.use_dask for p in processors): + if any(p.use_backend for p in processors): try: num_cpus = psutil.cpu_count(logical=False) or 4 logger.info(f"Starting Dask client with {num_cpus} workers") @@ -237,17 +243,28 @@ def run_processors(cfg): # Run processors in order try: - for proc in processors: - if proc.use_dask and dask_client is not None: - proc.dask_client = dask_client - logger.info('=> Running processor "%s" with Dask', proc) - else: - logger.info('=> Running processor "%s" with Multiprocessing', proc) - proc.process() + if global_use_backend == "curator": + pipeline = Pipeline(name="processing", description="Process data from JSONL files") + for p in cfg.processors: + stage = hydra.utils.instantiate(processor_cfg) + pipeline.add_stage(stage) + + executor = XennaExecutor() + results = pipeline.run(executor) + # raise ValueError("results", results) + else: + for proc in processors: + if proc.use_backend == "dask" and dask_client is not None: + proc.dask_client = dask_client + logger.info('=> Running processor "%s" with Dask', proc) + else: + logger.info('=> Running processor "%s" with Multiprocessing', proc) + proc.process(_EmptyTask(task_id="empty", dataset_name="empty", data=None)) finally: if dask_client is not None: logger.info("Shutting down Dask client...") dask_client.close(timeout="60s") logger.info("Dask client shutdown complete") -#tmp_dir is removed here after all processing finishes. !!! + +# tmp_dir is removed here after all processing finishes. !!! diff --git a/tests/test_curator.py b/tests/test_curator.py new file mode 100644 index 00000000..21b4746c --- /dev/null +++ b/tests/test_curator.py @@ -0,0 +1,98 @@ +import json +import os +import sys +import tempfile +from pathlib import Path + +import hydra +import yaml +from hydra import compose, initialize +from omegaconf import DictConfig, OmegaConf, open_dict + +from sdp.run_processors import run_processors, update_processor_imports +from sdp.utils import BootstrapProcessor + + +def _write_config(file_path: Path, dict_conf): + with file_path.open("w") as file: + yaml.dump(dict_conf, file) + + +def read_yaml(config_path=".", config_name="config"): + with initialize(version_base=None, config_path=config_path): + cfg = compose(config_name=config_name) + return cfg + + +def make_dict(output_manifest_file, use_backend=None): + workspace_dir = os.path.join(os.getenv('TEST_DATA_ROOT'), "armenian/audio_books/mp3") + return { + "processors_to_run": "0:", + "use_backend": use_backend, + "processors": [ + { + "_target_": "sdp.processors.CreateInitialManifestByExt", + "raw_data_dir": workspace_dir, + "extension": "mp3", + "output_file_key": "audio_filepath", + "output_manifest_file": output_manifest_file, + }, + ], + } + + +def make_expected_output(): + workspace_dir = os.path.join(os.getenv('TEST_DATA_ROOT'), "armenian/audio_books/mp3") + return {'audio_filepath': os.path.join(workspace_dir, "Eleonora/Eleonora30s.mp3")} + + +def test_curator(): + with tempfile.TemporaryDirectory() as tmpdir: + output_path = os.path.join(tmpdir, "output_manifest_file.jsonl") + dict_conf = make_dict(output_manifest_file=output_path, use_backend="curator") + conf_path = Path(tmpdir) / "config.yaml" + _write_config(conf_path, dict_conf) + + cfg = OmegaConf.load(conf_path) + + run_processors(cfg) + with open(output_path, "r") as f: + output = json.load(f) + + expected_output = make_expected_output() + + assert output == expected_output, f"Expected {expected_output}, but got {output}" + + +def test_multiprocessing(): + with tempfile.TemporaryDirectory() as tmpdir: + output_path = os.path.join(tmpdir, "output_manifest_file.jsonl") + dict_conf = make_dict(output_manifest_file=output_path, use_backend=None) + conf_path = Path(tmpdir) / "config.yaml" + _write_config(conf_path, dict_conf) + + cfg = OmegaConf.load(conf_path) + + run_processors(cfg) + with open(output_path, "r") as f: + output = json.load(f) + + expected_output = make_expected_output() + assert output == expected_output, f"Expected {expected_output}, but got {output}" + + +def test_dask(): + with tempfile.TemporaryDirectory() as tmpdir: + output_path = os.path.join(tmpdir, "output_manifest_file.jsonl") + dict_conf = make_dict(output_manifest_file=output_path, use_backend="dask") + conf_path = Path(tmpdir) / "config.yaml" + _write_config(conf_path, dict_conf) + + cfg = OmegaConf.load(conf_path) + + run_processors(cfg) + with open(output_path, "r") as f: + output = json.load(f) + + expected_output = make_expected_output() + assert output == expected_output, f"Expected {expected_output}, but got {output}" diff --git a/tests/test_tts_sdp_end_to_end.py b/tests/test_tts_sdp_end_to_end.py index 291a8834..c50bed93 100644 --- a/tests/test_tts_sdp_end_to_end.py +++ b/tests/test_tts_sdp_end_to_end.py @@ -1,37 +1,37 @@ -import pytest -import ndjson -import boto3 import json import os import tarfile from pathlib import Path + +import boto3 +import pytest from omegaconf import OmegaConf + from sdp.run_processors import run_processors DATASET_CONFIGS_ROOT = Path(__file__).parents[1] / "dataset_configs" + @pytest.fixture def get_tts_ytc_data(tmpdir: str): # Download the data from S3 s3 = boto3.client( - 's3', - aws_access_key_id=os.getenv("AWS_ACCESS_KEY"), - aws_secret_access_key=os.getenv("AWS_SECRET_KEY") + 's3', aws_access_key_id=os.getenv("AWS_ACCESS_KEY"), aws_secret_access_key=os.getenv("AWS_SECRET_KEY") ) s3.download_file( - "sdp-test-data", - "test_data/tts/ytc/test_data_reference.json", - tmpdir/"test_data_reference.json", + "sdp-test-data", + "test_data/tts/ytc/test_data_reference.json", + tmpdir / "test_data_reference.json", ) s3.download_file( - "sdp-test-data", - "test_data/tts/ytc/ytc.en.tar.gz", - tmpdir/"ytc.en.tar.gz", + "sdp-test-data", + "test_data/tts/ytc/ytc.en.tar.gz", + tmpdir / "ytc.en.tar.gz", ) # Extract the tar.gz file - with tarfile.open(tmpdir/"ytc.en.tar.gz", "r:gz") as tar: + with tarfile.open(tmpdir / "ytc.en.tar.gz", "r:gz") as tar: tar.extractall(tmpdir) audio_files = Path(tmpdir).glob("audios/*") @@ -45,6 +45,7 @@ def get_tts_ytc_data(tmpdir: str): return tmpdir + def test_tts_sdp_end_to_end(get_tts_ytc_data): data_dir = get_tts_ytc_data assert os.path.exists(data_dir) @@ -72,15 +73,14 @@ def test_tts_sdp_end_to_end(get_tts_ytc_data): output_data = ndjson.load(f) for item in output_data: output_file_data[item["audio_item_id"]] = item - + reference_file_data = {} with open(reference_manifest_file, "r") as f: reference_data = ndjson.load(f) for item in reference_data: reference_file_data[item["audio_item_id"]] = item - + assert len(output_file_data) == len(reference_file_data) assert len(output_file_data) == 2 for audio_item_id in output_file_data: assert output_file_data[audio_item_id]["segments"] == reference_file_data[audio_item_id]["segments"] - From 846a633d56764768af4c9078270619292c4e86ff Mon Sep 17 00:00:00 2001 From: Nikolay Karpov Date: Fri, 27 Jun 2025 10:23:54 -0700 Subject: [PATCH 02/13] BaseProcessor Signed-off-by: Nikolay Karpov --- requirements/curator.txt | 26 +++ sdp/processors/base_processor.py | 10 + .../fleurs/create_initial_manifest.py | 5 +- .../hifitts2/remove_failed_chapters.py | 5 +- .../librispeech/create_initial_manifest.py | 7 +- sdp/processors/datasets/mls/restore_pc.py | 4 +- .../slr140/create_initial_manifest.py | 4 +- .../datasets/slr83/create_initial_manifest.py | 4 +- .../uzbekvoice/create_initial_manifest.py | 25 ++- .../huggingface/speech_recognition.py | 11 +- sdp/processors/langs/armenian.py | 7 +- sdp/processors/modify_manifest/common.py | 34 ++-- .../modify_manifest/data_to_data.py | 64 ++++--- .../modify_manifest/data_to_dropbool.py | 39 ++-- sdp/processors/nemo/asr_inference.py | 9 +- tests/test_curator.py | 45 ++--- tests/test_manifest_chunking.py | 172 +++++++++--------- 17 files changed, 272 insertions(+), 199 deletions(-) create mode 100644 requirements/curator.txt diff --git a/requirements/curator.txt b/requirements/curator.txt new file mode 100644 index 00000000..d55c4609 --- /dev/null +++ b/requirements/curator.txt @@ -0,0 +1,26 @@ +cd ray-api + +# pip install cosmos-xenna[gpu] +git clone https://github.com/NVIDIA-NeMo/Curator.git +git switch ray-api +# install NeMo +pip install "nemo_toolkit[all]" + +# install nemo_text_processing +pip install nemo_text_processing + +pip install -r requirements/main.txt +pip install -r requirements/tests.txt +pip install . +RAY_ADDRESS=10.110.41.40:8265 python -m pytest tests/test_curator.py + +# pip install loguru +# pip install -U "ray[default]" + +# cd ~/workspace/Curator/ray-curator && pip install . +# ray start --include-dashboard=True --head # ray status # ray stop + # import ray + # ray.init() + # RAY_ADDRESS='http://127.0.0.1:8265' ray job submit --working-dir . -- python my_script.py + +RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES="0" RAY_MAX_LIMIT_FROM_API_SERVER=40000 RAY_MAX_LIMIT_FROM_DATA_SOURCE=40000 ray start --include-dashboard=True --dashboard-host=0.0.0.0 --port=8265 --dashboard-port=8266 --head --temp-dir=/tmp diff --git a/sdp/processors/base_processor.py b/sdp/processors/base_processor.py index a9ff16d3..35cc8729 100644 --- a/sdp/processors/base_processor.py +++ b/sdp/processors/base_processor.py @@ -84,6 +84,16 @@ def test(self): There are not tests by default. """ + def inputs(self) -> tuple[list[str], list[str]]: + return [], [] + + def outputs(self) -> tuple[list[str], list[str]]: + return [], [] + + @property + def name(self) -> str: + return "BaseProcessor" + class BaseParallelProcessor(BaseProcessor, ProcessingStage[_EmptyTask, _EmptyTask]): """ diff --git a/sdp/processors/datasets/fleurs/create_initial_manifest.py b/sdp/processors/datasets/fleurs/create_initial_manifest.py index d571593a..7967c5b9 100644 --- a/sdp/processors/datasets/fleurs/create_initial_manifest.py +++ b/sdp/processors/datasets/fleurs/create_initial_manifest.py @@ -20,6 +20,8 @@ import typing from urllib.parse import parse_qs, urlparse +from ray_curator.tasks import _EmptyTask + from sdp.processors.base_processor import BaseProcessor, DataEntry from sdp.utils.common import download_file, extract_archive @@ -145,6 +147,7 @@ def download_extract_files(self, dst_folder: str) -> None: os.remove(file_path) print(f'File {file_name} already exists in {target_folder}, deleted from source.') - def process(self): + def process(self, task: _EmptyTask) -> _EmptyTask: self.download_extract_files(self.raw_data_dir) self.process_data(self.raw_data_dir, self.output_manifest_file) + return _EmptyTask(task_id="empty", dataset_name="empty", data=None) diff --git a/sdp/processors/datasets/hifitts2/remove_failed_chapters.py b/sdp/processors/datasets/hifitts2/remove_failed_chapters.py index b4cd5a8b..dff513a3 100644 --- a/sdp/processors/datasets/hifitts2/remove_failed_chapters.py +++ b/sdp/processors/datasets/hifitts2/remove_failed_chapters.py @@ -15,6 +15,8 @@ import json from pathlib import Path + +from ray_curator.tasks import _EmptyTask from tqdm import tqdm from sdp.processors.base_processor import BaseProcessor @@ -49,7 +51,7 @@ def __init__( super().__init__(**kwargs) self.error_file = Path(error_file) - def process(self): + def process(self, task: _EmptyTask) -> _EmptyTask: chapter_rows = load_manifest(self.error_file) audio_files_to_remove = set() for chapter_row in chapter_rows: @@ -64,3 +66,4 @@ def process(self): output_line = f"{json.dumps(row, ensure_ascii=False)}\n" output_f.write(output_line) + return _EmptyTask(task_id="empty", dataset_name="empty", data=None) diff --git a/sdp/processors/datasets/librispeech/create_initial_manifest.py b/sdp/processors/datasets/librispeech/create_initial_manifest.py index 83d42bde..86fec1dc 100644 --- a/sdp/processors/datasets/librispeech/create_initial_manifest.py +++ b/sdp/processors/datasets/librispeech/create_initial_manifest.py @@ -18,6 +18,8 @@ import os import typing +from ray_curator.tasks import _EmptyTask + from sdp.processors.base_processor import BaseProcessor from sdp.utils.common import download_file, extract_archive @@ -94,7 +96,7 @@ def process_transcript(self, file_path: str) -> list[dict[str, typing.Any]]: entries = [] root = os.path.dirname(file_path) - print(f"Processing transcript file: {file_path}") + print(f"Processing transcript file: {file_path}") with open(file_path, encoding="utf-8") as fin: for line in fin: id, text = line[: line.index(" ")], line[line.index(" ") + 1 :] @@ -135,6 +137,7 @@ def download_extract_files(self, dst_folder: str) -> None: data_file = f'{dst_folder}/{self.split}.tar.gz' extract_archive(str(data_file), str(dst_folder), force_extract=True) - def process(self): + def process(self, task: _EmptyTask) -> _EmptyTask: self.download_extract_files(self.raw_data_dir) self.process_data(self.raw_data_dir, self.output_manifest_file) + return _EmptyTask(task_id="empty", dataset_name="empty", data=None) diff --git a/sdp/processors/datasets/mls/restore_pc.py b/sdp/processors/datasets/mls/restore_pc.py index 33ff22b0..a79116fc 100644 --- a/sdp/processors/datasets/mls/restore_pc.py +++ b/sdp/processors/datasets/mls/restore_pc.py @@ -24,6 +24,7 @@ import regex from joblib import Parallel, delayed +from ray_curator.tasks import _EmptyTask from tqdm import tqdm from sdp.logging import logger @@ -454,7 +455,7 @@ def __init__( self.n_jobs = n_jobs self.show_conversion_breakdown = show_conversion_breakdown - def process(self): + def process(self, task: _EmptyTask) -> _EmptyTask: """Main processing happens here. * Download & extract lv_text. @@ -604,3 +605,4 @@ def process(self): with open(manifest, "r") as fin: for line in fin: fout.write(line) + return _EmptyTask(task_id="empty", dataset_name="empty", data=None) diff --git a/sdp/processors/datasets/slr140/create_initial_manifest.py b/sdp/processors/datasets/slr140/create_initial_manifest.py index 2da79027..2d701bf5 100644 --- a/sdp/processors/datasets/slr140/create_initial_manifest.py +++ b/sdp/processors/datasets/slr140/create_initial_manifest.py @@ -19,6 +19,7 @@ import numpy as np import sox +from ray_curator.tasks import _EmptyTask from tqdm import tqdm from tqdm.contrib.concurrent import thread_map @@ -145,7 +146,7 @@ def __init__(self, data_split: str, split_audio_dir: str, **kwargs): self.data_split = data_split self.split_audio_dir = split_audio_dir - def process(self): + def process(self, task: _EmptyTask) -> _EmptyTask: with open(self.input_manifest_file, "rt", encoding="utf8") as fin: manifest_data = [json.loads(line) for line in fin.readlines()] @@ -190,6 +191,7 @@ def process(self): logger.info("Total number of entries after processing: %d", number_of_entries) logger.info("Total audio duration (hours) after processing: %.2f", total_duration / 3600) + return _EmptyTask(task_id="empty", dataset_name="empty", data=None) def _accumulate_samples( self, manifest_data: List[dict], sample_idxs: List[int], duration_threshold: int diff --git a/sdp/processors/datasets/slr83/create_initial_manifest.py b/sdp/processors/datasets/slr83/create_initial_manifest.py index 030360f7..f5edd887 100644 --- a/sdp/processors/datasets/slr83/create_initial_manifest.py +++ b/sdp/processors/datasets/slr83/create_initial_manifest.py @@ -19,6 +19,7 @@ import numpy as np import sox +from ray_curator.tasks import _EmptyTask from tqdm import tqdm from sdp.logging import logger @@ -192,7 +193,7 @@ def __init__(self, dialect, data_split, **kwargs): self.dialect = dialect self.data_split = data_split - def process(self): + def process(self, task: _EmptyTask) -> _EmptyTask: with open(self.input_manifest_file, "rt", encoding="utf8") as fin: manifest_data = [json.loads(line) for line in fin.readlines()] @@ -238,6 +239,7 @@ def process(self): logger.info("Total number of entries after processing: %d", number_of_entries) logger.info("Total audio duration (hours) after processing: %.2f", total_duration / 3600) + return _EmptyTask(task_id="empty", dataset_name="empty", data=None) def _accumulate_samples( self, manifest_data: List[dict], sample_idxs: List[int], duration_threshold: int diff --git a/sdp/processors/datasets/uzbekvoice/create_initial_manifest.py b/sdp/processors/datasets/uzbekvoice/create_initial_manifest.py index 27117f2a..ee272dc9 100644 --- a/sdp/processors/datasets/uzbekvoice/create_initial_manifest.py +++ b/sdp/processors/datasets/uzbekvoice/create_initial_manifest.py @@ -16,11 +16,13 @@ import json import os import typing + import gdown +from ray_curator.tasks import _EmptyTask +from sdp.logging import logger from sdp.processors.base_processor import BaseProcessor from sdp.utils.common import extract_archive -from sdp.logging import logger class CreateInitialManifestUzbekvoice(BaseProcessor): @@ -30,7 +32,7 @@ class CreateInitialManifestUzbekvoice(BaseProcessor): Will download all files, extract them, and create a manifest file with the "audio_filepath", "text" and "duration" fields. - Args: + Args: raw_data_dir (str): Path to the folder where the data archive should be downloaded and extracted. Returns: @@ -59,8 +61,10 @@ def download_extract_files(self, dst_folder: str) -> None: # for big files google drive doesn't allow to try downlaoding them more than once # so, in case of receiveing gdown error we need to download them manually - #check if clisp.zip and uzbekvoice-dataset.zip are already in dst_folder - if os.path.exists(os.path.join(dst_folder, 'clips.zip')) and os.path.exists(os.path.join(dst_folder, 'uzbekvoice-dataset.zip')): + # check if clisp.zip and uzbekvoice-dataset.zip are already in dst_folder + if os.path.exists(os.path.join(dst_folder, 'clips.zip')) and os.path.exists( + os.path.join(dst_folder, 'uzbekvoice-dataset.zip') + ): print("Files already exist in the folder. Skipping download.") else: print(f"Downloading files from {self.URL}...") @@ -74,7 +78,6 @@ def download_extract_files(self, dst_folder: str) -> None: extract_archive(file, str(dst_folder), force_extract=True) print(f"Extracted {file}") - def process_transcript(self, file_path: str) -> list[dict[str, typing.Any]]: """ Parse transcript JSON file and put it inside manifest. @@ -93,13 +96,8 @@ def process_transcript(self, file_path: str) -> list[dict[str, typing.Any]]: utter_length = entry["clip_duration"] number_of_entries += 1 entries.append( - { - "audio_filepath": os.path.abspath(audio_file), - "text": transcript, - "duration": utter_length - } + {"audio_filepath": os.path.abspath(audio_file), "text": transcript, "duration": utter_length} ) - logger.info("Total number of entries after processing: %d", number_of_entries) logger.info("Total audio duration (hours) after processing: %.2f", total_duration / 3600) @@ -113,8 +111,7 @@ def process_data(self, data_folder: str, manifest_file: str) -> None: for m in entries: fout.write(json.dumps(m, ensure_ascii=False) + "\n") - - - def process(self): + def process(self, task: _EmptyTask) -> _EmptyTask: self.download_extract_files(self.raw_data_dir) self.process_data(self.raw_data_dir, self.output_manifest_file) + return _EmptyTask(task_id="empty", dataset_name="empty", data=None) diff --git a/sdp/processors/huggingface/speech_recognition.py b/sdp/processors/huggingface/speech_recognition.py index 2e64e7c4..9db7bb86 100644 --- a/sdp/processors/huggingface/speech_recognition.py +++ b/sdp/processors/huggingface/speech_recognition.py @@ -14,13 +14,15 @@ import json from pathlib import Path +from typing import Optional +from ray_curator.tasks import _EmptyTask from tqdm import tqdm from sdp.logging import logger from sdp.processors.base_processor import BaseProcessor from sdp.utils.common import load_manifest -from typing import Optional + class ASRTransformers(BaseProcessor): """This processor transcribes audio files using HuggingFace ASR Transformer models. @@ -99,7 +101,7 @@ def __init__( # Check if using Whisper/Seamless or NVIDIA model based on the model name self.is_whisper_or_seamless = any(x in self.pretrained_model.lower() for x in ['whisper', 'seamless']) - + # Only set language in generation config for Whisper/Seamless models if self.is_whisper_or_seamless and self.generate_language: self.model.generation_config.language = self.generate_language @@ -119,7 +121,7 @@ def __init__( device=self.device, ) - def process(self): + def process(self, task: _EmptyTask) -> _EmptyTask: json_list = load_manifest(Path(self.input_manifest_file)) json_list_sorted = sorted(json_list, key=lambda d: d[self.input_duration_key], reverse=True) @@ -131,7 +133,7 @@ def process(self): batch = json_list_sorted[start_index : start_index + self.batch_size] start_index += self.batch_size audio_files = [item[self.input_audio_key] for item in batch] - + # Only pass generate_kwargs for Whisper/Seamless models if self.is_whisper_or_seamless and self.generate_language and self.generate_task: results = self.pipe( @@ -143,3 +145,4 @@ def process(self): for i, item in enumerate(batch): item[self.output_text_key] = results[i]["text"] f.write(json.dumps(item, ensure_ascii=False) + "\n") + return _EmptyTask(task_id="empty", dataset_name="empty", data=None) diff --git a/sdp/processors/langs/armenian.py b/sdp/processors/langs/armenian.py index 586807ed..319a3bd3 100644 --- a/sdp/processors/langs/armenian.py +++ b/sdp/processors/langs/armenian.py @@ -16,6 +16,7 @@ from pathlib import Path import pandas as pd +from ray_curator.tasks import _EmptyTask from sdp.processors.base_processor import ( BaseParallelProcessor, @@ -62,9 +63,10 @@ class MakeTsv(BaseProcessor): """ - def process(self): + def process(self, task: _EmptyTask) -> _EmptyTask: df1 = pd.DataFrame.from_records(load_manifest(Path(self.input_manifest_file))) df1.to_csv(self.output_manifest_file, index=None, sep='\t') + return _EmptyTask(task_id="empty", dataset_name="empty", data=None) class RandomTsvPart(BaseProcessor): @@ -88,8 +90,9 @@ def __init__( self.part = part self.random_state = random_state - def process(self): + def process(self, task: _EmptyTask) -> _EmptyTask: df1 = pd.read_csv(self.input_manifest_file, sep='\t') df1.sample(frac=self.part, random_state=self.random_state).to_csv( self.output_manifest_file, index=None, sep='\t' ) + return _EmptyTask(task_id="empty", dataset_name="empty", data=None) diff --git a/sdp/processors/modify_manifest/common.py b/sdp/processors/modify_manifest/common.py index 98ad1fa3..b9fe40a8 100644 --- a/sdp/processors/modify_manifest/common.py +++ b/sdp/processors/modify_manifest/common.py @@ -15,7 +15,7 @@ import json import os from pathlib import Path -from typing import Dict, List, Union, Optional +from typing import Dict, List, Optional, Union import pandas as pd from tqdm import tqdm @@ -28,6 +28,7 @@ ) from sdp.utils.common import load_manifest + class CombineSources(BaseParallelProcessor): """Can be used to create a single field from two alternative sources. @@ -104,24 +105,24 @@ class AddConstantFields(BaseParallelProcessor): This processor adds constant fields to all manifest entries using Dask BaseParallelProcessor. It is useful when you want to attach fixed information (e.g., a language label or metadata) to each entry for downstream tasks such as language identification model training. - + Args: fields (dict): A dictionary containing key-value pairs of fields to add to each manifest entry. For example:: - + { "label": "en", "metadata": "mcv-11.0-2022-09-21" } - + Returns: dict: The same data as in the input manifest with the added constant fields as specified in the ``fields`` dictionary. - + Example: - + .. code-block:: yaml - + - _target_: sdp.processors.modify_manifest.common.AddConstantFields input_manifest_file: ${workspace_dir}/input_manifest.json output_manifest_file: ${workspace_dir}/output_manifest.json @@ -139,7 +140,6 @@ def process_dataset_entry(self, data_entry: Dict): return [DataEntry(data=data_entry)] - class DuplicateFields(BaseParallelProcessor): """This processor duplicates fields in all manifest entries. @@ -154,8 +154,8 @@ class DuplicateFields(BaseParallelProcessor): Returns: The same data as in the input manifest with duplicated fields - as specified in the ``duplicate_fields`` input dictionary. - + as specified in the ``duplicate_fields`` input dictionary. + Example: .. code-block:: yaml @@ -165,6 +165,7 @@ class DuplicateFields(BaseParallelProcessor): duplicate_fields: {"text":"answer"} """ + def __init__( self, duplicate_fields: Dict, @@ -334,7 +335,7 @@ def process(self): fout.write(json.dumps(line, ensure_ascii=False) + "\n") -class KeepOnlySpecifiedFields(BaseProcessor): +class KeepOnlySpecifiedFields(BaseParallelProcessor): """Saves a copy of a manifest but only with a subset of the fields. Typically will be the final processor to save only relevant fields @@ -354,14 +355,9 @@ def __init__(self, fields_to_keep: List[str], **kwargs): super().__init__(**kwargs) self.fields_to_keep = fields_to_keep - def process(self): - with open(self.input_manifest_file, "rt", encoding="utf8") as fin, open( - self.output_manifest_file, "wt", encoding="utf8" - ) as fout: - for line in tqdm(fin): - line = json.loads(line) - new_line = {field: line[field] for field in self.fields_to_keep} - fout.write(json.dumps(new_line, ensure_ascii=False) + "\n") + def process_dataset_entry(self, data_entry: Dict): + new_data_entry = {field: data_entry[field] for field in self.fields_to_keep} + return [DataEntry(data=new_data_entry)] class ApplyInnerJoin(BaseProcessor): diff --git a/sdp/processors/modify_manifest/data_to_data.py b/sdp/processors/modify_manifest/data_to_data.py index 16e1de6d..f2c0a749 100644 --- a/sdp/processors/modify_manifest/data_to_data.py +++ b/sdp/processors/modify_manifest/data_to_data.py @@ -13,6 +13,7 @@ # limitations under the License. import collections +import json import os import re from typing import Dict, List, Optional @@ -20,9 +21,9 @@ import soundfile import torchaudio from docx import Document +from ray_curator.tasks import _EmptyTask from sox import Transformer from tqdm import tqdm -import json from sdp.logging import logger from sdp.processors.base_processor import ( @@ -211,7 +212,7 @@ def __init__( # Extract workspace_dir from kwargs to avoid passing it to BaseProcessor if "workspace_dir" in kwargs: workspace_dir = kwargs.pop("workspace_dir") - + super().__init__(**kwargs) self.input_audio_file_key = input_audio_file_key self.output_audio_file_key = output_audio_file_key @@ -230,13 +231,13 @@ def prepare(self): def process_dataset_entry(self, data_entry): audio_path = data_entry[self.input_audio_file_key] - + # If workspace_dir is provided, join it with audio_path to get absolute path if self.workspace_dir is not None: full_audio_path = os.path.join(self.workspace_dir, audio_path) else: full_audio_path = audio_path - + # Debug print first file path if not hasattr(self, '_debug_printed'): logger.info(f"First audio_path from manifest: {audio_path}") @@ -678,6 +679,7 @@ def __init__( def prepare(self): from nemo_text_processing.text_normalization.normalize import Normalizer + try: self.normalizer = Normalizer(input_case=self.input_case, lang=self.input_language) except NotImplementedError as e: @@ -726,7 +728,10 @@ def __init__( self.verbose = verbose def prepare(self): - from nemo_text_processing.inverse_text_normalization.inverse_normalize import InverseNormalizer + from nemo_text_processing.inverse_text_normalization.inverse_normalize import ( + InverseNormalizer, + ) + try: self.inverse_normalizer = InverseNormalizer(input_case=self.input_case, lang=self.input_language) except NotImplementedError as e: @@ -747,7 +752,7 @@ class CopyManifestData(BaseParallelProcessor): Args: copy_path (str): The destination directory where files will be copied. - source_filepath (str): The key in the manifest that contains the path to + source_filepath (str): The key in the manifest that contains the path to the file to be copied. Default: "audio_path". Returns: @@ -763,6 +768,7 @@ class CopyManifestData(BaseParallelProcessor): copy_path: ${workspace_dir}/consolidated_data source_filepath: "audio_filepath" """ + def __init__( self, copy_path: str, @@ -931,15 +937,16 @@ class GetWER(BaseParallelProcessor): """This processor calculates Word Error Rate (WER) between predicted text and ground truth text. It computes the WER for each entry in the manifest and adds the result as a new field. - + Args: text_key (str): Key for the ground truth text field in the manifest. Default: "text". pred_text_key (str): Key for the predicted text field in the manifest. Default: "pred_text". - + Returns: - The same data as in the input manifest with an additional 'wer' field containing + The same data as in the input manifest with an additional 'wer' field containing the calculated Word Error Rate between the specified text fields. """ + def __init__( self, text_key: str = "text", @@ -983,6 +990,7 @@ class MakeSentence(BaseParallelProcessor): end_symbol: "." make_uppercase: true """ + def __init__( self, text_key: str = "text", @@ -1022,7 +1030,14 @@ class ASRFileCheck(BaseProcessor): A manifest with corrupted audio files removed. """ - def __init__(self, audio_filepath_key: str = "audio_filepath", corrupted_audio_dir: str = None, workspace_dir: str = None, **kwargs): + + def __init__( + self, + audio_filepath_key: str = "audio_filepath", + corrupted_audio_dir: str = None, + workspace_dir: str = None, + **kwargs, + ): """ Constructs the necessary attributes for the ASRFileCheck class. @@ -1038,31 +1053,33 @@ def __init__(self, audio_filepath_key: str = "audio_filepath", corrupted_audio_d """ super().__init__(**kwargs) self.audio_filepath_key = audio_filepath_key - + if corrupted_audio_dir is None: - raise ValueError("corrupted_audio_dir parameter is required. Please specify a directory to move corrupted files.") - + raise ValueError( + "corrupted_audio_dir parameter is required. Please specify a directory to move corrupted files." + ) + self.corrupted_audio_dir = corrupted_audio_dir self.workspace_dir = workspace_dir self.failed_files = [] - def process(self): + def process(self, task: _EmptyTask) -> _EmptyTask: """ Check each file listed in the manifest to ensure it can be loaded with torchaudio. This method reads through the manifest file, attempts to load each audio file using torchaudio, and moves corrupted files. A new manifest file is created with only the valid entries. - + Specific errors handled: - FileNotFoundError: File doesn't exist - RuntimeError: File format issues or codec problems - Other exceptions: General issues with file loading """ from sdp.logging import logger - + # Debug print to show workspace_dir logger.info(f"ASRFileCheck workspace_dir: {self.workspace_dir}") - + with open(self.input_manifest_file, 'r') as f: lines = f.readlines() @@ -1076,22 +1093,22 @@ def process(self): line = lines[idx] entry = json.loads(line) audio_path = entry[self.audio_filepath_key] - + # Debug print first file path if idx == 0: logger.info(f"First audio_path from manifest: {audio_path}") - + # If workspace_dir is provided, join it with audio_path to get absolute path if self.workspace_dir is not None: full_audio_path = os.path.join(self.workspace_dir, audio_path) else: full_audio_path = audio_path - + # Debug print first full path if idx == 0: logger.info(f"First full_audio_path: {full_audio_path}") logger.info(f"Path exists: {os.path.exists(full_audio_path)}") - + try: # Attempt to load the audio file to check if it is corrupted torchaudio.load(full_audio_path) @@ -1102,7 +1119,7 @@ def process(self): except RuntimeError as e: logger.warning(f"Audio format error in {audio_path}: {e}") self.failed_files.append(audio_path) - + # Move the corrupted audio file if os.path.exists(full_audio_path): dest_path = os.path.join(self.corrupted_audio_dir, os.path.basename(audio_path)) @@ -1111,7 +1128,7 @@ def process(self): except Exception as e: logger.warning(f"Unknown error loading {audio_path}: {e}") self.failed_files.append(audio_path) - + # Move the corrupted audio file if os.path.exists(full_audio_path): dest_path = os.path.join(self.corrupted_audio_dir, os.path.basename(audio_path)) @@ -1127,3 +1144,4 @@ def process(self): if self.failed_files: logger.warning(f"Failed to process {len(self.failed_files)} files.") logger.debug(f"Failed files: {self.failed_files}") + return _EmptyTask(task_id="empty", dataset_name="empty", data=None) diff --git a/sdp/processors/modify_manifest/data_to_dropbool.py b/sdp/processors/modify_manifest/data_to_dropbool.py index ff675e0a..8c03dc64 100644 --- a/sdp/processors/modify_manifest/data_to_dropbool.py +++ b/sdp/processors/modify_manifest/data_to_dropbool.py @@ -14,12 +14,13 @@ import collections import json +import os import re -import os -import json from operator import eq, ge, gt, le, lt, ne from typing import List, Union +from ray_curator.tasks import _EmptyTask + from sdp.logging import logger from sdp.processors.base_processor import ( BaseParallelProcessor, @@ -808,9 +809,9 @@ class DropRepeatedFields(BaseParallelProcessor): """Drops utterances from the current manifest if their text fields are present in other manifests. This class processes multiple manifest files and removes entries from the current manifest if the text field - matches any entry in the other manifests. It allows for optional punctuation removal from the text fields + matches any entry in the other manifests. It allows for optional punctuation removal from the text fields before performing the check. - + .. note:: It is better to process Test/Dev/Train and then Other.tsv @@ -819,19 +820,21 @@ class DropRepeatedFields(BaseParallelProcessor): current_manifest_file (str): Path to the current manifest file to be processed. punctuations (str): (Optional): String of punctuation characters to be removed from the text fields before checking for duplicates. Defaults to None. text_key (str): The key in the manifest entries that contains the text field. Defaults to "text". - + Returns: The same data as in the input manifest with some entries dropped. """ - def __init__(self, - manifests_paths: List[str], - current_manifest_file: str, - punctuations: str = None, - text_key: str = "text", - **kwargs - ): - super().__init__( **kwargs) + + def __init__( + self, + manifests_paths: List[str], + current_manifest_file: str, + punctuations: str = None, + text_key: str = "text", + **kwargs, + ): + super().__init__(**kwargs) self.manifests_paths = manifests_paths self.current_manifest_file = current_manifest_file self.text_key = text_key @@ -851,10 +854,10 @@ def load_data(self): if self.punctuations is not None and len(self.punctuations) > 0: line_text = self.remove_punctuation(line_text) self.text_set.add(line_text) - + def remove_punctuation(self, text): return re.sub(fr'[{self.punctuations}]', '', text) - + def process_dataset_entry(self, data_entry) -> List: text_for_check = data_entry[self.text_key] if self.punctuations is not None and len(self.punctuations) > 0: @@ -862,7 +865,7 @@ def process_dataset_entry(self, data_entry) -> List: if text_for_check in self.text_set: return [DataEntry(data=None, metrics=1)] return [DataEntry(data=data_entry, metrics=0)] - + def finalize(self, metrics: List): total_counter = 0 for counter in metrics: @@ -889,7 +892,7 @@ def __init__(self, drop_key: str = "text", **kwargs): self.drop_key = drop_key self.seen_texts = set() - def process(self): + def process(self, task: _EmptyTask) -> _EmptyTask: unique_entries = [] with open(self.input_manifest_file, 'r', encoding='utf-8') as file: for line in file: @@ -904,4 +907,4 @@ def process(self): fout.write(json.dumps(entry, ensure_ascii=False) + "\n") logger.info(f"Total number of entries after processing: {len(unique_entries)}") - return unique_entries + return _EmptyTask(task_id="empty", dataset_name="empty", data=None) diff --git a/sdp/processors/nemo/asr_inference.py b/sdp/processors/nemo/asr_inference.py index 4359f320..634674c2 100644 --- a/sdp/processors/nemo/asr_inference.py +++ b/sdp/processors/nemo/asr_inference.py @@ -17,6 +17,8 @@ from pathlib import Path from typing import Optional +from ray_curator.tasks import _EmptyTask + from sdp.processors.base_processor import BaseProcessor # Note that we do not re-use base parallel implementation, since the ASR @@ -44,7 +46,7 @@ class ASRInference(BaseProcessor): def __init__( self, - pretrained_model: Optional[str]=None, + pretrained_model: Optional[str] = None, batch_size: int = 32, **kwargs, ): @@ -53,7 +55,7 @@ def __init__( self.pretrained_model = pretrained_model self.batch_size = batch_size - def process(self): + def process(self, task: _EmptyTask) -> _EmptyTask: """This will add "pred_text" key into the output manifest.""" os.makedirs(os.path.dirname(self.output_manifest_file), exist_ok=True) if self.pretrained_model.endswith(".nemo"): @@ -75,4 +77,5 @@ def process(self): f"batch_size={self.batch_size} ", shell=True, check=True, - ) \ No newline at end of file + ) + return _EmptyTask(task_id="empty", dataset_name="empty", data=None) diff --git a/tests/test_curator.py b/tests/test_curator.py index 21b4746c..54232642 100644 --- a/tests/test_curator.py +++ b/tests/test_curator.py @@ -1,16 +1,26 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import json import os -import sys import tempfile from pathlib import Path -import hydra import yaml -from hydra import compose, initialize -from omegaconf import DictConfig, OmegaConf, open_dict +from omegaconf import OmegaConf -from sdp.run_processors import run_processors, update_processor_imports -from sdp.utils import BootstrapProcessor +from sdp.run_processors import run_processors def _write_config(file_path: Path, dict_conf): @@ -18,13 +28,7 @@ def _write_config(file_path: Path, dict_conf): yaml.dump(dict_conf, file) -def read_yaml(config_path=".", config_name="config"): - with initialize(version_base=None, config_path=config_path): - cfg = compose(config_name=config_name) - return cfg - - -def make_dict(output_manifest_file, use_backend=None): +def _make_dict(output_manifest_file, use_backend=None): workspace_dir = os.path.join(os.getenv('TEST_DATA_ROOT'), "armenian/audio_books/mp3") return { "processors_to_run": "0:", @@ -41,7 +45,7 @@ def make_dict(output_manifest_file, use_backend=None): } -def make_expected_output(): +def _make_expected_output(): workspace_dir = os.path.join(os.getenv('TEST_DATA_ROOT'), "armenian/audio_books/mp3") return {'audio_filepath': os.path.join(workspace_dir, "Eleonora/Eleonora30s.mp3")} @@ -49,7 +53,7 @@ def make_expected_output(): def test_curator(): with tempfile.TemporaryDirectory() as tmpdir: output_path = os.path.join(tmpdir, "output_manifest_file.jsonl") - dict_conf = make_dict(output_manifest_file=output_path, use_backend="curator") + dict_conf = _make_dict(output_manifest_file=output_path, use_backend="curator") conf_path = Path(tmpdir) / "config.yaml" _write_config(conf_path, dict_conf) @@ -59,15 +63,14 @@ def test_curator(): with open(output_path, "r") as f: output = json.load(f) - expected_output = make_expected_output() - + expected_output = _make_expected_output() assert output == expected_output, f"Expected {expected_output}, but got {output}" def test_multiprocessing(): with tempfile.TemporaryDirectory() as tmpdir: output_path = os.path.join(tmpdir, "output_manifest_file.jsonl") - dict_conf = make_dict(output_manifest_file=output_path, use_backend=None) + dict_conf = _make_dict(output_manifest_file=output_path, use_backend=None) conf_path = Path(tmpdir) / "config.yaml" _write_config(conf_path, dict_conf) @@ -77,14 +80,14 @@ def test_multiprocessing(): with open(output_path, "r") as f: output = json.load(f) - expected_output = make_expected_output() + expected_output = _make_expected_output() assert output == expected_output, f"Expected {expected_output}, but got {output}" def test_dask(): with tempfile.TemporaryDirectory() as tmpdir: output_path = os.path.join(tmpdir, "output_manifest_file.jsonl") - dict_conf = make_dict(output_manifest_file=output_path, use_backend="dask") + dict_conf = _make_dict(output_manifest_file=output_path, use_backend="dask") conf_path = Path(tmpdir) / "config.yaml" _write_config(conf_path, dict_conf) @@ -94,5 +97,5 @@ def test_dask(): with open(output_path, "r") as f: output = json.load(f) - expected_output = make_expected_output() + expected_output = _make_expected_output() assert output == expected_output, f"Expected {expected_output}, but got {output}" diff --git a/tests/test_manifest_chunking.py b/tests/test_manifest_chunking.py index ae1aa394..f836c6df 100644 --- a/tests/test_manifest_chunking.py +++ b/tests/test_manifest_chunking.py @@ -21,97 +21,93 @@ import json import pytest +from ray_curator.tasks import _EmptyTask -from sdp.processors import DropNonAlphabet -from sdp.processors import SubMakeLowercase +from sdp.processors import DropNonAlphabet, SubMakeLowercase -def test_submakelowercase_with_chunking(tmp_path): - input_lines = [ - {"text": "ABC"}, - {"text": "DEF"}, - {"text": "GHI"}, - {"text": "JKL"}, - {"text": "MNO"}, - {"text": "PQR"}, - {"text": "STU"}, - {"text": "VWX"}, - {"text": "YZ"}, - ] - - expected_output_lines = [ - {"text": "abc"}, - {"text": "def"}, - {"text": "ghi"}, - {"text": "jkl"}, - {"text": "mno"}, - {"text": "pqr"}, - {"text": "stu"}, - {"text": "vwx"}, - {"text": "yz"}, - ] - - - # save input lines to manifest: - input_manifest_file = tmp_path / "input_manifest.json" - with open(input_manifest_file, "w") as f: - for line in input_lines: - f.write(json.dumps(line) + "\n") - - # run make_lowercase processor: - output_manifest_file = tmp_path / "output_manifest_make_lowercase.json" - processor = SubMakeLowercase( - input_manifest_file=input_manifest_file, - output_manifest_file=output_manifest_file, - in_memory_chunksize=2 - ) - - processor.process() - - # check that output manifest matches expected lines: - with open(output_manifest_file, "r") as f: - output_lines = [json.loads(line) for line in f] - - assert output_lines == expected_output_lines +def test_submakelowercase_with_chunking(tmp_path): + input_lines = [ + {"text": "ABC"}, + {"text": "DEF"}, + {"text": "GHI"}, + {"text": "JKL"}, + {"text": "MNO"}, + {"text": "PQR"}, + {"text": "STU"}, + {"text": "VWX"}, + {"text": "YZ"}, + ] + + expected_output_lines = [ + {"text": "abc"}, + {"text": "def"}, + {"text": "ghi"}, + {"text": "jkl"}, + {"text": "mno"}, + {"text": "pqr"}, + {"text": "stu"}, + {"text": "vwx"}, + {"text": "yz"}, + ] + + # save input lines to manifest: + input_manifest_file = tmp_path / "input_manifest.json" + with open(input_manifest_file, "w") as f: + for line in input_lines: + f.write(json.dumps(line) + "\n") + + # run make_lowercase processor: + output_manifest_file = tmp_path / "output_manifest_make_lowercase.json" + processor = SubMakeLowercase( + input_manifest_file=input_manifest_file, output_manifest_file=output_manifest_file, in_memory_chunksize=2 + ) + + processor.process(_EmptyTask(task_id="empty", dataset_name="empty", data=None)) + + # check that output manifest matches expected lines: + with open(output_manifest_file, "r") as f: + output_lines = [json.loads(line) for line in f] + + assert output_lines == expected_output_lines def test_dropnonalphabet_with_chunking(tmp_path): - - input_lines = [ - {"text": "ABC"}, - {"text": "DEF"}, - {"text": "GHI"}, - {"text": "JKL"}, - {"text": "MNO"}, - {"text": "PQR"}, - {"text": "STU"}, - {"text": "VWX"}, - {"text": "YZ"}, - ] - - expected_output_lines = [ - {"text": "ABC"}, - ] - - # save input lines to manifest: - input_manifest_file = tmp_path / "input_manifest.json" - with open(input_manifest_file, "w") as f: - for line in input_lines: - f.write(json.dumps(line) + "\n") - - # run make_lowercase processor: - output_manifest_file = tmp_path / "output_manifest_make_lowercase.json" - processor = DropNonAlphabet( - input_manifest_file=input_manifest_file, - output_manifest_file=output_manifest_file, - in_memory_chunksize=2, - alphabet="ABC" - ) - - processor.process() - - # check that output manifest matches expected lines: - with open(output_manifest_file, "r") as f: - output_lines = [json.loads(line) for line in f] - - assert output_lines == expected_output_lines + input_lines = [ + {"text": "ABC"}, + {"text": "DEF"}, + {"text": "GHI"}, + {"text": "JKL"}, + {"text": "MNO"}, + {"text": "PQR"}, + {"text": "STU"}, + {"text": "VWX"}, + {"text": "YZ"}, + ] + + expected_output_lines = [ + {"text": "ABC"}, + ] + + # save input lines to manifest: + input_manifest_file = tmp_path / "input_manifest.json" + with open(input_manifest_file, "w") as f: + for line in input_lines: + f.write(json.dumps(line) + "\n") + + # run make_lowercase processor: + output_manifest_file = tmp_path / "output_manifest_make_lowercase.json" + processor = DropNonAlphabet( + input_manifest_file=input_manifest_file, + output_manifest_file=output_manifest_file, + in_memory_chunksize=2, + alphabet="ABC", + ) + + processor.process(_EmptyTask(task_id="empty", dataset_name="empty", data=None)) + + # check that output manifest matches expected lines: + with open(output_manifest_file, "r") as f: + output_lines = [json.loads(line) for line in f] + + assert output_lines == expected_output_lines From 907365c71921098784d93b4c2b80fe7c6e8b61e8 Mon Sep 17 00:00:00 2001 From: Nikolay Karpov Date: Sat, 5 Jul 2025 03:06:15 -0700 Subject: [PATCH 03/13] inherit Task Signed-off-by: Nikolay Karpov --- sdp/processors/base_processor.py | 10 +++++----- .../uzbekvoice/create_initial_manifest.py | 16 +++++++++------- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/sdp/processors/base_processor.py b/sdp/processors/base_processor.py index 35cc8729..a0ed4fef 100644 --- a/sdp/processors/base_processor.py +++ b/sdp/processors/base_processor.py @@ -23,7 +23,7 @@ from typing import Any, Dict, List, Optional, Union from ray_curator.stages.base import ProcessingStage -from ray_curator.tasks import _EmptyTask +from ray_curator.tasks import Task, _EmptyTask from tqdm import tqdm from tqdm.contrib.concurrent import process_map @@ -95,7 +95,7 @@ def name(self) -> str: return "BaseProcessor" -class BaseParallelProcessor(BaseProcessor, ProcessingStage[_EmptyTask, _EmptyTask]): +class BaseParallelProcessor(BaseProcessor, ProcessingStage[Task, Task]): """ A processor that performs per-entry processing in parallel (using Dask or multiprocessing). @@ -125,7 +125,7 @@ def __init__( chunksize: int = 100, in_memory_chunksize: int = 100000, test_cases: Optional[List[Dict]] = None, - use_backend: bool = True, + use_backend: Optional[str] = None, dask_client=None, **kwargs, ): @@ -147,7 +147,7 @@ def prepare(self): """Can be used in derived classes to prepare the processing.""" pass - def process(self, task: _EmptyTask) -> _EmptyTask: + def process(self, task: Task) -> Task: """A fork in the road to pick dask or classic processing""" os.environ.setdefault("PATH", os.defpath) @@ -163,7 +163,7 @@ def process(self, task: _EmptyTask) -> _EmptyTask: self._process_with_multiprocessing(metrics) self.finalize(metrics) - return _EmptyTask(task_id="empty", dataset_name="empty", data=None) + return task def inputs(self) -> tuple[list[str], list[str]]: return [], [] diff --git a/sdp/processors/datasets/uzbekvoice/create_initial_manifest.py b/sdp/processors/datasets/uzbekvoice/create_initial_manifest.py index ee272dc9..78949dce 100644 --- a/sdp/processors/datasets/uzbekvoice/create_initial_manifest.py +++ b/sdp/processors/datasets/uzbekvoice/create_initial_manifest.py @@ -18,7 +18,7 @@ import typing import gdown -from ray_curator.tasks import _EmptyTask +from ray_curator.tasks import DocumentBatch, EmptyTask, _EmptyTask from sdp.logging import logger from sdp.processors.base_processor import BaseProcessor @@ -107,11 +107,13 @@ def process_transcript(self, file_path: str) -> list[dict[str, typing.Any]]: def process_data(self, data_folder: str, manifest_file: str) -> None: entries = self.process_transcript(os.path.join(data_folder, "uzbekvoice-dataset", "voice_dataset.json")) - with open(manifest_file, "w", encoding="utf-8") as fout: - for m in entries: - fout.write(json.dumps(m, ensure_ascii=False) + "\n") + if self.use_backend is None or self.use_backend == "dask": + with open(manifest_file, "w", encoding="utf-8") as fout: + for m in entries: + fout.write(json.dumps(m, ensure_ascii=False) + "\n") + return entries - def process(self, task: _EmptyTask) -> _EmptyTask: + def process(self, task: _EmptyTask) -> DocumentBatch: self.download_extract_files(self.raw_data_dir) - self.process_data(self.raw_data_dir, self.output_manifest_file) - return _EmptyTask(task_id="empty", dataset_name="empty", data=None) + entries = self.process_data(self.raw_data_dir, self.output_manifest_file) + return DocumentBatch(entries) From 2a6458807b3b53738e75bc13a558971f8f7586aa Mon Sep 17 00:00:00 2001 From: Nikolay Karpov Date: Mon, 21 Jul 2025 10:49:42 -0700 Subject: [PATCH 04/13] DataEntry from Task Signed-off-by: Nikolay Karpov --- sdp/processors/base_processor.py | 50 ++++++++++++++----- .../modify_manifest/create_manifest.py | 29 +++++++---- tests/test_curator.py | 36 +------------ 3 files changed, 57 insertions(+), 58 deletions(-) diff --git a/sdp/processors/base_processor.py b/sdp/processors/base_processor.py index a0ed4fef..45caedee 100644 --- a/sdp/processors/base_processor.py +++ b/sdp/processors/base_processor.py @@ -31,12 +31,19 @@ @dataclass -class DataEntry: +class DataEntry(Task[list]): """A wrapper for data entry + any additional metrics.""" data: Optional[Dict] # can be None to drop the entry metrics: Any = None + @property + def num_items(self) -> int: + return len(self.data) + + def validate(self) -> bool: + return True + class BaseProcessor(ABC): """Abstract class for SDP processors. @@ -107,14 +114,14 @@ class BaseParallelProcessor(BaseProcessor, ProcessingStage[Task, Task]): in_memory_chunksize (int): Maximum number of entries to load at once. test_cases (list[dict]): Optional list of test cases. use_backend (str): Use {None, dask, curator} for parallelization. Use None for multiprocessing. - dask_client: (Optional) An existing Dask client. + backend_client: (Optional) An existing backend client. """ def __getstate__(self): state = self.__dict__.copy() # Remove the Dask client from state (it is not picklable) - if 'dask_client' in state: - state['dask_client'] = None + if 'backend_client' in state: + state['backend_client'] = None return state def __init__( @@ -126,7 +133,7 @@ def __init__( in_memory_chunksize: int = 100000, test_cases: Optional[List[Dict]] = None, use_backend: Optional[str] = None, - dask_client=None, + backend_client=None, **kwargs, ): kwargs.pop("use_backend", None) # @@ -141,7 +148,7 @@ def __init__( self.start_time = time.time() self.test_cases = test_cases or [] self.use_backend = use_backend - self.dask_client = dask_client + self.backend_client = backend_client def prepare(self): """Can be used in derived classes to prepare the processing.""" @@ -157,10 +164,10 @@ def process(self, task: Task) -> Task: metrics = [] # Ability to work sa legacy and as dask - if self.use_backend == "dask": - self._process_with_dask(metrics) + if self.use_backend == "curator": + task = self._process_with_ray(metrics) else: - self._process_with_multiprocessing(metrics) + task = self._process_with_multiprocessing(metrics) self.finalize(metrics) return task @@ -169,7 +176,7 @@ def inputs(self) -> tuple[list[str], list[str]]: return [], [] def outputs(self) -> tuple[list[str], list[str]]: - return [], [] + return ["data"], [] @property def name(self) -> str: @@ -179,9 +186,9 @@ def _process_with_dask(self, metrics): import dask.bag as db from dask.distributed import Client - if self.dask_client is None: - self.dask_client = Client() - client = self.dask_client + if self.backend_client is None: + self.backend_client = Client() + client = self.backend_client from sdp.logging import logger logger.info(f"Using Dask client with dashboard at: {client.dashboard_link}") @@ -210,6 +217,22 @@ def _process_with_dask(self, metrics): self.total_duration += entry.data.get("duration", 0) logger.info(f"Processed {total_entries} entries using Dask.") + def _process_with_ray(self, metrics): + tasks = [] + with open(self.output_manifest_file, "wt", encoding="utf8") as fout: + for manifest_chunk in self._chunk_manifest(): + data = self.process_dataset_entry(manifest_chunk) + for data_entry in tqdm(data): + metrics.append(data_entry.metrics) + if data_entry.data is None: + continue + json.dump(data_entry.data, fout, ensure_ascii=False) + fout.write("\n") + self.number_of_entries += 1 + self.total_duration += data_entry.data.get("duration", 0) + tasks.extend(data) + return tasks + def _process_with_multiprocessing(self, metrics): with open(self.output_manifest_file, "wt", encoding="utf8") as fout: for manifest_chunk in self._chunk_manifest(): @@ -229,6 +252,7 @@ def _process_with_multiprocessing(self, metrics): fout.write("\n") self.number_of_entries += 1 self.total_duration += data_entry.data.get("duration", 0) + return data def _chunk_manifest(self): """Splits the input manifest into chunks of in_memory_chunksize size. diff --git a/sdp/processors/modify_manifest/create_manifest.py b/sdp/processors/modify_manifest/create_manifest.py index 1e416571..c0082fac 100644 --- a/sdp/processors/modify_manifest/create_manifest.py +++ b/sdp/processors/modify_manifest/create_manifest.py @@ -17,10 +17,7 @@ import pandas -from sdp.processors.base_processor import ( - BaseParallelProcessor, - DataEntry, -) +from sdp.processors.base_processor import BaseParallelProcessor, DataEntry class CreateInitialManifestByExt(BaseParallelProcessor): @@ -55,23 +52,30 @@ def read_manifest(self): def process_dataset_entry(self, data_entry): data = {self.output_file_key: data_entry} - return [DataEntry(data=data)] + return [ + DataEntry( + data=data, + task_id=0, + dataset_name=str(self.raw_data_dir / "*.") + self.extension, + ) + ] class CreateCombinedManifests(BaseParallelProcessor): """Reads JSON lines from specified files and creates a combined manifest. - This processor iterates over files listed in `manifest_list`, reads each file line by line, + This processor iterates over files listed in `manifest_list`, reads each file line by line, and yields the parsed JSON data from each line. Args: - manifest_list (list(str)): A list of file paths or directories to process. The processor will + manifest_list (list(str)): A list of file paths or directories to process. The processor will recursively read files within the directories and expect each file to contain JSON data. **kwargs: Additional keyword arguments passed to the base class `BaseParallelProcessor`. Returns: A generator that yields parsed JSON data from each line in the files listed in `manifest_list`. """ + def __init__( self, manifest_list: list[str], @@ -87,7 +91,10 @@ def read_manifest(self): yield json.loads(line) def process_dataset_entry(self, data_entry): - return [DataEntry(data=data_entry)] - - - + return [ + DataEntry( + data=data_entry, + task_id=0, + dataset_name=self.__class__.__name__, + ) + ] diff --git a/tests/test_curator.py b/tests/test_curator.py index 54232642..a7103cba 100644 --- a/tests/test_curator.py +++ b/tests/test_curator.py @@ -19,6 +19,8 @@ import yaml from omegaconf import OmegaConf +from ray_curator.stages.base import ProcessingStage +from ray_curator.tasks import Task, _EmptyTask from sdp.run_processors import run_processors @@ -65,37 +67,3 @@ def test_curator(): expected_output = _make_expected_output() assert output == expected_output, f"Expected {expected_output}, but got {output}" - - -def test_multiprocessing(): - with tempfile.TemporaryDirectory() as tmpdir: - output_path = os.path.join(tmpdir, "output_manifest_file.jsonl") - dict_conf = _make_dict(output_manifest_file=output_path, use_backend=None) - conf_path = Path(tmpdir) / "config.yaml" - _write_config(conf_path, dict_conf) - - cfg = OmegaConf.load(conf_path) - - run_processors(cfg) - with open(output_path, "r") as f: - output = json.load(f) - - expected_output = _make_expected_output() - assert output == expected_output, f"Expected {expected_output}, but got {output}" - - -def test_dask(): - with tempfile.TemporaryDirectory() as tmpdir: - output_path = os.path.join(tmpdir, "output_manifest_file.jsonl") - dict_conf = _make_dict(output_manifest_file=output_path, use_backend="dask") - conf_path = Path(tmpdir) / "config.yaml" - _write_config(conf_path, dict_conf) - - cfg = OmegaConf.load(conf_path) - - run_processors(cfg) - with open(output_path, "r") as f: - output = json.load(f) - - expected_output = _make_expected_output() - assert output == expected_output, f"Expected {expected_output}, but got {output}" From b8368e535c62853cd9b43930bfbd17ba1eb4b32f Mon Sep 17 00:00:00 2001 From: Nikolay Karpov Date: Mon, 21 Jul 2025 11:14:27 -0700 Subject: [PATCH 05/13] num_items 1 Signed-off-by: Nikolay Karpov --- sdp/processors/base_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdp/processors/base_processor.py b/sdp/processors/base_processor.py index 45caedee..eda34e03 100644 --- a/sdp/processors/base_processor.py +++ b/sdp/processors/base_processor.py @@ -39,7 +39,7 @@ class DataEntry(Task[list]): @property def num_items(self) -> int: - return len(self.data) + return 1 def validate(self) -> bool: return True From e7837e6890cff9c51c1ba7929a9888344a583c5c Mon Sep 17 00:00:00 2001 From: Nikolay Karpov Date: Mon, 21 Jul 2025 20:34:55 -0700 Subject: [PATCH 06/13] add SaveJsonl Signed-off-by: Nikolay Karpov --- sdp/processors/__init__.py | 28 +++++++------ sdp/processors/base_processor.py | 19 ++++----- .../modify_manifest/create_manifest.py | 42 ++++++++++++++++++- sdp/run_processors.py | 35 +++++++++------- tests/test_curator.py | 13 ++++-- 5 files changed, 92 insertions(+), 45 deletions(-) diff --git a/sdp/processors/__init__.py b/sdp/processors/__init__.py index c3ff70b6..da9b4495 100644 --- a/sdp/processors/__init__.py +++ b/sdp/processors/__init__.py @@ -25,9 +25,8 @@ CreateInitialManifestFleurs, ) from sdp.processors.datasets.hifitts2.download_dataset import DownloadHiFiTTS2 -from sdp.processors.datasets.hifitts2.remove_failed_chapters import RemovedFailedChapters -from sdp.processors.datasets.uzbekvoice.create_initial_manifest import ( - CreateInitialManifestUzbekvoice, +from sdp.processors.datasets.hifitts2.remove_failed_chapters import ( + RemovedFailedChapters, ) from sdp.processors.datasets.ksc2.create_initial_manifest import ( CreateInitialManifestKSC2, @@ -37,13 +36,15 @@ CreateInitialManifestLibrispeech, ) from sdp.processors.datasets.masc import ( - CreateInitialManifestMASC, AggregateSegments, + CreateInitialManifestMASC, + GetCaptionFileSegments, RegExpVttEntries, - GetCaptionFileSegments ) -from sdp.processors.datasets.mediaspeech.create_initial_manifest import CreateInitialManifestMediaSpeech from sdp.processors.datasets.mcv.create_initial_manifest import CreateInitialManifestMCV +from sdp.processors.datasets.mediaspeech.create_initial_manifest import ( + CreateInitialManifestMediaSpeech, +) from sdp.processors.datasets.mls.create_initial_manifest import CreateInitialManifestMLS from sdp.processors.datasets.mls.restore_pc import RestorePCForMLS from sdp.processors.datasets.mtedx.create_initial_manifest import ( @@ -60,18 +61,20 @@ CreateInitialManifestSLR140, CustomDataSplitSLR140, ) +from sdp.processors.datasets.uzbekvoice.create_initial_manifest import ( + CreateInitialManifestUzbekvoice, +) from sdp.processors.datasets.voxpopuli.create_initial_manifest import ( CreateInitialManifestVoxpopuli, ) from sdp.processors.datasets.voxpopuli.normalize_from_non_pc_text import ( NormalizeFromNonPCTextVoxpopuli, ) -from sdp.processors.datasets.ytc.create_initial_manifest import ( - CreateInitialManifestYTC, +from sdp.processors.datasets.ytc.create_initial_manifest import CreateInitialManifestYTC +from sdp.processors.huggingface.create_initial_manifest import ( + CreateInitialManifestHuggingFace, ) from sdp.processors.huggingface.speech_recognition import ASRTransformers -from sdp.processors.huggingface.create_initial_manifest import CreateInitialManifestHuggingFace - from sdp.processors.modify_manifest.common import ( AddConstantFields, ApplyInnerJoin, @@ -86,6 +89,7 @@ from sdp.processors.modify_manifest.create_manifest import ( CreateCombinedManifests, CreateInitialManifestByExt, + SaveJsonl, ) from sdp.processors.modify_manifest.data_to_data import ( ASRFileCheck, @@ -97,8 +101,8 @@ GetWER, InsIfASRInsertion, InverseNormalizeText, - NormalizeText, MakeSentence, + NormalizeText, ReadDocxLines, ReadTxtLines, SoxConvert, @@ -122,8 +126,8 @@ DropLowWordMatchRate, DropNonAlphabet, DropOnAttribute, - PreserveByValue, DropRepeatedFields, + PreserveByValue, ) from sdp.processors.modify_manifest.make_letters_uppercase_after_period import ( MakeLettersUppercaseAfterPeriod, diff --git a/sdp/processors/base_processor.py b/sdp/processors/base_processor.py index eda34e03..daab9e85 100644 --- a/sdp/processors/base_processor.py +++ b/sdp/processors/base_processor.py @@ -23,7 +23,7 @@ from typing import Any, Dict, List, Optional, Union from ray_curator.stages.base import ProcessingStage -from ray_curator.tasks import Task, _EmptyTask +from ray_curator.tasks import Task from tqdm import tqdm from tqdm.contrib.concurrent import process_map @@ -154,7 +154,7 @@ def prepare(self): """Can be used in derived classes to prepare the processing.""" pass - def process(self, task: Task) -> Task: + def process(self, tasks: Task) -> Task: """A fork in the road to pick dask or classic processing""" os.environ.setdefault("PATH", os.defpath) @@ -165,12 +165,11 @@ def process(self, task: Task) -> Task: # Ability to work sa legacy and as dask if self.use_backend == "curator": - task = self._process_with_ray(metrics) + tasks = self._process_with_ray(metrics) else: - task = self._process_with_multiprocessing(metrics) + tasks = self._process_with_multiprocessing(metrics) self.finalize(metrics) - - return task + return tasks def inputs(self) -> tuple[list[str], list[str]]: return [], [] @@ -219,15 +218,13 @@ def _process_with_dask(self, metrics): def _process_with_ray(self, metrics): tasks = [] - with open(self.output_manifest_file, "wt", encoding="utf8") as fout: - for manifest_chunk in self._chunk_manifest(): - data = self.process_dataset_entry(manifest_chunk) + for manifest_chunk in self._chunk_manifest(): + for row in manifest_chunk: + data = self.process_dataset_entry(row) for data_entry in tqdm(data): metrics.append(data_entry.metrics) if data_entry.data is None: continue - json.dump(data_entry.data, fout, ensure_ascii=False) - fout.write("\n") self.number_of_entries += 1 self.total_duration += data_entry.data.get("duration", 0) tasks.extend(data) diff --git a/sdp/processors/modify_manifest/create_manifest.py b/sdp/processors/modify_manifest/create_manifest.py index c0082fac..323bdac0 100644 --- a/sdp/processors/modify_manifest/create_manifest.py +++ b/sdp/processors/modify_manifest/create_manifest.py @@ -15,9 +15,47 @@ import json from pathlib import Path -import pandas +from ray_curator.stages.base import ProcessingStage +from ray_curator.stages.resources import Resources +from ray_curator.tasks import Task -from sdp.processors.base_processor import BaseParallelProcessor, DataEntry +from sdp.processors.base_processor import ( + BaseParallelProcessor, + BaseProcessor, + DataEntry, +) + + +class SaveJsonl(BaseProcessor, ProcessingStage[Task, Task]): + """ + Processor for creating an initial dataset manifest by saving filepaths with a common extension to the field specified in output_field. + + Args: + raw_data_dir (str): The root directory of the files to be added to the initial manifest. This processor will recursively look for files with the extension 'extension' inside this directory. + output_file_key (str): The key to store the paths to the files in the dataset. + extension (str): The file extension of the of the files to be added to the manifest. + **kwargs: Additional keyword arguments to be passed to the base class `BaseParallelProcessor`. + + """ + + name: str = "SaveManifest" + resources: Resources = Resources(cpus=1.0, gpu_memory_gb=10.0) + batch_size: int = 100000 + + def __init__( + self, + **kwargs, + ): + super().__init__(**kwargs) + + def setup(self, a): + # Path(self.output_manifest_file).touch() + open(self.output_manifest_file, 'w').close() + + def process(self, tasks: DataEntry) -> DataEntry: + with open(self.output_manifest_file, 'a', encoding="utf8") as f: + f.write(json.dumps(tasks.data) + '\n') + return tasks class CreateInitialManifestByExt(BaseParallelProcessor): diff --git a/sdp/run_processors.py b/sdp/run_processors.py index 471064f2..5496e33c 100644 --- a/sdp/run_processors.py +++ b/sdp/run_processors.py @@ -146,18 +146,20 @@ def run_processors(cfg): except Exception as e: logger.error(f"An unexpected error occurred during management of imports: {e}") - # Detecting dask + # Detecting ray try: - from dask.distributed import Client + import ray - dask_available = True + ray.init() + + ray_available = True except ImportError: logger.warning("Dask not installed; using multiprocessing for all processors") - dask_available = False + ray_available = False # look for global directions in cfg for dask usage - if bool(cfg.get("use_backend", None) == "dask") and dask_available: - global_use_backend = "dask" + if bool(cfg.get("use_backend", None) == "ray") and ray_available: + global_use_backend = "ray" else: global_use_backend = cfg.get("use_backend", None) @@ -230,40 +232,41 @@ def run_processors(cfg): processors.append(processor) # Start Dask client if any processor requires it - dask_client = None + backend_client = None if any(p.use_backend for p in processors): try: num_cpus = psutil.cpu_count(logical=False) or 4 logger.info(f"Starting Dask client with {num_cpus} workers") - dask_client = Client(n_workers=num_cpus, processes=True) - logger.info(f"Dask dashboard at: {dask_client.dashboard_link}") + backend_client = Client(n_workers=num_cpus, processes=True) + logger.info(f"Dask dashboard at: {backend_client.dashboard_link}") except Exception as e: logger.warning(f"Failed to start Dask client: {e}") - dask_client = None + backend_client = None # Run processors in order try: if global_use_backend == "curator": pipeline = Pipeline(name="processing", description="Process data from JSONL files") for p in cfg.processors: - stage = hydra.utils.instantiate(processor_cfg) + stage = hydra.utils.instantiate(p, use_backend="curator", backend_client=backend_client) pipeline.add_stage(stage) executor = XennaExecutor() results = pipeline.run(executor) - # raise ValueError("results", results) + # raise ValueError("results" + results) else: for proc in processors: - if proc.use_backend == "dask" and dask_client is not None: - proc.dask_client = dask_client + if proc.use_backend == "dask" and backend_client is not None: + proc.backend_client = backend_client logger.info('=> Running processor "%s" with Dask', proc) else: logger.info('=> Running processor "%s" with Multiprocessing', proc) + # raise ValueError("_EmptyTask_EmptyTask") proc.process(_EmptyTask(task_id="empty", dataset_name="empty", data=None)) finally: - if dask_client is not None: + if backend_client is not None: logger.info("Shutting down Dask client...") - dask_client.close(timeout="60s") + backend_client.close(timeout="60s") logger.info("Dask client shutdown complete") diff --git a/tests/test_curator.py b/tests/test_curator.py index a7103cba..697d547c 100644 --- a/tests/test_curator.py +++ b/tests/test_curator.py @@ -23,6 +23,7 @@ from ray_curator.tasks import Task, _EmptyTask from sdp.run_processors import run_processors +from sdp.utils.common import load_manifest def _write_config(file_path: Path, dict_conf): @@ -41,6 +42,9 @@ def _make_dict(output_manifest_file, use_backend=None): "raw_data_dir": workspace_dir, "extension": "mp3", "output_file_key": "audio_filepath", + }, + { + "_target_": "sdp.processors.SaveJsonl", "output_manifest_file": output_manifest_file, }, ], @@ -49,12 +53,13 @@ def _make_dict(output_manifest_file, use_backend=None): def _make_expected_output(): workspace_dir = os.path.join(os.getenv('TEST_DATA_ROOT'), "armenian/audio_books/mp3") - return {'audio_filepath': os.path.join(workspace_dir, "Eleonora/Eleonora30s.mp3")} + return [{'audio_filepath': os.path.join(workspace_dir, "Eleonora/Eleonora30s.mp3")}] def test_curator(): with tempfile.TemporaryDirectory() as tmpdir: - output_path = os.path.join(tmpdir, "output_manifest_file.jsonl") + # output_path = os.path.join(tmpdir, "output_manifest_file.jsonl") + output_path = "/tmp/output_manifest_file.jsonl" dict_conf = _make_dict(output_manifest_file=output_path, use_backend="curator") conf_path = Path(tmpdir) / "config.yaml" _write_config(conf_path, dict_conf) @@ -62,8 +67,8 @@ def test_curator(): cfg = OmegaConf.load(conf_path) run_processors(cfg) - with open(output_path, "r") as f: - output = json.load(f) + + output = load_manifest(output_path) expected_output = _make_expected_output() assert output == expected_output, f"Expected {expected_output}, but got {output}" From 7554d9933aca180b1bdbc828ba6e7a7b89392119 Mon Sep 17 00:00:00 2001 From: Nikolay Karpov Date: Tue, 22 Jul 2025 10:55:19 -0700 Subject: [PATCH 07/13] BaseProcessor inherit ProcessingStage Signed-off-by: Nikolay Karpov --- sdp/processors/base_processor.py | 4 ++-- .../modify_manifest/create_manifest.py | 16 ++++------------ tests/test_curator.py | 15 +++++++-------- 3 files changed, 13 insertions(+), 22 deletions(-) diff --git a/sdp/processors/base_processor.py b/sdp/processors/base_processor.py index daab9e85..2f87d093 100644 --- a/sdp/processors/base_processor.py +++ b/sdp/processors/base_processor.py @@ -45,7 +45,7 @@ def validate(self) -> bool: return True -class BaseProcessor(ABC): +class BaseProcessor(ProcessingStage[Task, Task]): """Abstract class for SDP processors. All processor classes inherit from the ``BaseProcessor`` class. @@ -102,7 +102,7 @@ def name(self) -> str: return "BaseProcessor" -class BaseParallelProcessor(BaseProcessor, ProcessingStage[Task, Task]): +class BaseParallelProcessor(BaseProcessor): """ A processor that performs per-entry processing in parallel (using Dask or multiprocessing). diff --git a/sdp/processors/modify_manifest/create_manifest.py b/sdp/processors/modify_manifest/create_manifest.py index 323bdac0..02be56b3 100644 --- a/sdp/processors/modify_manifest/create_manifest.py +++ b/sdp/processors/modify_manifest/create_manifest.py @@ -26,30 +26,22 @@ ) -class SaveJsonl(BaseProcessor, ProcessingStage[Task, Task]): +class SaveJsonl(BaseProcessor): """ - Processor for creating an initial dataset manifest by saving filepaths with a common extension to the field specified in output_field. + Processor for saving tasks as a one JSONL file. Args: - raw_data_dir (str): The root directory of the files to be added to the initial manifest. This processor will recursively look for files with the extension 'extension' inside this directory. - output_file_key (str): The key to store the paths to the files in the dataset. - extension (str): The file extension of the of the files to be added to the manifest. - **kwargs: Additional keyword arguments to be passed to the base class `BaseParallelProcessor`. + **kwargs: Additional keyword arguments to be passed to the base class `BaseProcessor`. """ - name: str = "SaveManifest" - resources: Resources = Resources(cpus=1.0, gpu_memory_gb=10.0) - batch_size: int = 100000 - def __init__( self, **kwargs, ): super().__init__(**kwargs) - def setup(self, a): - # Path(self.output_manifest_file).touch() + def setup_on_node(self, _, __): open(self.output_manifest_file, 'w').close() def process(self, tasks: DataEntry) -> DataEntry: diff --git a/tests/test_curator.py b/tests/test_curator.py index 697d547c..adb0ecda 100644 --- a/tests/test_curator.py +++ b/tests/test_curator.py @@ -57,18 +57,17 @@ def _make_expected_output(): def test_curator(): - with tempfile.TemporaryDirectory() as tmpdir: - # output_path = os.path.join(tmpdir, "output_manifest_file.jsonl") - output_path = "/tmp/output_manifest_file.jsonl" - dict_conf = _make_dict(output_manifest_file=output_path, use_backend="curator") - conf_path = Path(tmpdir) / "config.yaml" - _write_config(conf_path, dict_conf) + tmpdir = tempfile.TemporaryDirectory() + output_path = os.path.join(tmpdir.name, "output_manifest_file.jsonl") + dict_conf = _make_dict(output_manifest_file=output_path, use_backend="curator") + conf_path = Path(tmpdir.name) / "config.yaml" + _write_config(conf_path, dict_conf) - cfg = OmegaConf.load(conf_path) + cfg = OmegaConf.load(conf_path) run_processors(cfg) output = load_manifest(output_path) - + tmpdir.cleanup() expected_output = _make_expected_output() assert output == expected_output, f"Expected {expected_output}, but got {output}" From 941abec82814f56eeb3fadcdc127bc70a8c478ac Mon Sep 17 00:00:00 2001 From: Nikolay Karpov Date: Tue, 22 Jul 2025 14:04:42 -0700 Subject: [PATCH 08/13] save if output_manifest_file not None Signed-off-by: Nikolay Karpov --- sdp/processors/base_processor.py | 11 ++++++++-- .../modify_manifest/create_manifest.py | 4 ++++ sdp/run_processors.py | 15 ++++++-------- tests/test_curator.py | 20 ++++++++++++++++--- 4 files changed, 36 insertions(+), 14 deletions(-) diff --git a/sdp/processors/base_processor.py b/sdp/processors/base_processor.py index 2f87d093..de3bdc7a 100644 --- a/sdp/processors/base_processor.py +++ b/sdp/processors/base_processor.py @@ -67,7 +67,9 @@ class BaseProcessor(ProcessingStage[Task, Task]): as ``input_manifest_file``. """ - def __init__(self, output_manifest_file: str, input_manifest_file: Optional[str] = None, **kwargs): + def __init__( + self, output_manifest_file: Optional[str] = None, input_manifest_file: Optional[str] = None, **kwargs + ): if output_manifest_file and input_manifest_file and (output_manifest_file == input_manifest_file): # we cannot have the same input and output manifest file specified because we need to be able to # read from the input_manifest_file and write to the output_manifest_file at the same time @@ -77,7 +79,7 @@ def __init__(self, output_manifest_file: str, input_manifest_file: Optional[str] self.input_manifest_file = input_manifest_file @abstractmethod - def process(self): + def process(self, tasks: Task) -> Task: """Should be overriden by the child classes to implement some data processing.""" pass @@ -217,6 +219,8 @@ def _process_with_dask(self, metrics): logger.info(f"Processed {total_entries} entries using Dask.") def _process_with_ray(self, metrics): + if self.output_manifest_file: + fout = open(self.output_manifest_file, "wt", encoding="utf8") tasks = [] for manifest_chunk in self._chunk_manifest(): for row in manifest_chunk: @@ -225,6 +229,9 @@ def _process_with_ray(self, metrics): metrics.append(data_entry.metrics) if data_entry.data is None: continue + if self.output_manifest_file: + json.dump(data_entry.data, fout, ensure_ascii=False) + fout.write("\n") self.number_of_entries += 1 self.total_duration += data_entry.data.get("duration", 0) tasks.extend(data) diff --git a/sdp/processors/modify_manifest/create_manifest.py b/sdp/processors/modify_manifest/create_manifest.py index 02be56b3..a56ce4ec 100644 --- a/sdp/processors/modify_manifest/create_manifest.py +++ b/sdp/processors/modify_manifest/create_manifest.py @@ -74,6 +74,10 @@ def __init__( self.output_file_key = output_file_key self.extension = extension + def setup_on_node(self, _, __): + if self.output_manifest_file: + open(self.output_manifest_file, 'w').close() + def read_manifest(self): # Get all files with the specified extension files = list(self.raw_data_dir.rglob('*.' + self.extension)) diff --git a/sdp/run_processors.py b/sdp/run_processors.py index 5496e33c..46707344 100644 --- a/sdp/run_processors.py +++ b/sdp/run_processors.py @@ -24,7 +24,7 @@ from omegaconf import OmegaConf, open_dict from ray_curator.backends.xenna import XennaExecutor from ray_curator.pipeline import Pipeline -from ray_curator.tasks import _EmptyTask +from ray_curator.tasks import EmptyTask, _EmptyTask from sdp.logging import logger from sdp.utils.import_manager import ImportManager @@ -150,8 +150,6 @@ def run_processors(cfg): try: import ray - ray.init() - ray_available = True except ImportError: logger.warning("Dask not installed; using multiprocessing for all processors") @@ -236,8 +234,8 @@ def run_processors(cfg): if any(p.use_backend for p in processors): try: num_cpus = psutil.cpu_count(logical=False) or 4 - logger.info(f"Starting Dask client with {num_cpus} workers") - backend_client = Client(n_workers=num_cpus, processes=True) + logger.info(f"Starting Ray client with {num_cpus} workers") + backend_client = ray.init() # Client(n_workers=num_cpus, processes=True) logger.info(f"Dask dashboard at: {backend_client.dashboard_link}") except Exception as e: logger.warning(f"Failed to start Dask client: {e}") @@ -252,17 +250,16 @@ def run_processors(cfg): pipeline.add_stage(stage) executor = XennaExecutor() - results = pipeline.run(executor) - # raise ValueError("results" + results) + pipeline.run(executor) else: + t = EmptyTask for proc in processors: if proc.use_backend == "dask" and backend_client is not None: proc.backend_client = backend_client logger.info('=> Running processor "%s" with Dask', proc) else: logger.info('=> Running processor "%s" with Multiprocessing', proc) - # raise ValueError("_EmptyTask_EmptyTask") - proc.process(_EmptyTask(task_id="empty", dataset_name="empty", data=None)) + t = proc.process(t) finally: if backend_client is not None: logger.info("Shutting down Dask client...") diff --git a/tests/test_curator.py b/tests/test_curator.py index adb0ecda..6b99a61e 100644 --- a/tests/test_curator.py +++ b/tests/test_curator.py @@ -42,9 +42,6 @@ def _make_dict(output_manifest_file, use_backend=None): "raw_data_dir": workspace_dir, "extension": "mp3", "output_file_key": "audio_filepath", - }, - { - "_target_": "sdp.processors.SaveJsonl", "output_manifest_file": output_manifest_file, }, ], @@ -71,3 +68,20 @@ def test_curator(): tmpdir.cleanup() expected_output = _make_expected_output() assert output == expected_output, f"Expected {expected_output}, but got {output}" + + +def test_multiprocessing(): + tmpdir = tempfile.TemporaryDirectory() + output_path = os.path.join(tmpdir.name, "output_manifest_file.jsonl") + dict_conf = _make_dict(output_manifest_file=output_path, use_backend=None) + conf_path = Path(tmpdir.name) / "config.yaml" + _write_config(conf_path, dict_conf) + + cfg = OmegaConf.load(conf_path) + + run_processors(cfg) + + output = load_manifest(output_path) + tmpdir.cleanup() + expected_output = _make_expected_output() + assert output == expected_output, f"Expected {expected_output}, but got {output}" From c32e77f4f75601832f1ac320dc14fa152aa7c328 Mon Sep 17 00:00:00 2001 From: Nikolay Karpov Date: Tue, 22 Jul 2025 14:30:30 -0700 Subject: [PATCH 09/13] dataset_name and task_id optional Signed-off-by: Nikolay Karpov --- sdp/processors/base_processor.py | 5 ++++- sdp/processors/modify_manifest/create_manifest.py | 1 - 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/sdp/processors/base_processor.py b/sdp/processors/base_processor.py index de3bdc7a..975349f8 100644 --- a/sdp/processors/base_processor.py +++ b/sdp/processors/base_processor.py @@ -35,7 +35,10 @@ class DataEntry(Task[list]): """A wrapper for data entry + any additional metrics.""" data: Optional[Dict] # can be None to drop the entry - metrics: Any = None + + def __init__(self, metrics: Any = None, dataset_name: str = "", task_id: int = 0, **kwargs): + self.metrics = metrics + super().__init__(task_id=task_id, dataset_name=dataset_name, **kwargs) @property def num_items(self) -> int: diff --git a/sdp/processors/modify_manifest/create_manifest.py b/sdp/processors/modify_manifest/create_manifest.py index a56ce4ec..c1905b96 100644 --- a/sdp/processors/modify_manifest/create_manifest.py +++ b/sdp/processors/modify_manifest/create_manifest.py @@ -89,7 +89,6 @@ def process_dataset_entry(self, data_entry): return [ DataEntry( data=data, - task_id=0, dataset_name=str(self.raw_data_dir / "*.") + self.extension, ) ] From 83262deb4e975652780665e2b1b1224cefc278bf Mon Sep 17 00:00:00 2001 From: Nikolay Karpov Date: Tue, 22 Jul 2025 17:53:51 -0700 Subject: [PATCH 10/13] fix CreateInitialManifestUzbekvoice Signed-off-by: Nikolay Karpov --- .../uzbekvoice/create_initial_manifest.py | 17 ++++++++--------- tests/test_curator.py | 2 +- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/sdp/processors/datasets/uzbekvoice/create_initial_manifest.py b/sdp/processors/datasets/uzbekvoice/create_initial_manifest.py index 78949dce..96ecb787 100644 --- a/sdp/processors/datasets/uzbekvoice/create_initial_manifest.py +++ b/sdp/processors/datasets/uzbekvoice/create_initial_manifest.py @@ -18,11 +18,12 @@ import typing import gdown -from ray_curator.tasks import DocumentBatch, EmptyTask, _EmptyTask +import pandas as pd +from ray_curator.tasks import DocumentBatch, EmptyTask, Task, _EmptyTask from sdp.logging import logger -from sdp.processors.base_processor import BaseProcessor -from sdp.utils.common import extract_archive +from sdp.processors.base_processor import BaseProcessor, DataEntry +from sdp.utils.common import extract_archive, load_manifest, save_manifest class CreateInitialManifestUzbekvoice(BaseProcessor): @@ -107,13 +108,11 @@ def process_transcript(self, file_path: str) -> list[dict[str, typing.Any]]: def process_data(self, data_folder: str, manifest_file: str) -> None: entries = self.process_transcript(os.path.join(data_folder, "uzbekvoice-dataset", "voice_dataset.json")) - if self.use_backend is None or self.use_backend == "dask": - with open(manifest_file, "w", encoding="utf-8") as fout: - for m in entries: - fout.write(json.dumps(m, ensure_ascii=False) + "\n") + if self.output_manifest_file: + save_manifest(entries, manifest_file) return entries - def process(self, task: _EmptyTask) -> DocumentBatch: + def process(self, _: Task) -> DataEntry: self.download_extract_files(self.raw_data_dir) entries = self.process_data(self.raw_data_dir, self.output_manifest_file) - return DocumentBatch(entries) + return entries diff --git a/tests/test_curator.py b/tests/test_curator.py index 6b99a61e..2b0685cb 100644 --- a/tests/test_curator.py +++ b/tests/test_curator.py @@ -23,7 +23,7 @@ from ray_curator.tasks import Task, _EmptyTask from sdp.run_processors import run_processors -from sdp.utils.common import load_manifest +from sdp.utils.common import load_manifest, save_manifest def _write_config(file_path: Path, dict_conf): From e1d2f0ae1f44251ebf56e6e649b3b39c48ed0c27 Mon Sep 17 00:00:00 2001 From: Nikolay Karpov Date: Fri, 25 Jul 2025 16:25:04 -0700 Subject: [PATCH 11/13] comment fix Signed-off-by: Nikolay Karpov --- requirements/curator.txt | 4 +- sdp/processors/nemo/asr_inference.py | 6 +- sdp/processors/nemo/transcribe_speech.py | 286 +++++++++++++++-------- 3 files changed, 197 insertions(+), 99 deletions(-) diff --git a/requirements/curator.txt b/requirements/curator.txt index d55c4609..33b874c9 100644 --- a/requirements/curator.txt +++ b/requirements/curator.txt @@ -3,6 +3,8 @@ cd ray-api # pip install cosmos-xenna[gpu] git clone https://github.com/NVIDIA-NeMo/Curator.git git switch ray-api +pip install . + # install NeMo pip install "nemo_toolkit[all]" @@ -11,7 +13,7 @@ pip install nemo_text_processing pip install -r requirements/main.txt pip install -r requirements/tests.txt -pip install . + RAY_ADDRESS=10.110.41.40:8265 python -m pytest tests/test_curator.py # pip install loguru diff --git a/sdp/processors/nemo/asr_inference.py b/sdp/processors/nemo/asr_inference.py index 634674c2..d5672191 100644 --- a/sdp/processors/nemo/asr_inference.py +++ b/sdp/processors/nemo/asr_inference.py @@ -53,7 +53,7 @@ def __init__( super().__init__(**kwargs) self.script_path = Path(__file__).parents[1] / "nemo" / "transcribe_speech.py" self.pretrained_model = pretrained_model - self.batch_size = batch_size + self.batch_size_asr = batch_size def process(self, task: _EmptyTask) -> _EmptyTask: """This will add "pred_text" key into the output manifest.""" @@ -64,7 +64,7 @@ def process(self, task: _EmptyTask) -> _EmptyTask: f"model_path={self.pretrained_model} " f"dataset_manifest={self.input_manifest_file} " f"output_filename={self.output_manifest_file} " - f"batch_size={self.batch_size} ", + f"batch_size={self.batch_size_asr} ", shell=True, check=True, ) @@ -74,7 +74,7 @@ def process(self, task: _EmptyTask) -> _EmptyTask: f"pretrained_name={self.pretrained_model} " f"dataset_manifest={self.input_manifest_file} " f"output_filename={self.output_manifest_file} " - f"batch_size={self.batch_size} ", + f"batch_size={self.batch_size_asr} ", shell=True, check=True, ) diff --git a/sdp/processors/nemo/transcribe_speech.py b/sdp/processors/nemo/transcribe_speech.py index bb04047b..05803923 100644 --- a/sdp/processors/nemo/transcribe_speech.py +++ b/sdp/processors/nemo/transcribe_speech.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,34 +12,43 @@ # See the License for the specific language governing permissions and # limitations under the License. -# This file is copied over from https://github.com/NVIDIA/NeMo/blob/v1.23.0/examples/asr/transcribe_speech.py. -# It is currently only compatible with NeMo v1.23.0. To use a different version of NeMo, please modify the file. +# This file is copied over from https://github.com/NVIDIA/NeMo/blob/r2.4.0/examples/asr/transcribe_speech.py. +# It is currently only compatible with NeMo r2.4.0. To use a different version of NeMo, please modify the file. -import contextlib +import json import os -from dataclasses import dataclass, is_dataclass +from dataclasses import dataclass, field, is_dataclass from typing import List, Optional, Union -import pytorch_lightning as pl +import lightning.pytorch as pl +import numpy as np import torch -from omegaconf import OmegaConf, open_dict - -from nemo.collections.asr.models import EncDecCTCModel, EncDecHybridRNNTCTCModel, EncDecMultiTaskModel +from nemo.collections.asr.models import ( + EncDecCTCModel, + EncDecHybridRNNTCTCModel, + EncDecRNNTModel, +) +from nemo.collections.asr.models.aed_multitask_models import parse_multitask_prompt from nemo.collections.asr.modules.conformer_encoder import ConformerChangeConfig from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecodingConfig -from nemo.collections.asr.parts.submodules.multitask_decoding import MultiTaskDecoding, MultiTaskDecodingConfig +from nemo.collections.asr.parts.submodules.multitask_decoding import ( + MultiTaskDecoding, + MultiTaskDecodingConfig, +) from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecodingConfig from nemo.collections.asr.parts.utils.eval_utils import cal_write_wer from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis from nemo.collections.asr.parts.utils.transcribe_utils import ( compute_output_filename, prepare_audio_data, + restore_transcription_order, setup_model, - transcribe_partial_audio, write_transcription, ) from nemo.core.config import hydra_runner from nemo.utils import logging +from nemo.utils.timers import SimpleTimer +from omegaconf import OmegaConf, open_dict """ Transcribe audio file on a single CPU/GPU. Useful for transcription of moderate amounts of audio data. @@ -48,21 +57,17 @@ model_path: path to .nemo ASR checkpoint pretrained_name: name of pretrained ASR model (from NGC registry) audio_dir: path to directory with audio files - dataset_manifest: path to dataset JSON manifest file (in NeMo format) - - compute_timestamps: Bool to request greedy time stamp information (if the model supports it) + dataset_manifest: path to dataset JSON manifest file (in NeMo formats compute_langs: Bool to request language ID information (if the model supports it) + timestamps: Bool to request greedy time stamp information (if the model supports it) by default None (Optionally: You can limit the type of timestamp computations using below overrides) - ctc_decoding.ctc_timestamp_type="all" # (default all, can be [all, char, word]) - rnnt_decoding.rnnt_timestamp_type="all" # (default all, can be [all, char, word]) - - (Optionally: You can limit the type of timestamp computations using below overrides) - ctc_decoding.ctc_timestamp_type="all" # (default all, can be [all, char, word]) - rnnt_decoding.rnnt_timestamp_type="all" # (default all, can be [all, char, word]) + ctc_decoding.ctc_timestamp_type="all" # (default all, can be [all, char, word, segment]) + rnnt_decoding.rnnt_timestamp_type="all" # (default all, can be [all, char, word, segment]) output_filename: Output filename where the transcriptions will be written batch_size: batch size during inference + presort_manifest: sorts the provided manifest by audio length for faster inference (default: True) cuda: Optional int to enable or disable execution of model on certain CUDA device. allow_mps: Bool to allow using MPS (Apple Silicon M-series GPU) device if available @@ -79,6 +84,8 @@ langid: Str used for convert_num_to_words during groundtruth cleaning use_cer: Bool to use Character Error Rate (CER) or Word Error Rate (WER) + calculate_rtfx: Bool to calculate the RTFx throughput to transcribe the input dataset. + # Usage ASR model can be specified by either "model_path" or "pretrained_name". Data for transcription can be defined with either "audio_dir" or "dataset_manifest". @@ -95,7 +102,7 @@ clean_groundtruth_text=True \ langid='en' \ batch_size=32 \ - compute_timestamps=False \ + timestamps=False \ compute_langs=False \ cuda=0 \ amp=True \ @@ -106,13 +113,19 @@ @dataclass class ModelChangeConfig: + """ + Sub-config for changes specific to the Conformer Encoder + """ - # Sub-config for changes specific to the Conformer Encoder - conformer: ConformerChangeConfig = ConformerChangeConfig() + conformer: ConformerChangeConfig = field(default_factory=ConformerChangeConfig) @dataclass class TranscriptionConfig: + """ + Transcription Configuration for audio to text transcription. + """ + # Required configs model_path: Optional[str] = None # Path to a .nemo file pretrained_name: Optional[str] = None # Name of a pretrained model @@ -123,6 +136,7 @@ class TranscriptionConfig: ] = None # Used to select a single channel from multichannel audio, or use average across channels audio_key: str = 'audio_filepath' # Used to override the default audio key in dataset_manifest eval_config_yaml: Optional[str] = None # Path to a yaml file of config of evaluation + presort_manifest: bool = True # Significant inference speedup on short-form data due to padding reduction # General configs output_filename: Optional[str] = None @@ -132,10 +146,11 @@ class TranscriptionConfig: pred_name_postfix: Optional[str] = None # If you need to use another model name, rather than standard one. random_seed: Optional[int] = None # seed number going to be used in seed_everything() - # Set to True to output greedy timestamp information (only supported models) - compute_timestamps: bool = False - # set to True if need to return full alignment information - preserve_alignment: bool = False + # Set to True to output greedy timestamp information (only supported models) and returns full alignment hypotheses + timestamps: Optional[bool] = None + + # Set to True to return hypotheses instead of text from the transcribe function + return_hypotheses: bool = False # Set to True to output language ID information compute_langs: bool = False @@ -147,19 +162,33 @@ class TranscriptionConfig: allow_mps: bool = False # allow to select MPS device (Apple Silicon M-series GPU) amp: bool = False amp_dtype: str = "float16" # can be set to "float16" or "bfloat16" when using amp + compute_dtype: Optional[ + str + ] = None # "float32", "bfloat16" or "float16"; if None (default): bfloat16 if available else float32 + matmul_precision: str = "high" # Literal["highest", "high", "medium"] audio_type: str = "wav" # Recompute model transcription, even if the output folder exists with scores. overwrite_transcripts: bool = True # Decoding strategy for CTC models - ctc_decoding: CTCDecodingConfig = CTCDecodingConfig() + ctc_decoding: CTCDecodingConfig = field(default_factory=CTCDecodingConfig) # Decoding strategy for RNNT models - rnnt_decoding: RNNTDecodingConfig = RNNTDecodingConfig(fused_batch_size=-1) + # enable CUDA graphs for transcription + rnnt_decoding: RNNTDecodingConfig = field(default_factory=lambda: RNNTDecodingConfig(fused_batch_size=-1)) # Decoding strategy for AED models - multitask_decoding: MultiTaskDecodingConfig = MultiTaskDecodingConfig() + multitask_decoding: MultiTaskDecodingConfig = field(default_factory=MultiTaskDecodingConfig) + # Prompt slots for prompted models, e.g. Canary-1B. Examples of acceptable prompt inputs: + # Implicit single-turn assuming default role='user' (works with Canary-1B) + # +prompt.source_lang=en +prompt.target_lang=es +prompt.task=asr +prompt.pnc=yes + # Explicit single-turn prompt: + # +prompt.role=user +prompt.slots.source_lang=en +prompt.slots.target_lang=es + # +prompt.slots.task=s2t_translation +prompt.slots.pnc=yes + # Explicit multi-turn prompt: + # +prompt.turns='[{role:user,slots:{source_lang:en,target_lang:es,task:asr,pnc:yes}}]' + prompt: dict = field(default_factory=dict) # decoder type: ctc or rnnt, can be used to switch between CTC and RNNT decoder for Hybrid RNNT/CTC models decoder_type: Optional[str] = None @@ -167,7 +196,7 @@ class TranscriptionConfig: att_context_size: Optional[list] = None # Use this for model-specific changes before transcription - model_change: ModelChangeConfig = ModelChangeConfig() + model_change: ModelChangeConfig = field(default_factory=ModelChangeConfig) # Config for word / character error rate calculation calculate_wer: bool = True @@ -179,20 +208,22 @@ class TranscriptionConfig: # if True, will also skip writing anything to the output file return_transcriptions: bool = False - # Set to False to return text instead of hypotheses from the transcribe function, so as to save memory - return_hypotheses: bool = True - # key for groundtruth text in manifest gt_text_attr_name: str = "text" + gt_lang_attr_name: str = "lang" + + extract_nbest: bool = False # Extract n-best hypotheses from the model - # Use model's transcribe() function instead of transcribe_partial_audio() by default - # Only use transcribe_partial_audio() when the audio is too long to fit in memory - # Your manifest input should have `offset` field to use transcribe_partial_audio() - allow_partial_transcribe: bool = False + calculate_rtfx: bool = False + warmup_steps: int = 0 # by default - no warmup + run_steps: int = 1 # by default - single run @hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig) def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis]]: + """ + Transcribes the input audio and can be used to infer with Encoder-Decoder models. + """ logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') for key in cfg: @@ -217,6 +248,7 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis logging.info(f"Will apply on-the-fly augmentation on samples during transcription: {augmentor} ") # setup GPU + torch.set_float32_matmul_precision(cfg.matmul_precision) if cfg.cuda is None: if torch.cuda.is_available(): device = [0] # use 0th CUDA device @@ -247,11 +279,29 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis asr_model.set_trainer(trainer) asr_model = asr_model.eval() + if (cfg.compute_dtype is not None and cfg.compute_dtype != "float32") and cfg.amp: + raise ValueError("amp=true is mutually exclusive with a compute_dtype other than float32") + + amp_dtype = torch.float16 if cfg.amp_dtype == "float16" else torch.bfloat16 + + compute_dtype: torch.dtype + if cfg.compute_dtype is None: + can_use_bfloat16 = (not cfg.amp) and map_location.type == "cuda" and torch.cuda.is_bf16_supported() + if can_use_bfloat16: + compute_dtype = torch.bfloat16 + else: + compute_dtype = torch.float32 + else: + assert cfg.compute_dtype in {"float32", "bfloat16", "float16"} + compute_dtype = getattr(torch, cfg.compute_dtype) + + asr_model.to(compute_dtype) + # we will adjust this flag if the model does not support it - compute_timestamps = cfg.compute_timestamps compute_langs = cfg.compute_langs - # has to be True if timestamps are required - preserve_alignment = True if cfg.compute_timestamps else cfg.preserve_alignment + + if cfg.timestamps: + cfg.return_hypotheses = True # Check whether model and decoder type match if isinstance(asr_model, EncDecCTCModel): @@ -260,7 +310,7 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis elif isinstance(asr_model, EncDecHybridRNNTCTCModel): if cfg.decoder_type and cfg.decoder_type not in ['ctc', 'rnnt']: raise ValueError('Hybrid model only support ctc or rnnt decoding!') - else: # rnnt model, there could be other models needs to be addressed. + elif isinstance(asr_model, EncDecRNNTModel): if cfg.decoder_type and cfg.decoder_type != 'rnnt': raise ValueError('RNNT model only support rnnt decoding!') @@ -271,7 +321,9 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis if hasattr(asr_model, 'change_decoding_strategy') and hasattr(asr_model, 'decoding'): if isinstance(asr_model.decoding, MultiTaskDecoding): cfg.multitask_decoding.compute_langs = cfg.compute_langs - cfg.multitask_decoding.preserve_alignments = cfg.preserve_alignment + if cfg.extract_nbest: + cfg.multitask_decoding.beam.return_best_hypothesis = False + cfg.return_hypotheses = True asr_model.change_decoding_strategy(cfg.multitask_decoding) elif cfg.decoder_type is not None: # TODO: Support compute_langs in CTC eventually @@ -279,9 +331,9 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis raise ValueError("CTC models do not support `compute_langs` at the moment") decoding_cfg = cfg.rnnt_decoding if cfg.decoder_type == 'rnnt' else cfg.ctc_decoding - decoding_cfg.compute_timestamps = cfg.compute_timestamps # both ctc and rnnt support it - if 'preserve_alignments' in decoding_cfg: - decoding_cfg.preserve_alignments = preserve_alignment + if cfg.extract_nbest: + decoding_cfg.beam.return_best_hypothesis = False + cfg.return_hypotheses = True if 'compute_langs' in decoding_cfg: decoding_cfg.compute_langs = cfg.compute_langs if hasattr(asr_model, 'cur_decoder'): @@ -291,17 +343,19 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis # Check if ctc or rnnt model elif hasattr(asr_model, 'joint'): # RNNT model + if cfg.extract_nbest: + cfg.rnnt_decoding.beam.return_best_hypothesis = False + cfg.return_hypotheses = True cfg.rnnt_decoding.fused_batch_size = -1 - cfg.rnnt_decoding.compute_timestamps = cfg.compute_timestamps cfg.rnnt_decoding.compute_langs = cfg.compute_langs - if 'preserve_alignments' in cfg.rnnt_decoding: - cfg.rnnt_decoding.preserve_alignments = preserve_alignment asr_model.change_decoding_strategy(cfg.rnnt_decoding) else: if cfg.compute_langs: raise ValueError("CTC models do not support `compute_langs` at the moment.") - cfg.ctc_decoding.compute_timestamps = cfg.compute_timestamps + if cfg.extract_nbest: + cfg.ctc_decoding.beam.return_best_hypothesis = False + cfg.return_hypotheses = True asr_model.change_decoding_strategy(cfg.ctc_decoding) @@ -311,31 +365,16 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis isinstance(asr_model, EncDecHybridRNNTCTCModel) and cfg.decoder_type == "ctc" ): cfg.decoding = cfg.ctc_decoding + elif isinstance(asr_model.decoding, MultiTaskDecoding): + cfg.decoding = cfg.multitask_decoding else: cfg.decoding = cfg.rnnt_decoding - if isinstance(asr_model, EncDecMultiTaskModel): - # Special case for EncDecMultiTaskModel, where the input manifest is directly passed into the model's transcribe() function - partial_audio = False - filepaths = cfg.dataset_manifest - assert cfg.dataset_manifest is not None - else: - # prepare audio filepaths and decide wether it's partial audio - filepaths, partial_audio = prepare_audio_data(cfg) + filepaths, sorted_manifest_path = prepare_audio_data(cfg) - if not cfg.allow_partial_transcribe: - # by defatul, use model's transcribe() function, unless partial audio is required - partial_audio = False - - # setup AMP (optional) - if cfg.amp and torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'): - logging.info("AMP enabled!\n") - autocast = torch.cuda.amp.autocast - else: + remove_path_after_done = sorted_manifest_path if sorted_manifest_path is not None else None - @contextlib.contextmanager - def autocast(dtype=None): - yield + filepaths = sorted_manifest_path if sorted_manifest_path is not None else filepaths # Compute output filename cfg = compute_output_filename(cfg, model_name) @@ -350,37 +389,82 @@ def autocast(dtype=None): # transcribe audio - amp_dtype = torch.float16 if cfg.amp_dtype == "float16" else torch.bfloat16 + if cfg.calculate_rtfx: + total_duration = 0.0 + + with open(cfg.dataset_manifest, "rt") as fh: + for line in fh: + item = json.loads(line) + if "duration" not in item: + raise ValueError( + f"Requested calculate_rtfx=True, but line {line} in manifest {cfg.dataset_manifest} \ + lacks a 'duration' field." + ) + total_duration += item["duration"] + + if cfg.warmup_steps == 0: + logging.warning( + "RTFx measurement enabled, but warmup_steps=0. " + "At least one warmup step is recommended to measure RTFx" + ) - with autocast(dtype=amp_dtype): + timer = SimpleTimer() + model_measurements = [] + with torch.amp.autocast('cuda' if torch.cuda.is_available() else 'cpu', dtype=amp_dtype, enabled=cfg.amp): with torch.no_grad(): - if partial_audio: - transcriptions = transcribe_partial_audio( - asr_model=asr_model, - path2manifest=cfg.dataset_manifest, - batch_size=cfg.batch_size, - num_workers=cfg.num_workers, - return_hypotheses=cfg.return_hypotheses, - channel_selector=cfg.channel_selector, - augmentor=augmentor, - decoder_type=cfg.decoder_type, - ) - else: + override_cfg = asr_model.get_transcribe_config() + override_cfg.batch_size = cfg.batch_size + override_cfg.num_workers = cfg.num_workers + override_cfg.return_hypotheses = cfg.return_hypotheses + override_cfg.channel_selector = cfg.channel_selector + override_cfg.augmentor = augmentor + override_cfg.text_field = cfg.gt_text_attr_name + override_cfg.lang_field = cfg.gt_lang_attr_name + override_cfg.timestamps = cfg.timestamps + if hasattr(override_cfg, "prompt"): + override_cfg.prompt = parse_multitask_prompt(OmegaConf.to_container(cfg.prompt)) + + device = next(asr_model.parameters()).device + for run_step in range(cfg.warmup_steps + cfg.run_steps): + if run_step < cfg.warmup_steps: + logging.info(f"Running warmup step {run_step}") + # reset timer + timer.reset() + timer.start(device=device) + # call transcribe transcriptions = asr_model.transcribe( - paths2audio_files=filepaths, - batch_size=cfg.batch_size, - num_workers=cfg.num_workers, - return_hypotheses=cfg.return_hypotheses, - channel_selector=cfg.channel_selector, - augmentor=augmentor, + audio=filepaths, + override_config=override_cfg, + timestamps=cfg.timestamps, ) + # stop timer, log time + timer.stop(device=device) + logging.info(f"Model time for iteration {run_step}: {timer.total_sec():.3f}") + if run_step >= cfg.warmup_steps: + model_measurements.append(timer.total_sec()) + + model_measurements_np = np.asarray(model_measurements) + logging.info( + f"Model time avg: {model_measurements_np.mean():.3f}" + + (f" (std: {model_measurements_np.std():.3f})" if cfg.run_steps > 1 else "") + ) - logging.info(f"Finished transcribing {len(filepaths)} files !") + if cfg.dataset_manifest is not None: + logging.info(f"Finished transcribing from manifest file: {cfg.dataset_manifest}") + if cfg.presort_manifest: + transcriptions = restore_transcription_order(cfg.dataset_manifest, transcriptions) + else: + logging.info(f"Finished transcribing {len(filepaths)} files !") logging.info(f"Writing transcriptions into file: {cfg.output_filename}") - # if transcriptions form a tuple (from RNNT), extract just "best" hypothesis + # if transcriptions form a tuple of (best_hypotheses, all_hypotheses) if type(transcriptions) == tuple and len(transcriptions) == 2: - transcriptions = transcriptions[0] + if cfg.extract_nbest: + # extract all hypotheses if exists + transcriptions = transcriptions[1] + else: + # extract just best hypothesis + transcriptions = transcriptions[0] if cfg.return_transcriptions: return transcriptions @@ -392,10 +476,15 @@ def autocast(dtype=None): model_name, filepaths=filepaths, compute_langs=compute_langs, - compute_timestamps=compute_timestamps, + timestamps=cfg.timestamps, ) logging.info(f"Finished writing predictions to {output_filename}!") + # clean-up + if cfg.presort_manifest is not None: + if remove_path_after_done is not None: + os.unlink(remove_path_after_done) + if cfg.calculate_wer: output_manifest_w_wer, total_res, _ = cal_write_wer( pred_manifest=output_filename, @@ -410,8 +499,15 @@ def autocast(dtype=None): logging.info(f"Writing prediction and error rate of each sample to {output_manifest_w_wer}!") logging.info(f"{total_res}") + if cfg.calculate_rtfx: + rtfx_measurements = total_duration / model_measurements_np + logging.info( + f"Model RTFx on the dataset: {rtfx_measurements.mean():.3f}" + + (f" (std: {rtfx_measurements.std():.3f})" if cfg.run_steps > 1 else "") + ) + return cfg if __name__ == '__main__': - main() # noqa pylint: disable=no-value-for-parameter \ No newline at end of file + main() # noqa pylint: disable=no-value-for-parameter From c83483a6b6c7e450d1acd96d5fdd712529201765 Mon Sep 17 00:00:00 2001 From: Nikolay Karpov Date: Fri, 25 Jul 2025 18:33:27 -0700 Subject: [PATCH 12/13] add DataEntry Signed-off-by: Nikolay Karpov --- sdp/processors/base_processor.py | 8 +++--- sdp/processors/modify_manifest/common.py | 33 +++++++++++++++++++----- sdp/processors/nemo/lid_inference.py | 10 ++++--- 3 files changed, 38 insertions(+), 13 deletions(-) diff --git a/sdp/processors/base_processor.py b/sdp/processors/base_processor.py index 51cc974a..5cad8cea 100644 --- a/sdp/processors/base_processor.py +++ b/sdp/processors/base_processor.py @@ -33,11 +33,10 @@ class DataEntry(Task[list]): """A wrapper for data entry + any additional metrics.""" - data: Optional[Dict] # can be None to drop the entry - - def __init__(self, metrics: Any = None, dataset_name: str = "", task_id: int = 0, **kwargs): + def __init__(self, data: Dict = None, metrics: Any = None, dataset_name: str = "", task_id: int = 0, **kwargs): + self.data = data # data can be None to drop the entry self.metrics = metrics - super().__init__(task_id=task_id, dataset_name=dataset_name, **kwargs) + super().__init__(data=data, task_id=task_id, dataset_name=dataset_name, **kwargs) @property def num_items(self) -> int: @@ -240,6 +239,7 @@ def _process_with_ray(self, metrics): return tasks def _process_with_multiprocessing(self, metrics): + data = [] with open(self.output_manifest_file, "wt", encoding="utf8") as fout: for manifest_chunk in self._chunk_manifest(): data = itertools.chain( diff --git a/sdp/processors/modify_manifest/common.py b/sdp/processors/modify_manifest/common.py index 4b44e536..2b2febc0 100644 --- a/sdp/processors/modify_manifest/common.py +++ b/sdp/processors/modify_manifest/common.py @@ -69,7 +69,7 @@ def __init__( self.arg_separator = arg_separator self.cmd = cmd - def process(self): + def process(self, tasks: DataEntry) -> DataEntry: os.makedirs(os.path.dirname(self.output_manifest_file), exist_ok=True) if self.cmd.find(self.input_manifest_file) != -1 or self.cmd.find(self.output_manifest_file) != -1: logger.error( @@ -92,6 +92,7 @@ def process(self): if self.output_manifest_arg: process_args.extend([self.output_manifest_arg + self.arg_separator + self.output_manifest_file]) subprocess.run(" ".join(process_args), shell=True) + return tasks class CombineSources(BaseParallelProcessor): @@ -454,14 +455,34 @@ def __init__( self.right_manifest_file = right_manifest_file self.column_id = column_id - def process(self): - m1 = pd.DataFrame.from_records(load_manifest(Path(self.left_manifest_file))) - m2 = pd.DataFrame.from_records(load_manifest(Path(self.right_manifest_file))) + def process(self, tasks: DataEntry, tasks2: DataEntry = None) -> DataEntry: + if self.left_manifest_file: + m1 = pd.DataFrame.from_records(load_manifest(Path(self.left_manifest_file))) + elif tasks: + logger.warning("batch_size should be as big as the dataset size") + m1 = tasks.toDataFrame() + else: + raise ValueError("tasks or self.input_manifest_file or self.left_manifest_file must be not None") + + if self.right_manifest_file: + m2 = pd.DataFrame.from_records(load_manifest(Path(self.right_manifest_file))) + elif tasks2: + logger.warning("batch_size should be as big as the dataset size") + m2 = tasks.toDataFrame() + else: + raise ValueError("tasks2 or self.right_manifest_file must be not None") + m3 = pd.merge(m1, m2, on=self.column_id, how="inner") - with open(self.output_manifest_file, "wt", encoding="utf8") as fout: - for _, line in m3.iterrows(): + if self.output_manifest_file: + fout = open(self.output_manifest_file, "wt", encoding="utf8") + items = [] + for _, line in m3.iterrows(): + item = DataEntry(data=dict(line)) + if self.output_manifest_file: fout.write(json.dumps(dict(line), ensure_ascii=False) + "\n") + items.append(item) + return items class DropSpecifiedFields(BaseProcessor): diff --git a/sdp/processors/nemo/lid_inference.py b/sdp/processors/nemo/lid_inference.py index 0f1b871f..69bd9456 100644 --- a/sdp/processors/nemo/lid_inference.py +++ b/sdp/processors/nemo/lid_inference.py @@ -5,7 +5,7 @@ from tqdm import tqdm from sdp.logging import logger -from sdp.processors.base_processor import BaseProcessor +from sdp.processors.base_processor import BaseProcessor, DataEntry from sdp.utils.common import load_manifest @@ -45,7 +45,7 @@ def __init__( self.random_seed = random_seed self.device = device - def process(self): + def process(self, tasks: DataEntry) -> DataEntry: import nemo.collections.asr as nemo_asr import torch # importing after nemo to make sure users first install nemo, instead of torch, then nemo @@ -59,7 +59,10 @@ def process(self): else: model = model.to(self.device) - manifest = load_manifest(Path(self.input_manifest_file)) + if self.input_manifest_file: + manifest = load_manifest(Path(self.input_manifest_file)) + else: + manifest = tasks.data Path(self.output_manifest_file).parent.mkdir(exist_ok=True, parents=True) with Path(self.output_manifest_file).open('w') as f: @@ -75,3 +78,4 @@ def process(self): if lang: item[self.output_lang_key] = lang f.write(json.dumps(item, ensure_ascii=False) + '\n') + return tasks From a49eb1fca7b3a38601c213adf83284ea8403fbf2 Mon Sep 17 00:00:00 2001 From: Nikolay Karpov Date: Fri, 25 Jul 2025 18:41:38 -0700 Subject: [PATCH 13/13] input EmptyTask Signed-off-by: Nikolay Karpov --- tests/test_modify_manifest.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_modify_manifest.py b/tests/test_modify_manifest.py index 99583c26..3542b54e 100644 --- a/tests/test_modify_manifest.py +++ b/tests/test_modify_manifest.py @@ -18,6 +18,7 @@ from typing import Dict, List, Union import pytest +from ray_curator.tasks import EmptyTask from sdp.processors import ApplyInnerJoin, DropNonAlphabet @@ -161,7 +162,7 @@ def test_apply_inner_join( output_manifest_file=manifest_out, ) - processor.process() + processor.process(EmptyTask) with open(manifest_out, "r") as f: output_lines = [json.loads(line) for line in f]