From e6dec85b6a02d1b3fa6ceb896a45cdf563ff25ad Mon Sep 17 00:00:00 2001 From: Steve Han Date: Wed, 29 Apr 2026 12:10:58 -0400 Subject: [PATCH 1/6] feat(data-designer-retrieval-sdg): add retrieval SDG plugin Adds a new Data Designer plugin for retriever synthetic data generation. The plugin provides: - A retrieval-sdg-dedup column generator that deduplicates QA pairs by embedding cosine similarity, registered via data_designer.plugins. - A four-column SDG pipeline (artifact extraction, QA generation, dedup, quality evaluation) accessible as build_qa_generation_pipeline(). - Conversion utilities for exporting raw SDG output to NeMo Retriever training format (train.json, val.json), BEIR evaluation format, and a parquet corpus with merlin metadata. - A data-designer-retrieval-sdg CLI with generate and convert subcommands. Refreshes auto-derived metadata (docs/catalog.md, .github/CODEOWNERS) and adds data_designer_retrieval_sdg to root pyproject.toml known-first-party. Local CI (lint, isolated-venv test, validate, check) is green; 47 plugin tests pass. Signed-off-by: Steve Han Made-with: Cursor --- .github/CODEOWNERS | 1 + docs/catalog.md | 1 + .../data-designer-retrieval-sdg/CODEOWNERS | 3 + plugins/data-designer-retrieval-sdg/README.md | 96 ++ .../pyproject.toml | 43 + .../data_designer_retrieval_sdg/__init__.py | 33 + .../src/data_designer_retrieval_sdg/cli.py | 360 ++++++ .../src/data_designer_retrieval_sdg/config.py | 46 + .../data_designer_retrieval_sdg/convert.py | 1017 +++++++++++++++++ .../src/data_designer_retrieval_sdg/dedup.py | 127 ++ .../src/data_designer_retrieval_sdg/ingest.py | 724 ++++++++++++ .../src/data_designer_retrieval_sdg/models.py | 128 +++ .../data_designer_retrieval_sdg/pipeline.py | 349 ++++++ .../src/data_designer_retrieval_sdg/plugin.py | 12 + .../postprocess.py | 375 ++++++ .../data_designer_retrieval_sdg/prompts.py | 298 +++++ .../tests/test_convert.py | 182 +++ .../tests/test_dedup.py | 74 ++ .../tests/test_ingest.py | 145 +++ .../tests/test_models.py | 81 ++ .../tests/test_plugin.py | 10 + .../tests/test_postprocess.py | 87 ++ pyproject.toml | 2 +- uv.lock | 46 +- 24 files changed, 4238 insertions(+), 2 deletions(-) create mode 100644 plugins/data-designer-retrieval-sdg/CODEOWNERS create mode 100644 plugins/data-designer-retrieval-sdg/README.md create mode 100644 plugins/data-designer-retrieval-sdg/pyproject.toml create mode 100644 plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/__init__.py create mode 100644 plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/cli.py create mode 100644 plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/config.py create mode 100644 plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/convert.py create mode 100644 plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/dedup.py create mode 100644 plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/ingest.py create mode 100644 plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/models.py create mode 100644 plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/pipeline.py create mode 100644 plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/plugin.py create mode 100644 plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/postprocess.py create mode 100644 plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/prompts.py create mode 100644 plugins/data-designer-retrieval-sdg/tests/test_convert.py create mode 100644 plugins/data-designer-retrieval-sdg/tests/test_dedup.py create mode 100644 plugins/data-designer-retrieval-sdg/tests/test_ingest.py create mode 100644 plugins/data-designer-retrieval-sdg/tests/test_models.py create mode 100644 plugins/data-designer-retrieval-sdg/tests/test_plugin.py create mode 100644 plugins/data-designer-retrieval-sdg/tests/test_postprocess.py diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index eb06565..135062f 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -7,4 +7,5 @@ /.github/ @NVIDIA-NeMo/data_designer_reviewers # Plugins +/plugins/data-designer-retrieval-sdg/ @NVIDIA-NeMo/data_designer_reviewers @shan-nvidia /plugins/data-designer-template/ @NVIDIA-NeMo/data_designer_reviewers diff --git a/docs/catalog.md b/docs/catalog.md index d3c1211..4d6d7b1 100644 --- a/docs/catalog.md +++ b/docs/catalog.md @@ -4,4 +4,5 @@ Auto-generated from plugin metadata. Do not edit manually. | Plugin | Version | Column Type | Description | |--------|---------|-------------|-------------| +| data-designer-retrieval-sdg | 0.1.0 | `retrieval-sdg-dedup` | Multi-step retriever SDG pipeline (artifact extraction, QA generation, dedup, evaluation) with Automodel-compatible data conversion; registers a retrieval-sdg-dedup column plugin | | data-designer-template | 0.1.0 | `text-transform` | Template Data Designer plugin — text transform column generator | diff --git a/plugins/data-designer-retrieval-sdg/CODEOWNERS b/plugins/data-designer-retrieval-sdg/CODEOWNERS new file mode 100644 index 0000000..3c525c6 --- /dev/null +++ b/plugins/data-designer-retrieval-sdg/CODEOWNERS @@ -0,0 +1,3 @@ +# Owner(s) of this plugin — used to generate the root CODEOWNERS file. +# GitHub accepts @username, @org/team, or email format. +* @NVIDIA-NeMo/data_designer_reviewers @shan-nvidia diff --git a/plugins/data-designer-retrieval-sdg/README.md b/plugins/data-designer-retrieval-sdg/README.md new file mode 100644 index 0000000..be20051 --- /dev/null +++ b/plugins/data-designer-retrieval-sdg/README.md @@ -0,0 +1,96 @@ +# data-designer-retrieval-sdg + +Data Designer plugin for **retriever synthetic data generation**. Generates +multi-hop QA pairs from text documents and converts them into training +formats compatible with [Automodel](https://github.com/NVIDIA-NeMo/Automodel) +retriever finetuning. + +## Features + +- **Retrieval-sdg-dedup column plugin** — embedding-based QA-pair deduplication + registered as a `data_designer.plugins` entry point. +- **Four-column SDG pipeline** — artifact extraction → QA generation → + deduplication → quality evaluation, all orchestrated via DataDesigner. +- **Data conversion** — convert raw SDG output to NeMo Retriever training + format (`train.json`, `val.json`), BEIR evaluation format, and corpus + parquet with `merlin_metadata.json`. +- **CLI** — `data-designer-retrieval-sdg generate` and + `data-designer-retrieval-sdg convert` subcommands. + +## Installation + +```bash +pip install data-designer-retrieval-sdg +``` + +Or, for development inside the monorepo: + +```bash +make sync # from the repo root +``` + +## Development setup + +When working inside the monorepo the CLI and library are installed into the +workspace virtual environment. Activate it before running commands: + +```bash +make sync # install all packages into .venv +source .venv/bin/activate # activate the virtual environment +``` + +Alternatively, prefix any command with `uv run` to execute inside the venv +without activating it: + +```bash +uv run data-designer-retrieval-sdg generate --help +``` + +## Quick start + +### Generate QA pairs + +```bash +data-designer-retrieval-sdg generate \ + --input-dir ./my_documents \ + --output-dir ./generated_output \ + --num-pairs 7 +``` + +### Convert to training format + +```bash +data-designer-retrieval-sdg convert ./generated_output \ + --corpus-id my_corpus +``` + +### Use as a library + +```python +from data_designer_retrieval_sdg import ( + build_qa_generation_pipeline, + load_text_files_from_directory, +) + +seed_df = load_text_files_from_directory(Path("./docs")) +config_builder = build_qa_generation_pipeline(seed_df) +``` + +## Plugin column type + +The package registers the `retrieval-sdg-dedup` column type. Use it in a +DataDesigner pipeline to deduplicate QA pairs by embedding cosine +similarity: + +```python +from data_designer_retrieval_sdg.config import RetrievalSdgDedupColumnConfig + +config_builder.add_column( + RetrievalSdgDedupColumnConfig( + name="deduplicated_qa_pairs", + qa_pairs_column="qa_generation", + embedding_alias="embed", + dedupe_similarity_threshold=0.9, + ) +) +``` diff --git a/plugins/data-designer-retrieval-sdg/pyproject.toml b/plugins/data-designer-retrieval-sdg/pyproject.toml new file mode 100644 index 0000000..42cf649 --- /dev/null +++ b/plugins/data-designer-retrieval-sdg/pyproject.toml @@ -0,0 +1,43 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +[project] +name = "data-designer-retrieval-sdg" +version = "0.1.0" +description = "Multi-step retriever SDG pipeline (artifact extraction, QA generation, dedup, evaluation) with Automodel-compatible data conversion; registers a retrieval-sdg-dedup column plugin" +requires-python = ">=3.10" +dependencies = [ + "data-designer>=0.5.7", + "nltk>=3.9.2", + "pyyaml>=6.0", + "pyarrow>=14.0", +] +license = "Apache-2.0" +readme = "README.md" +authors = [ + {name = "NVIDIA Corporation"}, +] +classifiers = [ + "Development Status :: 3 - Alpha", + "Programming Language :: Python :: 3", + "Topic :: Scientific/Engineering :: Artificial Intelligence", +] + +[project.entry-points."data_designer.plugins"] +retrieval-sdg-dedup = "data_designer_retrieval_sdg.plugin:plugin" + +[project.scripts] +data-designer-retrieval-sdg = "data_designer_retrieval_sdg.cli:main" + +[project.urls] +Repository = "https://github.com/NVIDIA-NeMo/DataDesignerPlugins" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["src/data_designer_retrieval_sdg"] + +[tool.ruff] +extend = "../../pyproject.toml" diff --git a/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/__init__.py b/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/__init__.py new file mode 100644 index 0000000..c187c03 --- /dev/null +++ b/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/__init__.py @@ -0,0 +1,33 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Data Designer plugin for retriever synthetic data generation. + +Provides a multi-step pipeline that generates QA pairs from text documents +for retriever finetuning, plus utilities for converting raw SDG output into +Automodel-compatible training formats. + +Public API: + +- :func:`build_qa_generation_pipeline` -- build the four-column DD pipeline +- :func:`load_text_files_from_directory` -- load and chunk text files +- :func:`postprocess_retriever_data` -- flatten to BEIR format +- :func:`filter_qa_pairs_by_quality` -- quality-based filtering +- :func:`load_positive_docs_with_modality` -- load BEIR docs with modality +""" + +from data_designer_retrieval_sdg.ingest import load_text_files_from_directory +from data_designer_retrieval_sdg.pipeline import build_qa_generation_pipeline +from data_designer_retrieval_sdg.postprocess import ( + filter_qa_pairs_by_quality, + load_positive_docs_with_modality, + postprocess_retriever_data, +) + +__all__ = [ + "build_qa_generation_pipeline", + "filter_qa_pairs_by_quality", + "load_positive_docs_with_modality", + "load_text_files_from_directory", + "postprocess_retriever_data", +] diff --git a/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/cli.py b/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/cli.py new file mode 100644 index 0000000..e298eed --- /dev/null +++ b/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/cli.py @@ -0,0 +1,360 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""CLI entry points for the data-designer-retrieval-sdg package. + +Provides two subcommands: +- ``generate`` -- run the full SDG pipeline on a directory of text files +- ``convert`` -- convert raw SDG output to Automodel-compatible formats +""" + +from __future__ import annotations + +import argparse +import sys +import time +from pathlib import Path + +import data_designer.config as dd +from data_designer.interface import DataDesigner +from data_designer.logging import LoggerConfig, LoggingConfig, OutputConfig, configure_logging + +from data_designer_retrieval_sdg.convert import run_conversion +from data_designer_retrieval_sdg.ingest import load_text_files_from_directory +from data_designer_retrieval_sdg.pipeline import build_model_providers, build_qa_generation_pipeline + + +def _format_duration(seconds: float) -> str: + """Format a duration in seconds to a human-readable string.""" + seconds = max(0, int(seconds)) + if seconds < 60: + return f"{seconds}s" + minutes, secs = divmod(seconds, 60) + if minutes < 60: + return f"{minutes}m {secs}s" + hours, minutes = divmod(minutes, 60) + return f"{hours}h {minutes}m" + + +# --------------------------------------------------------------------------- +# ``generate`` subcommand +# --------------------------------------------------------------------------- + + +def _add_generate_parser(subparsers: argparse._SubParsersAction) -> None: + """Register the ``generate`` subcommand.""" + p = subparsers.add_parser( + "generate", + help="Generate synthetic QA pairs from a directory of text files", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + p.add_argument("--input-dir", type=Path, required=True, help="Directory containing text files") + p.add_argument("--output-dir", type=Path, required=True, help="Directory to save generated output") + p.add_argument("--min-text-length", type=int, default=50, help="Minimum document text length") + p.add_argument("--sentences-per-chunk", type=int, default=5, help="Sentences per chunk") + p.add_argument("--num-sections", type=int, default=1, help="Sections to divide chunks into") + p.add_argument("--max-artifacts-per-type", type=int, default=2, help="Max artifacts per type") + p.add_argument("--num-pairs", type=int, default=7, help="QA pairs per document") + p.add_argument("--min-hops", type=int, default=2, help="Min hops for multi-hop questions") + p.add_argument("--max-hops", type=int, default=4, help="Max hops for multi-hop questions") + p.add_argument("--min-complexity", type=int, default=4, help="Min question complexity") + p.add_argument("--preview", action="store_true", help="Preview without full generation") + p.add_argument("--file-extensions", nargs="+", default=None, help="File extensions to include") + p.add_argument("--artifact-path", type=Path, default=Path("./artifacts"), help="DD artifact path") + p.add_argument("--num-files", type=int, default=None, help="Max files to process") + p.add_argument("--batch-size", type=int, default=200, help="Records per batch") + p.add_argument("--start-batch-index", type=int, default=0, help="Batch index to start from") + p.add_argument("--end-batch-index", type=int, default=-1, help="Batch index to end at (exclusive)") + + g = p.add_argument_group("multi-document bundling") + g.add_argument("--multi-doc", action="store_true", help="Enable multi-doc bundling") + g.add_argument("--bundle-size", type=int, default=2, help="Docs per bundle") + g.add_argument( + "--bundle-strategy", + choices=["sequential", "doc_balanced", "interleaved"], + default="sequential", + help="Segment splitting strategy", + ) + g.add_argument("--max-docs-per-bundle", type=int, default=3, help="Max docs per bundle") + g.add_argument("--multi-doc-manifest", type=Path, default=None, help="Manifest for explicit bundles") + + g = p.add_argument_group("logging") + g.add_argument("--log-level", choices=["DEBUG", "INFO", "WARNING", "ERROR"], default="INFO") + + g = p.add_argument_group("model configuration") + g.add_argument("--artifact-extraction-model", default="nvidia/nemotron-3-nano-30b-a3b") + g.add_argument("--artifact-extraction-provider", default="nvidia") + g.add_argument("--qa-generation-model", default="nvidia/nemotron-3-nano-30b-a3b") + g.add_argument("--qa-generation-provider", default="nvidia") + g.add_argument("--quality-judge-model", default="nvidia/nemotron-3-nano-30b-a3b") + g.add_argument("--quality-judge-provider", default="nvidia") + g.add_argument("--embed-model", default="nvidia/llama-3.2-nv-embedqa-1b-v2") + g.add_argument("--embed-provider", default="nvidia") + g.add_argument("--max-parallel-requests-for-gen", type=int, default=None) + + g = p.add_argument_group("custom provider") + g.add_argument("--custom-provider-endpoint", default=None, help="Base URL for custom provider") + g.add_argument("--custom-provider-name", default="custom") + g.add_argument("--custom-provider-type", default="openai") + g.add_argument("--custom-provider-api-key", default=None) + g.add_argument("--model-providers-file", type=Path, default=None, help="YAML/JSON providers file") + + p.set_defaults(func=_run_generate) + + +def _run_generate(args: argparse.Namespace) -> None: + """Execute the ``generate`` subcommand.""" + file_extensions = args.file_extensions or [".txt", ".md", ".text", ""] + + print(f"Loading text files from {args.input_dir}...") + if args.multi_doc: + print(f"Multi-doc mode enabled: bundle_size={args.bundle_size}, strategy={args.bundle_strategy}") + + text_files_df = load_text_files_from_directory( + input_dir=args.input_dir, + file_extensions=file_extensions, + min_text_length=args.min_text_length, + sentences_per_chunk=args.sentences_per_chunk, + num_sections=args.num_sections, + num_files=args.num_files, + multi_doc=args.multi_doc, + bundle_size=args.bundle_size, + bundle_strategy=args.bundle_strategy, + max_docs_per_bundle=args.max_docs_per_bundle, + multi_doc_manifest=args.multi_doc_manifest, + ) + + row_type = "bundles" if args.multi_doc else "text files" + print(f"\nLoaded {len(text_files_df)} {row_type}") + + configure_logging( + LoggingConfig( + logger_configs=[LoggerConfig(name="data_designer", level=args.log_level)], + output_configs=[OutputConfig(destination=sys.stderr, structured=(args.log_level == "DEBUG"))], + root_level=args.log_level, + ) + ) + + model_providers, custom_providers = build_model_providers( + custom_provider_endpoint=args.custom_provider_endpoint, + custom_provider_name=args.custom_provider_name, + custom_provider_type=args.custom_provider_type, + custom_provider_api_key=args.custom_provider_api_key, + model_providers_file=args.model_providers_file, + ) + + data_designer = DataDesigner(artifact_path=args.artifact_path, model_providers=model_providers) + data_designer.set_run_config(dd.RunConfig(disable_early_shutdown=True)) + + args.output_dir.mkdir(parents=True, exist_ok=True) + + total_records = len(text_files_df) + num_batches = (total_records + args.batch_size - 1) // args.batch_size + actual_end_batch = num_batches if args.end_batch_index == -1 else min(args.end_batch_index, num_batches) + + model_kwargs: dict = { + "max_parallel_requests_for_gen": args.max_parallel_requests_for_gen, + "artifact_extraction_model": args.artifact_extraction_model, + "artifact_extraction_provider": args.artifact_extraction_provider, + "qa_generation_model": args.qa_generation_model, + "qa_generation_provider": args.qa_generation_provider, + "quality_judge_model": args.quality_judge_model, + "quality_judge_provider": args.quality_judge_provider, + "embed_model": args.embed_model, + "embed_provider": args.embed_provider, + } + + _print_model_config(args, custom_providers) + + if args.preview: + _run_preview(data_designer, text_files_df, total_records, args, model_kwargs) + return + + _run_batches( + data_designer, + text_files_df, + total_records, + num_batches, + args.start_batch_index, + actual_end_batch, + args, + model_kwargs, + ) + + +def _print_model_config(args: argparse.Namespace, custom_providers: list) -> None: + """Print model configuration to stdout.""" + print("\nModel configuration:") + print(f" Artifact extraction: {args.artifact_extraction_model} ({args.artifact_extraction_provider})") + print(f" QA generation: {args.qa_generation_model} ({args.qa_generation_provider})") + print(f" Quality judge: {args.quality_judge_model} ({args.quality_judge_provider})") + print(f" Embedding: {args.embed_model} ({args.embed_provider})") + if custom_providers: + print("\nCustom model providers:") + for p in custom_providers: + print(f" {p.name}: {p.endpoint} (type={p.provider_type}, api_key={p.api_key or 'none'})") + + +def _run_preview( + data_designer: DataDesigner, + text_files_df: object, + total_records: int, + args: argparse.Namespace, + model_kwargs: dict, +) -> None: + """Run a single-record preview of the pipeline.""" + config_builder = build_qa_generation_pipeline( + seed_dataset=text_files_df, + start_index=0, + end_index=min(args.batch_size - 1, total_records - 1), + max_artifacts_per_type=args.max_artifacts_per_type, + num_pairs=args.num_pairs, + min_hops=args.min_hops, + max_hops=args.max_hops, + min_complexity=args.min_complexity, + **model_kwargs, + ) + print("\nPreviewing generation...") + try: + preview_result = data_designer.preview(config_builder, num_records=1) + preview_result.display_sample_record() + except Exception as e: + print(f"Preview error: {e}") + + +def _run_batches( + data_designer: DataDesigner, + text_files_df: object, + total_records: int, + num_batches: int, + start_batch: int, + end_batch: int, + args: argparse.Namespace, + model_kwargs: dict, +) -> None: + """Process the pipeline in batches.""" + total_batches_to_run = end_batch - start_batch + batch_times: list[float] = [] + + print(f"\nTotal records: {total_records}") + print(f"Batch size: {args.batch_size}") + print(f"Total batches: {num_batches}") + print(f"Starting from batch index: {start_batch}") + print(f"Ending at batch index: {end_batch} (exclusive)") + + for batch_idx in range(start_batch, end_batch): + start_idx = batch_idx * args.batch_size + end_idx = min(start_idx + args.batch_size - 1, total_records - 1) + num_in_batch = end_idx - start_idx + 1 + + print(f"\n{'=' * 60}") + print(f"Processing batch {batch_idx}/{num_batches - 1} (records {start_idx}-{end_idx})") + print(f"{'=' * 60}") + + batch_start = time.monotonic() + + config_builder = build_qa_generation_pipeline( + seed_dataset=text_files_df, + start_index=start_idx, + end_index=end_idx, + max_artifacts_per_type=args.max_artifacts_per_type, + num_pairs=args.num_pairs, + min_hops=args.min_hops, + max_hops=args.max_hops, + min_complexity=args.min_complexity, + **model_kwargs, + ) + + input_basename = args.input_dir.name + dataset_name = f"{input_basename}_batch{batch_idx}_{start_idx}_{end_idx}" + result = data_designer.create(config_builder, num_records=num_in_batch, dataset_name=dataset_name) + generated_df = result.load_dataset() + + output_filename = f"generated_batch{batch_idx}_{start_idx}_{end_idx}.json" + generated_df.to_json(args.output_dir / output_filename, orient="records", indent=2) + + batch_elapsed = time.monotonic() - batch_start + batch_times.append(batch_elapsed) + + batches_done = batch_idx - start_batch + 1 + batches_remaining = end_batch - batch_idx - 1 + + print(f"Batch {batch_idx}/{num_batches - 1} done in {_format_duration(batch_elapsed)}") + print(f" Saved to {output_filename} ({len(generated_df)} records)") + if batches_remaining > 0: + avg_time = sum(batch_times) / len(batch_times) + eta = avg_time * batches_remaining + print(f" Progress: {batches_done}/{total_batches_to_run} batches") + print(f" ETA: ~{_format_duration(eta)} remaining") + + print(f"\n{'=' * 60}") + print(f"Generation complete! All batches saved to {args.output_dir}") + print(f"Total batches processed: {end_batch - start_batch}") + + +# --------------------------------------------------------------------------- +# ``convert`` subcommand +# --------------------------------------------------------------------------- + + +def _add_convert_parser(subparsers: argparse._SubParsersAction) -> None: + """Register the ``convert`` subcommand.""" + p = subparsers.add_parser( + "convert", + help="Convert SDG output to retriever training/evaluation formats", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + p.add_argument("input_path", help="Path to JSON file or directory of batch files") + p.add_argument("--corpus-id", required=True, help="Corpus identifier") + p.add_argument("--output-dir", default=None, help="Output directory") + p.add_argument("--eval-only", action="store_true", help="BEIR eval only (no train/val)") + p.add_argument("--train-ratio", type=float, default=0.8, help="Training split ratio") + p.add_argument("--val-ratio", type=float, default=0.1, help="Validation split ratio") + p.add_argument("--seed", type=int, default=42, help="Random seed") + p.add_argument("--quality-threshold", type=float, default=7.0, help="Min quality score") + p.add_argument("--max-pos-docs", type=int, default=5, help="Max positive docs per query") + p.add_argument("--use-group-id-in-eval", action="store_true", help="Use group_id in qrels") + p.add_argument("--split-strategy", choices=["random", "dedupped", "cluster"], default="random") + p.add_argument("--groups-json", nargs="+", default=None, help="Dedup groups JSON paths") + + p.set_defaults(func=_run_convert) + + +def _run_convert(args: argparse.Namespace) -> None: + """Execute the ``convert`` subcommand.""" + run_conversion( + input_path=args.input_path, + corpus_id=args.corpus_id, + output_dir=args.output_dir, + eval_only=args.eval_only, + train_ratio=args.train_ratio, + val_ratio=args.val_ratio, + seed=args.seed, + quality_threshold=args.quality_threshold, + max_pos_docs=args.max_pos_docs, + use_group_id_in_eval=args.use_group_id_in_eval, + split_strategy=args.split_strategy, + groups_json=args.groups_json, + ) + + +# --------------------------------------------------------------------------- +# Main entry point +# --------------------------------------------------------------------------- + + +def main() -> None: + """CLI entry point for ``data-designer-retrieval-sdg``.""" + parser = argparse.ArgumentParser( + prog="data-designer-retrieval-sdg", + description="SDG Pipeline for Retriever Evaluation Dataset Generation", + ) + subparsers = parser.add_subparsers(dest="command", required=True) + + _add_generate_parser(subparsers) + _add_convert_parser(subparsers) + + args = parser.parse_args() + args.func(args) diff --git a/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/config.py b/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/config.py new file mode 100644 index 0000000..38ae6ad --- /dev/null +++ b/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/config.py @@ -0,0 +1,46 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Column configuration for the retrieval deduplication plugin.""" + +from __future__ import annotations + +from typing import Literal + +from data_designer.config.base import SingleColumnConfig + + +class RetrievalSdgDedupColumnConfig(SingleColumnConfig): + """Deduplicate QA pairs from a retrieval generation set via embedding similarity. + + This column reads QA pairs from a source column, embeds each question, + and removes near-duplicates whose cosine similarity exceeds a threshold. + + Args: + qa_pairs_column: Name of the upstream column containing QA pairs + with a ``pairs`` key. + embedding_alias: Model alias registered in the DataDesigner model + registry to use for computing embeddings. + column_type: Fixed literal identifying this column type. + dedupe_similarity_threshold: Cosine similarity threshold above which + two questions are considered duplicates. Defaults to ``0.9``. + """ + + qa_pairs_column: str + embedding_alias: str + column_type: Literal["retrieval-sdg-dedup"] = "retrieval-sdg-dedup" + dedupe_similarity_threshold: float = 0.9 + + @property + def required_columns(self) -> list[str]: + """Columns that must be present before this column can run.""" + return [self.qa_pairs_column] + + @property + def side_effect_columns(self) -> list[str]: + """Additional columns produced as side effects.""" + return [] + + def get_column_emoji(self) -> str: + """Emoji displayed in logs for this column type.""" + return "🔍" diff --git a/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/convert.py b/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/convert.py new file mode 100644 index 0000000..3ba7a8b --- /dev/null +++ b/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/convert.py @@ -0,0 +1,1017 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Convert raw SDG output to Automodel-compatible retriever training formats. + +Produces: +- ``train.json`` / ``val.json`` -- NeMo Retriever training format +- ``eval_beir/`` -- BEIR-compatible evaluation format +- ``corpus/`` -- parquet corpus + merlin metadata + +Supports random, dedupped (union-find merged), and cluster split strategies. +""" + +from __future__ import annotations + +import glob as glob_mod +import hashlib +import json +import os +import random +from collections import defaultdict + +import pandas as pd + +from data_designer_retrieval_sdg.postprocess import filter_qa_pairs_by_quality + +# --------------------------------------------------------------------------- +# Record loading +# --------------------------------------------------------------------------- + + +def filter_mismatched_records(records: list[dict]) -> tuple[list[dict], int]: + """Drop records where evaluation and pair counts disagree. + + Args: + records: Raw JSON records from the SDG pipeline. + + Returns: + Tuple of ``(filtered_records, dropped_count)``. + """ + filtered: list[dict] = [] + dropped_count = 0 + + for record in records: + qa_evals = record.get("qa_evaluations", {}).get("evaluations", []) + dedup_pairs = record.get("deduplicated_qa_pairs", []) + if len(qa_evals) == len(dedup_pairs): + filtered.append(record) + else: + dropped_count += 1 + file_name = record.get("file_name", "unknown") + display = file_name if isinstance(file_name, str) else ", ".join(file_name) if file_name else "unknown" + print( + f" Dropping record '{display}': " + f"qa_evaluations={len(qa_evals)}, deduplicated_qa_pairs={len(dedup_pairs)}" + ) + + return filtered, dropped_count + + +def normalize_file_name(file_name: object) -> list[str]: + """Normalise *file_name* to a list of strings. + + Provides backward compatibility for old data where ``file_name`` was a + plain string. + + Args: + file_name: String, list of strings, or other. + + Returns: + List of file-name strings. + """ + if isinstance(file_name, str): + return [file_name] + if isinstance(file_name, list): + return file_name + return [str(file_name)] + + +def load_generated_json_files(input_path: str) -> pd.DataFrame: + """Load generated JSON from a single file or a directory of batch files. + + Args: + input_path: Path to a merged JSON file **or** a directory containing + ``generated_batch*.json`` files. + + Returns: + Combined DataFrame with all records. + + Raises: + ValueError: If no JSON files are found. + """ + all_records: list[dict] = [] + + if os.path.isfile(input_path): + print(f"Loading single JSON file: {input_path}") + with open(input_path, encoding="utf-8") as f: + records = json.load(f) + if isinstance(records, list): + all_records.extend(records) + else: + all_records.append(records) + else: + json_files = sorted(glob_mod.glob(os.path.join(input_path, "generated_batch*.json"))) + if not json_files: + json_files = sorted(glob_mod.glob(os.path.join(input_path, "*.json"))) + if not json_files: + raise ValueError(f"No JSON files found in {input_path}") + + print(f"Found {len(json_files)} JSON files") + for json_file in json_files: + print(f" Loading: {json_file}") + with open(json_file, encoding="utf-8") as f: + records = json.load(f) + if isinstance(records, list): + all_records.extend(records) + else: + all_records.append(records) + + print("Normalizing file_name fields...") + for record in all_records: + if "file_name" in record: + record["file_name"] = normalize_file_name(record["file_name"]) + + print("Filtering mismatched records...") + all_records, dropped_count = filter_mismatched_records(all_records) + if dropped_count > 0: + print(f"Dropped {dropped_count} records with mismatched qa_evaluations/deduplicated_qa_pairs sizes") + + df = pd.DataFrame(all_records) + print(f"Loaded {len(df)} total records") + return df + + +# --------------------------------------------------------------------------- +# Corpus / chunk mapping +# --------------------------------------------------------------------------- + + +def get_corpus_id(text: str) -> str: + """Generate a hash-based corpus ID from text content. + + Args: + text: Document text. + + Returns: + ID in ``d_<16-hex-char>`` format. + """ + return "d_" + hashlib.sha256(text.encode()).hexdigest()[:16] + + +def extract_base_filename(file_path: str) -> str: + """Return the base filename without extension. + + Args: + file_path: Absolute or relative file path. + + Returns: + Filename stem. + """ + return os.path.splitext(os.path.basename(file_path))[0] + + +def get_file_identifier(file_name_list: list[str]) -> str: + """Derive a canonical identifier from a file-name list. + + Single-document bundles use the base filename; multi-document bundles + use a truncated hash of sorted paths. + + Args: + file_name_list: List of file names in the bundle. + + Returns: + String identifier for chunk-mapping lookups. + """ + if not file_name_list: + return "" + if len(file_name_list) == 1: + return extract_base_filename(file_name_list[0]) + return hashlib.md5("||".join(sorted(file_name_list)).encode()).hexdigest()[:16] + + +def build_corpus_and_mappings( + generated_df: pd.DataFrame, +) -> tuple[dict[str, str], dict[tuple[str, int], str]]: + """Build a deduplicated corpus and chunk-mapping from generated data. + + Args: + generated_df: DataFrame with ``file_name`` and ``chunks`` columns. + + Returns: + Tuple of ``(corpus, chunk_mapping)`` where *corpus* maps + ``text -> corpus_id`` and *chunk_mapping* maps + ``(file_identifier, chunk_id) -> text``. + """ + corpus: dict[str, str] = {} + chunk_mapping: dict[tuple[str, int], str] = {} + + print("Building corpus and chunk mappings...") + + for _, row in generated_df.iterrows(): + file_name_list = row.get("file_name", []) + chunks = row.get("chunks", []) + + if not chunks or not file_name_list: + continue + + file_identifier = get_file_identifier(file_name_list) + + if hasattr(chunks, "tolist"): + chunks = chunks.tolist() + + for chunk in chunks: + if isinstance(chunk, dict): + chunk_id = chunk.get("chunk_id") + text = chunk.get("text", "") + else: + chunk_id = getattr(chunk, "chunk_id", None) + text = getattr(chunk, "text", "") + + if chunk_id is None or not text: + continue + + chunk_mapping[(file_identifier, chunk_id)] = text + if text not in corpus: + corpus[text] = get_corpus_id(text) + + print(f"Built corpus with {len(corpus)} unique documents from {len(chunk_mapping)} total chunks") + return corpus, chunk_mapping + + +# --------------------------------------------------------------------------- +# Split strategies +# --------------------------------------------------------------------------- + + +def file_tuple_in_set(file_name: object, file_set: set[tuple[str, ...]]) -> bool: + """Check whether *file_name* (list or str) belongs to *file_set*. + + Args: + file_name: A list of strings or a single string. + file_set: Set of tuples to test membership against. + + Returns: + ``True`` when the normalised tuple is in *file_set*. + """ + file_tuple = tuple(file_name) if isinstance(file_name, list) else (file_name,) + return file_tuple in file_set + + +def create_train_val_test_split( + filtered_qa_df: pd.DataFrame, + train_ratio: float, + val_ratio: float, + seed: int, +) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: + """Randomly split QA pairs by file/bundle into train, val, and test. + + Args: + filtered_qa_df: DataFrame with filtered QA pairs. + train_ratio: Fraction of files for training. + val_ratio: Fraction of files for validation. + seed: Random seed. + + Returns: + ``(train_df, val_df, test_df)`` + + Raises: + ValueError: If ``train_ratio + val_ratio > 1.0``. + """ + random.seed(seed) + + test_ratio = 1.0 - train_ratio - val_ratio + if test_ratio < 0: + raise ValueError(f"train_ratio ({train_ratio}) + val_ratio ({val_ratio}) must be <= 1.0") + + unique_file_tuples = list({tuple(f) if isinstance(f, list) else (f,) for f in filtered_qa_df["file_name"]}) + random.shuffle(unique_file_tuples) + + n_train = int(len(unique_file_tuples) * train_ratio) + n_val = int(len(unique_file_tuples) * val_ratio) + + train_files = set(unique_file_tuples[:n_train]) + val_files = set(unique_file_tuples[n_train : n_train + n_val]) + test_files = set(unique_file_tuples[n_train + n_val :]) + + train_df = filtered_qa_df[filtered_qa_df["file_name"].apply(lambda f: file_tuple_in_set(f, train_files))] + val_df = filtered_qa_df[filtered_qa_df["file_name"].apply(lambda f: file_tuple_in_set(f, val_files))] + test_df = filtered_qa_df[filtered_qa_df["file_name"].apply(lambda f: file_tuple_in_set(f, test_files))] + + print( + f"Split: {len(train_files)} train files/bundles ({len(train_df)} QA pairs), " + f"{len(val_files)} val files/bundles ({len(val_df)} QA pairs), " + f"{len(test_files)} test files/bundles ({len(test_df)} QA pairs)" + ) + + return train_df, val_df, test_df + + +# --------------------------------------------------------------------------- +# Group-aware split helpers (dedupped / cluster) +# --------------------------------------------------------------------------- + + +class UnionFind: + """Disjoint-set / Union-Find with path compression and union by rank.""" + + def __init__(self) -> None: + self._parent: dict[str, str] = {} + self._rank: dict[str, int] = {} + + def find(self, x: str) -> str: + """Find the root representative of *x*.""" + if x not in self._parent: + self._parent[x] = x + self._rank[x] = 0 + if self._parent[x] != x: + self._parent[x] = self.find(self._parent[x]) + return self._parent[x] + + def union(self, x: str, y: str) -> None: + """Merge the sets containing *x* and *y*.""" + rx, ry = self.find(x), self.find(y) + if rx == ry: + return + if self._rank[rx] < self._rank[ry]: + rx, ry = ry, rx + self._parent[ry] = rx + if self._rank[rx] == self._rank[ry]: + self._rank[rx] += 1 + + +def load_dedup_groups(json_paths: list[str]) -> dict[str, list[str]]: + """Load groups/clusters from dedup_groups.json files. + + Auto-detects method keys (``exact``, ``fuzzy``, ``semantic``) and + extracts ``groups`` or ``clusters``. + + Args: + json_paths: Paths to dedup group JSON files. + + Returns: + Unified mapping of ``group_id -> [doc_id, ...]``. + """ + all_groups: dict[str, list[str]] = {} + + for path in json_paths: + print(f" Loading dedup groups from: {path}") + with open(path, encoding="utf-8") as f: + data = json.load(f) + + for method_key in ("exact", "fuzzy", "semantic"): + if method_key not in data: + continue + method_data = data[method_key] + groups = method_data.get("groups") or method_data.get("clusters", {}) + n_before = len(all_groups) + for group_id, doc_list in groups.items(): + all_groups[group_id] = doc_list + n_added = len(all_groups) - n_before + n_docs = sum(len(v) for v in groups.values()) + print(f" {method_key}: {n_added} groups, {n_docs} docs") + + print(f" Total loaded: {len(all_groups)} groups") + return all_groups + + +def merge_groups_union_find(all_groups: dict[str, list[str]]) -> dict[str, list[str]]: + """Transitively merge overlapping groups via Union-Find. + + Args: + all_groups: ``group_id -> [doc_id, ...]`` mapping. + + Returns: + Merged super-groups (only groups with 2+ members). + """ + uf = UnionFind() + + for doc_list in all_groups.values(): + if len(doc_list) < 2: + continue + anchor = doc_list[0] + for doc_id in doc_list[1:]: + uf.union(anchor, doc_id) + + all_docs: set[str] = set() + for doc_list in all_groups.values(): + all_docs.update(doc_list) + + components: dict[str, set[str]] = defaultdict(set) + for doc_id in all_docs: + root = uf.find(doc_id) + components[root].add(doc_id) + + merged: dict[str, list[str]] = {} + for i, (_, members) in enumerate(sorted(components.items(), key=lambda x: -len(x[1])), 1): + if len(members) >= 2: + merged[f"merged_{i:04d}"] = sorted(members) + + total_docs = sum(len(v) for v in merged.values()) + print(f" Merged into {len(merged)} super-groups covering {total_docs} docs (from {len(all_groups)} input groups)") + return merged + + +def build_file_to_group_mapping( + groups: dict[str, list[str]], + qa_file_names: set[str], +) -> dict[str, str]: + """Map QA file names to group IDs with fallback matching. + + Matching order: exact string, strip extension, basename. + + Args: + groups: ``group_id -> [doc_id, ...]``. + qa_file_names: Set of individual file paths from the QA DataFrame. + + Returns: + Mapping of ``file_name -> group_id`` (only matched files). + """ + doc_to_group: dict[str, str] = {} + for group_id, doc_list in groups.items(): + for doc_id in doc_list: + doc_to_group[doc_id] = group_id + + noext_to_doc = {os.path.splitext(d)[0]: d for d in doc_to_group} + basename_to_doc = {extract_base_filename(d): d for d in doc_to_group} + + file_to_group: dict[str, str] = {} + matched = 0 + unmatched = 0 + + for fname in qa_file_names: + if fname in doc_to_group: + file_to_group[fname] = doc_to_group[fname] + matched += 1 + continue + + fname_noext = os.path.splitext(fname)[0] + if fname_noext in doc_to_group: + file_to_group[fname] = doc_to_group[fname_noext] + matched += 1 + continue + if fname_noext in noext_to_doc: + file_to_group[fname] = doc_to_group[noext_to_doc[fname_noext]] + matched += 1 + continue + + bn = extract_base_filename(fname) + if bn in basename_to_doc: + file_to_group[fname] = doc_to_group[basename_to_doc[bn]] + matched += 1 + continue + + unmatched += 1 + + print(f" File matching: {matched} matched, {unmatched} unmatched (out of {len(qa_file_names)} QA files)") + return file_to_group + + +def create_group_aware_split( + filtered_qa_df: pd.DataFrame, + file_to_group: dict[str, str], + train_ratio: float, + val_ratio: float, + seed: int, +) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: + """Split QA pairs into train/val/test respecting group boundaries. + + Uses greedy bin-packing sorted by weight (QA-pair count) descending. + + Args: + filtered_qa_df: DataFrame with filtered QA pairs. + file_to_group: Mapping from individual file paths to group IDs. + train_ratio: Target ratio for training. + val_ratio: Target ratio for validation. + seed: Random seed. + + Returns: + ``(train_df, val_df, test_df)`` + + Raises: + ValueError: If ``train_ratio + val_ratio > 1.0``. + """ + random.seed(seed) + + test_ratio = 1.0 - train_ratio - val_ratio + if test_ratio < 0: + raise ValueError(f"train_ratio ({train_ratio}) + val_ratio ({val_ratio}) must be <= 1.0") + + unique_file_tuples = list({tuple(f) if isinstance(f, list) else (f,) for f in filtered_qa_df["file_name"]}) + + file_tuple_counts: dict[tuple[str, ...], int] = {} + for ft in unique_file_tuples: + mask = filtered_qa_df["file_name"].apply(lambda f, _ft=ft: (tuple(f) if isinstance(f, list) else (f,)) == _ft) + file_tuple_counts[ft] = int(mask.sum()) + + group_to_file_tuples: dict[str, list[tuple[str, ...]]] = defaultdict(list) + singleton_file_tuples: list[tuple[str, ...]] = [] + + for ft in unique_file_tuples: + matched_group = None + for fname in ft: + if fname in file_to_group: + matched_group = file_to_group[fname] + break + if matched_group is not None: + group_to_file_tuples[matched_group].append(ft) + else: + singleton_file_tuples.append(ft) + + units: list[tuple[str, list[tuple[str, ...]], int]] = [] + for group_id, file_tuples in group_to_file_tuples.items(): + weight = sum(file_tuple_counts[ft] for ft in file_tuples) + units.append((group_id, file_tuples, weight)) + for ft in singleton_file_tuples: + units.append((f"singleton_{ft}", [ft], file_tuple_counts[ft])) + + random.shuffle(units) + units.sort(key=lambda x: -x[2]) + + total_qa = sum(u[2] for u in units) + targets = {"train": total_qa * train_ratio, "val": total_qa * val_ratio, "test": total_qa * test_ratio} + current: dict[str, int] = {"train": 0, "val": 0, "test": 0} + split_assignments: dict[str, set[tuple[str, ...]]] = {"train": set(), "val": set(), "test": set()} + + for _, file_tuples, weight in units: + deficits = {s: targets[s] - current[s] for s in targets} + best_split = max(deficits, key=deficits.get) # type: ignore[arg-type] + for ft in file_tuples: + split_assignments[best_split].add(ft) + current[best_split] += weight + + train_df = filtered_qa_df[ + filtered_qa_df["file_name"].apply(lambda f: file_tuple_in_set(f, split_assignments["train"])) + ] + val_df = filtered_qa_df[filtered_qa_df["file_name"].apply(lambda f: file_tuple_in_set(f, split_assignments["val"]))] + test_df = filtered_qa_df[ + filtered_qa_df["file_name"].apply(lambda f: file_tuple_in_set(f, split_assignments["test"])) + ] + + n_groups = len(group_to_file_tuples) + n_singletons = len(singleton_file_tuples) + print(f" Groups: {n_groups} multi-file groups, {n_singletons} singletons") + print( + f" Split: train={len(train_df)} QA pairs ({len(split_assignments['train'])} files), " + f"val={len(val_df)} ({len(split_assignments['val'])} files), " + f"test={len(test_df)} ({len(split_assignments['test'])} files)" + ) + if total_qa > 0: + print( + f" Actual ratios: train={len(train_df) / total_qa:.3f}, " + f"val={len(val_df) / total_qa:.3f}, test={len(test_df) / total_qa:.3f}" + ) + + return train_df, val_df, test_df + + +# --------------------------------------------------------------------------- +# Output generation +# --------------------------------------------------------------------------- + + +def generate_training_set( + corpus: dict[str, str], + chunk_mapping: dict[tuple[str, int], str], + train_df: pd.DataFrame, + output_dir: str, + corpus_id: str, + max_pos_docs: int = 5, + output_filename: str = "train.json", + set_name: str = "training", + write_corpus: bool = True, +) -> None: + """Generate a training/validation set in NeMo Retriever format. + + Args: + corpus: ``text -> corpus_id`` mapping. + chunk_mapping: ``(file_identifier, chunk_id) -> text`` mapping. + train_df: DataFrame with QA pairs for this split. + output_dir: Output directory path. + corpus_id: Corpus identifier string. + max_pos_docs: Maximum positive docs per query. + output_filename: Name of the output JSON file. + set_name: Label for log messages (e.g. ``"training"``). + write_corpus: Whether to write corpus parquet and metadata. + """ + print(f"Generating {set_name} set...") + + corpus_dir = os.path.join(output_dir, "corpus") + os.makedirs(corpus_dir, exist_ok=True) + + training_data: list[dict] = [] + question_counter = 0 + skipped_queries = 0 + skipped_too_many_pos = 0 + + for _, qa_pair in train_df.iterrows(): + file_name_list = qa_pair.get("file_name", []) + file_identifier = get_file_identifier(file_name_list) if file_name_list else "" + segment_ids = qa_pair.get("segment_ids", []) + question = qa_pair.get("question", "") + + if not question: + skipped_queries += 1 + continue + + if hasattr(segment_ids, "tolist"): + segment_ids = segment_ids.tolist() + + if len(segment_ids) > max_pos_docs: + skipped_too_many_pos += 1 + continue + + pos_docs: list[dict] = [] + all_segments_exist = True + for segment_id in segment_ids: + key = (file_identifier, segment_id) + if key not in chunk_mapping: + all_segments_exist = False + break + text = chunk_mapping[key] + pos_docs.append({"id": corpus[text]}) + + if not all_segments_exist or not pos_docs: + skipped_queries += 1 + continue + + training_data.append( + { + "question_id": f"q{question_counter}", + "question": question, + "corpus_id": corpus_id, + "pos_doc": pos_docs, + "neg_doc": [], + } + ) + question_counter += 1 + + print(f" Generated {len(training_data)} {set_name} queries") + if skipped_queries > 0: + print(f" Skipped {skipped_queries} queries (missing segments or empty question)") + if skipped_too_many_pos > 0: + print(f" Skipped {skipped_too_many_pos} queries (exceeded max_pos_docs={max_pos_docs})") + + train_json_path = os.path.join(output_dir, output_filename) + with open(train_json_path, "w", encoding="utf-8") as f: + json.dump({"corpus": {"path": "./corpus/"}, "data": training_data}, f, indent=2, sort_keys=False) + print(f" Wrote {train_json_path}") + + if write_corpus: + corpus_list = [{"id": doc_id, "text": text} for text, doc_id in corpus.items()] + corpus_df = pd.DataFrame(corpus_list) + parquet_path = os.path.join(corpus_dir, "train.parquet") + corpus_df.to_parquet(parquet_path, index=False) + print(f" Wrote {parquet_path} with {len(corpus_list)} documents") + + metadata_path = os.path.join(corpus_dir, "merlin_metadata.json") + with open(metadata_path, "w", encoding="utf-8") as f: + json.dump({"corpus_id": corpus_id, "class": "TextQADataset"}, f, indent=2, sort_keys=False) + print(f" Wrote {metadata_path}") + + +def generate_eval_set( + corpus: dict[str, str], + chunk_mapping: dict[tuple[str, int], str], + eval_df: pd.DataFrame, + output_dir: str, + max_pos_docs: int = 5, + eval_only: bool = False, + use_group_id_in_eval: bool = False, +) -> None: + """Generate an evaluation set in BEIR format. + + Args: + corpus: ``text -> corpus_id`` mapping. + chunk_mapping: ``(file_identifier, chunk_id) -> text`` mapping. + eval_df: DataFrame with QA pairs for evaluation. + output_dir: Output directory path. + max_pos_docs: Maximum positive docs per query. + eval_only: If ``True`` write directly to *output_dir* instead of + an ``eval_beir/`` sub-directory. + use_group_id_in_eval: Use hash-based group ID in qrels instead of + sequential BEIR IDs. + """ + print("Generating evaluation set...") + + eval_dir = output_dir if eval_only else os.path.join(output_dir, "eval_beir") + os.makedirs(eval_dir, exist_ok=True) + + corpus_path = os.path.join(eval_dir, "corpus.jsonl") + corpus_id_counter = 0 + text_to_beir_id: dict[str, str] = {} + + with open(corpus_path, "w", encoding="utf-8") as corpus_file: + for text, hash_id in corpus.items(): + beir_id = f"d{corpus_id_counter}" + text_to_beir_id[text] = beir_id + + corpus_entry: dict = {"_id": beir_id, "metadata": {}, "text": text, "title": ""} + if use_group_id_in_eval: + corpus_entry["group_id"] = hash_id + corpus_file.write(json.dumps(corpus_entry) + "\n") + corpus_id_counter += 1 + + print(f" Wrote {corpus_path} with {corpus_id_counter} documents") + + queries_path = os.path.join(eval_dir, "queries.jsonl") + query_mappings: list[tuple[str, str, list]] = [] + query_counter = 0 + skipped_queries = 0 + skipped_too_many_pos = 0 + + with open(queries_path, "w", encoding="utf-8") as queries_file: + for _, qa_pair in eval_df.iterrows(): + file_name_list = qa_pair.get("file_name", []) + file_identifier = get_file_identifier(file_name_list) if file_name_list else "" + segment_ids = qa_pair.get("segment_ids", []) + question = qa_pair.get("question", "") + + if not question: + skipped_queries += 1 + continue + + if hasattr(segment_ids, "tolist"): + segment_ids = segment_ids.tolist() + + if len(segment_ids) > max_pos_docs: + skipped_too_many_pos += 1 + continue + + all_segments_exist = True + for segment_id in segment_ids: + key = (file_identifier, segment_id) + if key not in chunk_mapping: + all_segments_exist = False + break + + if not all_segments_exist: + skipped_queries += 1 + continue + + query_id = f"q{query_counter}" + query_mappings.append((query_id, file_identifier, segment_ids)) + + metadata: dict = {} + for field in ( + "query_type", + "reasoning_type", + "hop_count", + "question_complexity", + "quality_score", + "answer", + "hop_contexts", + ): + val = qa_pair.get(field) + if val is not None: + if hasattr(val, "tolist"): + val = val.tolist() + metadata[field] = val + + metadata["file_name"] = file_name_list + metadata["segment_ids"] = segment_ids + + query_entry = {"_id": query_id, "metadata": metadata, "text": question} + queries_file.write(json.dumps(query_entry) + "\n") + query_counter += 1 + + print(f" Wrote {queries_path} with {query_counter} queries") + if skipped_queries > 0: + print(f" Skipped {skipped_queries} queries (missing segments or empty question)") + if skipped_too_many_pos > 0: + print(f" Skipped {skipped_too_many_pos} queries (exceeded max_pos_docs={max_pos_docs})") + + qrels_dir = os.path.join(eval_dir, "qrels") + os.makedirs(qrels_dir, exist_ok=True) + + qrels_path = os.path.join(qrels_dir, "test.tsv") + qrels_count = 0 + + with open(qrels_path, "w", encoding="utf-8") as qrels_file: + qrels_file.write("query-id\tcorpus-id\tscore\n") + for query_id, file_identifier, segment_ids in query_mappings: + for segment_id in segment_ids: + key = (file_identifier, segment_id) + text = chunk_mapping[key] + if use_group_id_in_eval: + doc_id = corpus[text] + else: + doc_id = text_to_beir_id[text] + qrels_file.write(f"{query_id}\t{doc_id}\t1\n") + qrels_count += 1 + + id_type = "group_id" if use_group_id_in_eval else "_id" + print(f" Wrote {qrels_path} with {qrels_count} mappings (using {id_type})") + + +# --------------------------------------------------------------------------- +# Top-level conversion orchestrator +# --------------------------------------------------------------------------- + + +def run_conversion( + input_path: str, + corpus_id: str, + output_dir: str | None = None, + eval_only: bool = False, + train_ratio: float = 0.8, + val_ratio: float = 0.1, + seed: int = 42, + quality_threshold: float = 7.0, + max_pos_docs: int = 5, + use_group_id_in_eval: bool = False, + split_strategy: str = "random", + groups_json: list[str] | None = None, +) -> None: + """Run the full SDG-to-retriever-data conversion pipeline. + + Args: + input_path: Path to a merged JSON file or directory of batch files. + corpus_id: Corpus identifier. + output_dir: Output directory (auto-derived if ``None``). + eval_only: Generate only BEIR evaluation data. + train_ratio: Training split ratio. + val_ratio: Validation split ratio. + seed: Random seed. + quality_threshold: Minimum quality score. + max_pos_docs: Maximum positive docs per query. + use_group_id_in_eval: Use hash-based group IDs in eval qrels. + split_strategy: ``"random"``, ``"dedupped"``, or ``"cluster"``. + groups_json: Paths to dedup group JSON files. + """ + abs_input = os.path.abspath(input_path) + if not os.path.exists(abs_input): + raise ValueError(f"Input path does not exist: {abs_input}") + + if output_dir is None: + suffix = "_eval" if eval_only else "_train_eval" + if os.path.isfile(abs_input): + input_basename = os.path.splitext(os.path.basename(abs_input))[0] + output_dir = os.path.join(os.path.dirname(abs_input), f"{input_basename}{suffix}") + else: + output_dir = os.path.abspath(abs_input.rstrip("/") + suffix) + else: + output_dir = os.path.abspath(output_dir) + os.makedirs(output_dir, exist_ok=True) + + _print_conversion_header( + abs_input, + output_dir, + corpus_id, + eval_only, + train_ratio, + val_ratio, + split_strategy, + groups_json, + seed, + quality_threshold, + max_pos_docs, + use_group_id_in_eval, + ) + + generated_df = load_generated_json_files(abs_input) + corpus, chunk_mapping = build_corpus_and_mappings(generated_df) + filtered_qa_df, skipped_files = filter_qa_pairs_by_quality(generated_df, quality_threshold) + + if eval_only: + generate_eval_set( + corpus, + chunk_mapping, + filtered_qa_df, + output_dir, + max_pos_docs, + eval_only=True, + use_group_id_in_eval=use_group_id_in_eval, + ) + else: + train_df, val_df, test_df = _compute_split( + filtered_qa_df, + train_ratio, + val_ratio, + seed, + split_strategy, + groups_json, + ) + generate_training_set( + corpus, + chunk_mapping, + train_df, + output_dir, + corpus_id, + max_pos_docs, + output_filename="train.json", + set_name="training", + ) + generate_training_set( + corpus, + chunk_mapping, + val_df, + output_dir, + corpus_id, + max_pos_docs, + output_filename="val.json", + set_name="validation", + write_corpus=False, + ) + generate_eval_set( + corpus, + chunk_mapping, + test_df, + output_dir, + max_pos_docs, + eval_only=False, + use_group_id_in_eval=use_group_id_in_eval, + ) + + _print_conversion_footer(output_dir, eval_only, skipped_files) + + +# --------------------------------------------------------------------------- +# Conversion internal helpers +# --------------------------------------------------------------------------- + + +def _compute_split( + filtered_qa_df: pd.DataFrame, + train_ratio: float, + val_ratio: float, + seed: int, + split_strategy: str, + groups_json: list[str] | None, +) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: + """Route to the correct split strategy.""" + if split_strategy == "random": + return create_train_val_test_split(filtered_qa_df, train_ratio, val_ratio, seed) + + if not groups_json: + raise ValueError(f"--groups-json is required when split_strategy={split_strategy}") + + groups = load_dedup_groups(groups_json) + if split_strategy == "dedupped": + groups = merge_groups_union_find(groups) + + qa_file_names: set[str] = set() + for fnames in filtered_qa_df["file_name"]: + if isinstance(fnames, list): + qa_file_names.update(fnames) + else: + qa_file_names.add(fnames) + + ftg = build_file_to_group_mapping(groups, qa_file_names) + return create_group_aware_split(filtered_qa_df, ftg, train_ratio, val_ratio, seed) + + +def _print_conversion_header( + input_path: str, + output_dir: str, + corpus_id: str, + eval_only: bool, + train_ratio: float, + val_ratio: float, + split_strategy: str, + groups_json: list[str] | None, + seed: int, + quality_threshold: float, + max_pos_docs: int, + use_group_id_in_eval: bool, +) -> None: + """Print a banner with the conversion settings.""" + print("=" * 80) + print("SDG to Retriever Data Converter") + print("=" * 80) + print(f"Input path: {input_path}") + print(f"Output directory: {output_dir}") + print(f"Corpus ID: {corpus_id}") + if eval_only: + print("Mode: Evaluation only (BEIR format)") + else: + test_ratio = 1.0 - train_ratio - val_ratio + print("Mode: Train/Val/Test split") + print(f"Split strategy: {split_strategy}") + print(f"Split ratios: train={train_ratio}, val={val_ratio}, test={test_ratio:.2f}") + if groups_json: + for gj in groups_json: + print(f" Groups JSON: {gj}") + print(f"Random seed: {seed}") + print(f"Quality threshold: {quality_threshold}") + print(f"Max positive docs: {max_pos_docs}") + print(f"Eval qrels ID type: {'group_id' if use_group_id_in_eval else '_id'}") + print() + + +def _print_conversion_footer(output_dir: str, eval_only: bool, skipped_files: list[dict]) -> None: + """Print completion summary.""" + print() + print("=" * 80) + print("Conversion complete!") + print("=" * 80) + print(f"Output location: {output_dir}") + if eval_only: + print("Generated (BEIR format):") + print(" - corpus.jsonl") + print(" - queries.jsonl") + print(" - qrels/test.tsv") + else: + print("Generated:") + print(" - train.json (retriever training format)") + print(" - val.json (retriever validation format)") + print(" - corpus/ (parquet + metadata)") + print(" - eval_beir/ (BEIR test/evaluation format)") + + if skipped_files: + print() + print("=" * 80) + print(f"Skipped Files ({len(skipped_files)} total)") + print("=" * 80) + for item in skipped_files: + print(f" - {item['file_name']}: {item['reason']}") diff --git a/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/dedup.py b/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/dedup.py new file mode 100644 index 0000000..f9d6585 --- /dev/null +++ b/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/dedup.py @@ -0,0 +1,127 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Column generator that deduplicates QA pairs via embedding cosine similarity.""" + +from __future__ import annotations + +import logging +from concurrent.futures import ThreadPoolExecutor + +import numpy as np +from data_designer.engine.column_generators.generators.base import ColumnGeneratorCellByCell + +from data_designer_retrieval_sdg.config import RetrievalSdgDedupColumnConfig + +logger = logging.getLogger(__name__) + + +class RetrievalSdgDedupColumnGenerator(ColumnGeneratorCellByCell[RetrievalSdgDedupColumnConfig]): + """Remove near-duplicate QA pairs using embedding cosine similarity. + + For each cell the generator: + 1. Reads QA pairs from the configured source column. + 2. Embeds every question in parallel via the registered embedding model. + 3. Computes pairwise cosine similarity and greedily drops duplicates + whose similarity exceeds ``dedupe_similarity_threshold``. + 4. Returns the surviving pairs under the column name. + """ + + @property + def embedder(self): + """Resolve the embedding model from the resource provider.""" + return self.resource_provider.model_registry.get_model( + model_alias=self.config.embedding_alias, + ) + + def embed_text(self, text: str) -> list[float]: + """Compute an embedding vector for *text* using the configured model. + + Args: + text: Input string to embed. + + Returns: + List of floats representing the embedding vector. + """ + vectors = self.embedder.generate_text_embeddings( + input_texts=[text], + encoding_format="float", + ) + return vectors[0] + + def dedupe_qa_pairs(self, embeddings: list[list[float]]) -> list[int]: + """Return indices of QA pairs to keep after greedy deduplication. + + Computes pairwise cosine similarity. For every pair above the + threshold the later item is dropped. + + Args: + embeddings: 2-D list of embedding vectors, one per QA pair. + + Returns: + Sorted list of integer indices to retain. + + Raises: + ValueError: If *embeddings* is not a 2-D structure. + """ + if not embeddings: + return [] + + matrix = np.asarray(embeddings, dtype=float) + if matrix.ndim != 2: + raise ValueError("Embeddings must be a 2D array of shape (n, d).") + + norms = np.linalg.norm(matrix, axis=1, keepdims=True) + norms[norms == 0] = 1.0 + normalized = matrix / norms + + cosine_sim = np.clip(normalized @ normalized.T, -1.0, 1.0) + + threshold = self.config.dedupe_similarity_threshold + keep_indexes: list[int] = [] + dropped = np.zeros(len(embeddings), dtype=bool) + + for i in range(len(embeddings)): + if dropped[i]: + continue + keep_indexes.append(i) + if i == len(embeddings) - 1: + continue + close_matches = np.where(cosine_sim[i, i + 1 :] > threshold)[0] + i + 1 + dropped[close_matches] = True + + return keep_indexes + + def generate(self, data: dict) -> dict: + """Deduplicate QA pairs for a single record. + + Args: + data: Row dict containing at least the ``qa_pairs_column``. + + Returns: + Updated row dict with the deduplicated pairs stored under + ``self.config.name``. + """ + logger.debug("Deduplicating QA pairs from column: %s", self.config.qa_pairs_column) + + qa_pairs: list = data[self.config.qa_pairs_column]["pairs"] + max_parallel = self.embedder.max_parallel_requests + workers = max(1, max_parallel or 1) + + with ThreadPoolExecutor(max_workers=workers) as executor: + embeddings = list(executor.map(self.embed_text, [qa["question"] for qa in qa_pairs])) + + retained_indexes = self.dedupe_qa_pairs(embeddings) + dropped = len(qa_pairs) - len(retained_indexes) + if dropped > 0: + logger.info( + "Dedup: retained %d of %d QA pairs (%d duplicates removed)", + len(retained_indexes), + len(qa_pairs), + dropped, + ) + else: + logger.debug("Dedup: retained all %d QA pairs (no duplicates)", len(qa_pairs)) + + retained_qa_pairs = [qa_pairs[i] for i in retained_indexes] + return data | {self.config.name: retained_qa_pairs} diff --git a/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/ingest.py b/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/ingest.py new file mode 100644 index 0000000..c80fd77 --- /dev/null +++ b/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/ingest.py @@ -0,0 +1,724 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Text ingestion, chunking, and section-building utilities. + +This module handles loading text files from a directory, chunking them by +sentence boundaries, and organising chunks into sections using various +strategies (sequential, doc-balanced, interleaved). It supports both +single-document and multi-document (bundled) modes. +""" + +from __future__ import annotations + +import hashlib +import math +import re +from collections import defaultdict, deque +from pathlib import Path +from typing import Literal + +import nltk +import pandas as pd +from nltk.tokenize import sent_tokenize + +# --------------------------------------------------------------------------- +# File-matching helpers +# --------------------------------------------------------------------------- + + +def is_traditional_extension(suffix: str) -> bool: + """Check whether *suffix* looks like a real file extension. + + Traditional extensions are short (1-10 chars), start with a period, and + contain only alphanumeric characters. For example ``.txt``, ``.md``, + ``.json`` are traditional, whereas + ``.com_publication_2001-08_user-programmable`` is not. + + Args: + suffix: The file suffix (including leading ``'.'``). + + Returns: + ``True`` when the suffix matches the traditional pattern. + """ + if not suffix or not suffix.startswith("."): + return False + ext_part = suffix[1:] + return len(ext_part) <= 10 and ext_part.replace("_", "").isalnum() + + +def file_matches_extensions(file_path: Path, file_extensions: list[str]) -> bool: + """Decide whether *file_path* has one of the allowed extensions. + + Files whose suffix is not *traditional* (see + :func:`is_traditional_extension`) are treated as having no extension + and matched against ``""`` in *file_extensions*. + + Args: + file_path: Path to the file. + file_extensions: Allowed extensions, e.g. ``[".txt", ".md", ""]``. + + Returns: + ``True`` when the file matches. + """ + suffix = file_path.suffix.lower() + if is_traditional_extension(suffix): + return suffix in file_extensions + return "" in file_extensions + + +# --------------------------------------------------------------------------- +# Multi-document bundling helpers +# --------------------------------------------------------------------------- + + +def load_multi_doc_manifest(manifest_path: Path | None) -> list[list[str]]: + """Load a multi-doc manifest file. + + Supports JSON or YAML format:: + + [["doc1.txt", "doc2.txt"], ["doc3.txt"]] + {"bundles": [{"docs": ["doc1.txt", "doc2.txt"]}]} + + Args: + manifest_path: Path to the manifest file, or ``None``. + + Returns: + List of bundles, each a list of file-path strings. + """ + import json + + import yaml + + if not manifest_path: + return [] + + try: + manifest_text = manifest_path.read_text(encoding="utf-8") + except Exception as exc: + print(f"Warning: Unable to read multi_doc_manifest at {manifest_path}: {exc}") + return [] + + data = None + try: + data = json.loads(manifest_text) + except json.JSONDecodeError: + try: + data = yaml.safe_load(manifest_text) + except Exception as exc: + print(f"Warning: Failed to parse multi_doc_manifest: {exc}") + return [] + + if isinstance(data, dict) and "bundles" in data: + data = data["bundles"] + + bundles: list[list[str]] = [] + if isinstance(data, list): + for entry in data: + if isinstance(entry, dict) and "docs" in entry: + docs = entry["docs"] + elif isinstance(entry, list): + docs = entry + else: + docs = [] + clean_docs = [str(doc) for doc in docs if doc] + if clean_docs: + bundles.append(clean_docs) + else: + print("Warning: multi_doc_manifest must be a list or dict with 'bundles'") + + return bundles + + +def build_bundle_id(bundle_members: list[str]) -> str: + """Generate a unique bundle ID from member paths. + + Args: + bundle_members: List of file paths in the bundle. + + Returns: + MD5 hex digest of sorted, normalised paths. + """ + if not bundle_members: + return "" + normalized = "||".join(sorted(str(Path(member).resolve()) for member in bundle_members)) + return hashlib.md5(normalized.encode()).hexdigest() + + +def build_bundles( + file_paths: list[Path], + bundle_size: int = 2, + max_docs_per_bundle: int = 3, + manifest_bundles: list[list[str]] | None = None, + input_dir: Path | None = None, +) -> list[list[Path]]: + """Group file paths into document bundles. + + Manifest-defined bundles take priority. Remaining documents are grouped + sequentially according to *bundle_size*. + + Args: + file_paths: All candidate file paths. + bundle_size: Documents per automatic bundle. + max_docs_per_bundle: Hard cap on bundle size. + manifest_bundles: Pre-defined bundles from a manifest file. + input_dir: Root directory for resolving relative manifest paths. + + Returns: + List of bundles, each a list of resolved ``Path`` objects. + + Raises: + ValueError: If any bundle exceeds *max_docs_per_bundle*. + """ + if not file_paths: + return [] + + resolved_paths = [path.resolve() for path in file_paths] + seen: set[Path] = set() + bundles: list[list[Path]] = [] + + if manifest_bundles: + for entry in manifest_bundles: + resolved_bundle: list[Path] = [] + for raw_doc in entry: + candidate = Path(raw_doc) + if not candidate.is_absolute() and input_dir: + candidate = (input_dir / raw_doc).resolve() + candidate = candidate.resolve() + if candidate in resolved_paths and candidate not in seen: + resolved_bundle.append(candidate) + seen.add(candidate) + if resolved_bundle: + bundles.append(resolved_bundle) + + remaining = [p for p in resolved_paths if p not in seen] + for start in range(0, len(remaining), bundle_size): + bundle = remaining[start : start + bundle_size] + if bundle: + bundles.append(bundle) + + for i, bundle in enumerate(bundles): + if len(bundle) > max_docs_per_bundle: + raise ValueError( + f"Bundle {i} has {len(bundle)} documents, which exceeds " + f"max_docs_per_bundle={max_docs_per_bundle}. " + f"Either reduce the bundle size in your manifest or increase max_docs_per_bundle." + ) + + return [b for b in bundles if b] + + +# --------------------------------------------------------------------------- +# Section-building strategies +# --------------------------------------------------------------------------- + + +def group_chunks_by_doc(chunks: list[dict]) -> dict[str, list[tuple[int, dict]]]: + """Group chunks by their ``doc_id`` field. + + Args: + chunks: Chunk dicts, each optionally containing ``'doc_id'``. + + Returns: + Mapping from ``doc_id`` to ``(global_index, chunk)`` pairs. + """ + grouped: dict[str, list[tuple[int, dict]]] = defaultdict(list) + for idx, chunk in enumerate(chunks): + doc_id = chunk.get("doc_id", "default") + grouped[doc_id].append((idx, chunk)) + return dict(grouped) + + +def format_section_chunks(indexed_chunks: list[tuple[int, dict]], section_number: int) -> str: + """Render a list of indexed chunks into a section string. + + Args: + indexed_chunks: ``(global_index, chunk)`` tuples. + section_number: Section ordinal for the header. + + Returns: + Formatted section text, or ``""`` if no content. + """ + section_lines: list[str] = [] + for _, chunk in indexed_chunks: + text = chunk.get("text", "").strip() + if not text: + continue + segment_id = chunk.get("chunk_id", 1) + doc_id = chunk.get("doc_id", "") + start_time = "00:00:00" + end_time = "00:00:00" + if doc_id: + segment_info = f"Segment {segment_id} [Doc: {doc_id}] ({start_time} - {end_time}): {text}" + else: + segment_info = f"Segment {segment_id} ({start_time} - {end_time}): {text}" + section_lines.append(segment_info) + + if section_lines: + return f"=== Section {section_number} ===\n" + "\n".join(section_lines) + return "" + + +def chunks_to_sections_sequential(chunks: list[dict], num_sections: int = 1) -> list[str]: + """Split chunks sequentially into *num_sections* sections. + + Args: + chunks: Chunk dicts in document order. + num_sections: How many sections to produce. + + Returns: + List of formatted section strings. + """ + total = len(chunks) + if total == 0: + return [] + + section_size = max(1, total // num_sections) + formatted_sections: list[str] = [] + + for i in range(num_sections): + start_idx = i * section_size + end_idx = (i + 1) * section_size if i < num_sections - 1 else total + indexed_chunks = [(j, chunks[j]) for j in range(start_idx, end_idx)] + section_text = format_section_chunks(indexed_chunks, i + 1) + if section_text: + formatted_sections.append(section_text) + + return formatted_sections + + +def chunks_to_sections_doc_balanced(chunks: list[dict], num_sections: int = 1) -> list[str]: + """Split chunks so each section has proportional doc representation. + + Falls back to sequential when there is only one document. + + Args: + chunks: Chunk dicts with ``'doc_id'`` fields. + num_sections: How many sections to produce. + + Returns: + List of formatted section strings. + """ + if not chunks: + return [] + + grouped = group_chunks_by_doc(chunks) + if len(grouped) <= 1: + return chunks_to_sections_sequential(chunks, num_sections) + + chunk_sizes = {doc_id: max(1, math.ceil(len(entries) / num_sections)) for doc_id, entries in grouped.items()} + + sections: list[list[tuple[int, dict]]] = [] + for part_idx in range(num_sections): + part_entries: list[tuple[int, dict]] = [] + for doc_id, entries in grouped.items(): + chunk_size = chunk_sizes[doc_id] + start = part_idx * chunk_size + end = min(len(entries), start + chunk_size) + if start < len(entries): + part_entries.extend(entries[start:end]) + if part_entries: + sections.append(part_entries) + + formatted_sections: list[str] = [] + for i, indexed_chunks in enumerate(sections): + section_text = format_section_chunks(indexed_chunks, i + 1) + if section_text: + formatted_sections.append(section_text) + + return formatted_sections + + +def chunks_to_sections_interleaved(chunks: list[dict], num_sections: int = 1) -> list[str]: + """Split chunks with round-robin interleaving across documents. + + Falls back to sequential when there is only one document. + + Args: + chunks: Chunk dicts with ``'doc_id'`` fields. + num_sections: How many sections to produce. + + Returns: + List of formatted section strings. + """ + if not chunks: + return [] + + grouped = group_chunks_by_doc(chunks) + if len(grouped) <= 1: + return chunks_to_sections_sequential(chunks, num_sections) + + doc_iterators = {doc_id: deque(entries) for doc_id, entries in grouped.items()} + doc_order = list(grouped.keys()) + interleaved: list[tuple[int, dict]] = [] + + while True: + added = False + for doc_id in doc_order: + doc_queue = doc_iterators[doc_id] + if doc_queue: + interleaved.append(doc_queue.popleft()) + added = True + if not added: + break + + if not interleaved: + return [] + + total = len(interleaved) + section_size = max(1, total // num_sections) + formatted_sections: list[str] = [] + + for i in range(num_sections): + start_idx = i * section_size + end_idx = (i + 1) * section_size if i < num_sections - 1 else total + indexed_chunks = interleaved[start_idx:end_idx] + section_text = format_section_chunks(indexed_chunks, i + 1) + if section_text: + formatted_sections.append(section_text) + + return formatted_sections + + +def chunks_to_sections_structured( + chunks: list[dict], + num_sections: int = 1, + strategy: Literal["sequential", "doc_balanced", "interleaved"] = "sequential", +) -> list[str]: + """Split chunks into sections using the specified strategy. + + Args: + chunks: Chunk dicts. + num_sections: How many sections to produce. + strategy: ``"sequential"``, ``"doc_balanced"``, or ``"interleaved"``. + + Returns: + List of formatted section strings. + """ + if strategy == "doc_balanced": + return chunks_to_sections_doc_balanced(chunks, num_sections) + if strategy == "interleaved": + return chunks_to_sections_interleaved(chunks, num_sections) + return chunks_to_sections_sequential(chunks, num_sections) + + +# --------------------------------------------------------------------------- +# Sentence chunking +# --------------------------------------------------------------------------- + + +def _ensure_nltk_punkt() -> None: + """Download NLTK punkt tokeniser data if not already present.""" + for resource in ("tokenizers/punkt", "tokenizers/punkt_tab"): + try: + nltk.data.find(resource) + except LookupError: + nltk.download(resource.split("/")[-1], quiet=True) + + +def text_to_sentence_chunks( + text: str, + sentences_per_chunk: int = 5, + doc_id: str | None = None, + doc_path: str | None = None, + chunk_id_offset: int = 0, +) -> list[dict]: + """Chunk *text* into groups of sentences with metadata. + + Args: + text: Input text to chunk. + sentences_per_chunk: Sentences per chunk. + doc_id: Optional document identifier for multi-doc bundles. + doc_path: Optional document path for multi-doc bundles. + chunk_id_offset: Offset for global chunk IDs when aggregating. + + Returns: + List of chunk dicts with keys ``text``, ``start``, ``end``, + ``sentence_count``, ``word_count``, ``chunk_id``, + ``doc_chunk_index``, and optionally ``doc_id`` / ``doc_path``. + """ + _ensure_nltk_punkt() + + paragraphs = re.split(r"\n\s*\n+", text) + paragraphs = [p.strip() for p in paragraphs if p.strip()] + + sentences: list[str] = [] + for paragraph in paragraphs: + sentences.extend(sent_tokenize(paragraph)) + + chunks: list[dict] = [] + word_position = 0 + doc_chunk_index = 0 + + for i in range(0, len(sentences), sentences_per_chunk): + chunk_sentences = sentences[i : i + sentences_per_chunk] + chunk_text = ". ".join(chunk_sentences) + if chunk_text and not chunk_text.endswith("."): + chunk_text += "." + + chunk_words = chunk_text.split() + start_word_pos = word_position + end_word_pos = word_position + len(chunk_words) + word_position = end_word_pos + doc_chunk_index += 1 + + chunk_data: dict = { + "text": chunk_text, + "start": start_word_pos, + "end": end_word_pos, + "sentence_count": len(chunk_sentences), + "word_count": len(chunk_words), + "chunk_id": chunk_id_offset + len(chunks) + 1, + "doc_chunk_index": doc_chunk_index, + } + + if doc_id is not None: + chunk_data["doc_id"] = doc_id + if doc_path is not None: + chunk_data["doc_path"] = doc_path + + chunks.append(chunk_data) + + return chunks + + +# --------------------------------------------------------------------------- +# Top-level directory loader +# --------------------------------------------------------------------------- + + +def load_text_files_from_directory( + input_dir: Path, + file_extensions: list[str] | None = None, + min_text_length: int = 0, + sentences_per_chunk: int = 5, + num_sections: int = 1, + num_files: int | None = None, + multi_doc: bool = False, + bundle_size: int = 2, + bundle_strategy: Literal["sequential", "doc_balanced", "interleaved"] = "sequential", + max_docs_per_bundle: int = 3, + multi_doc_manifest: Path | None = None, +) -> pd.DataFrame: + """Load text files from a directory into a seed DataFrame. + + Supports single-document mode (one row per file) and multi-document mode + (files grouped into bundles, one row per bundle). + + Args: + input_dir: Root directory containing text files. + file_extensions: Allowed extensions (default ``[".txt", ".md", ".text", ""]``). + min_text_length: Minimum character count to include a document. + sentences_per_chunk: Sentences per chunk. + num_sections: Sections to split chunks into. + num_files: Cap on the number of files to process. + multi_doc: Enable multi-document bundling. + bundle_size: Documents per automatic bundle. + bundle_strategy: Section-building strategy. + max_docs_per_bundle: Hard cap on bundle size. + multi_doc_manifest: Path to a manifest defining explicit bundles. + + Returns: + DataFrame with columns ``file_name``, ``text``, ``chunks``, + ``sections_structured``, and (when multi-doc) ``bundle_id``, + ``bundle_members``, ``is_multi_doc``. + + Raises: + ValueError: If no text files or valid documents are found. + """ + if file_extensions is None: + file_extensions = [".txt", ".md", ".text", ""] + + all_file_paths: list[Path] = [] + for file_path in input_dir.rglob("*"): + if num_files is not None and len(all_file_paths) >= num_files: + break + if file_path.is_file() and file_matches_extensions(file_path, file_extensions): + try: + content = file_path.read_text(encoding="utf-8") + if min_text_length > 0 and len(content) < min_text_length: + continue + all_file_paths.append(file_path) + except Exception as e: + print(f"Warning: Could not read {file_path}: {e}") + continue + + if not all_file_paths: + raise ValueError(f"No text files found in {input_dir} with extensions {file_extensions}") + + resolved_input_dir = input_dir.resolve() + documents: list[dict] = [] + + if multi_doc: + documents = _load_multi_doc( + all_file_paths, + resolved_input_dir, + sentences_per_chunk, + num_sections, + bundle_size, + bundle_strategy, + max_docs_per_bundle, + multi_doc_manifest, + ) + else: + documents = _load_single_doc( + all_file_paths, + input_dir, + sentences_per_chunk, + num_sections, + bundle_strategy, + ) + + if not documents: + raise ValueError(f"No valid documents created from {input_dir}") + + df = pd.DataFrame(documents) + _print_load_stats(df, all_file_paths, multi_doc, min_text_length, bundle_strategy) + return df + + +# --------------------------------------------------------------------------- +# Internal loader helpers +# --------------------------------------------------------------------------- + + +def _load_single_doc( + file_paths: list[Path], + input_dir: Path, + sentences_per_chunk: int, + num_sections: int, + bundle_strategy: Literal["sequential", "doc_balanced", "interleaved"], +) -> list[dict]: + """Build one row per file.""" + documents: list[dict] = [] + for file_path in file_paths: + relative_path = file_path.relative_to(input_dir) + try: + content = file_path.read_text(encoding="utf-8") + except Exception as e: + print(f"Warning: Could not read {relative_path}: {e}") + continue + + chunks = text_to_sentence_chunks(content, sentences_per_chunk=sentences_per_chunk) + sections_structured = chunks_to_sections_structured(chunks, num_sections=num_sections, strategy=bundle_strategy) + documents.append( + { + "file_name": [str(relative_path)], + "text": content, + "chunks": chunks, + "sections_structured": sections_structured, + "bundle_id": "", + "bundle_members": [str(relative_path)], + "is_multi_doc": False, + } + ) + return documents + + +def _load_multi_doc( + file_paths: list[Path], + resolved_input_dir: Path, + sentences_per_chunk: int, + num_sections: int, + bundle_size: int, + bundle_strategy: Literal["sequential", "doc_balanced", "interleaved"], + max_docs_per_bundle: int, + multi_doc_manifest: Path | None, +) -> list[dict]: + """Build one row per bundle.""" + manifest_bundles = load_multi_doc_manifest(multi_doc_manifest) + bundles = build_bundles( + file_paths, + bundle_size=bundle_size, + max_docs_per_bundle=max_docs_per_bundle, + manifest_bundles=manifest_bundles, + input_dir=resolved_input_dir, + ) + + print(f"Multi-doc mode: Created {len(bundles)} bundles from {len(file_paths)} files") + documents: list[dict] = [] + + for bundle in bundles: + bundle_texts: list[str] = [] + bundle_chunks: list[dict] = [] + bundle_members: list[str] = [] + chunk_id_offset = 0 + + for file_path in bundle: + relative_path = file_path.relative_to(resolved_input_dir) + doc_id = str(relative_path) + bundle_members.append(doc_id) + + try: + content = file_path.read_text(encoding="utf-8") + except Exception as e: + print(f"Warning: Could not read {file_path}: {e}") + continue + + bundle_texts.append(content) + doc_chunks = text_to_sentence_chunks( + content, + sentences_per_chunk=sentences_per_chunk, + doc_id=doc_id, + doc_path=str(file_path), + chunk_id_offset=chunk_id_offset, + ) + bundle_chunks.extend(doc_chunks) + chunk_id_offset += len(doc_chunks) + + if not bundle_chunks: + continue + + combined_text = "\n\n=== Document Boundary ===\n\n".join(bundle_texts) + sections_structured = chunks_to_sections_structured( + bundle_chunks, num_sections=num_sections, strategy=bundle_strategy + ) + bid = build_bundle_id(bundle_members) + + documents.append( + { + "file_name": bundle_members, + "text": combined_text, + "chunks": bundle_chunks, + "sections_structured": sections_structured, + "bundle_id": bid, + "bundle_members": bundle_members, + "is_multi_doc": True, + } + ) + + return documents + + +def _print_load_stats( + df: pd.DataFrame, + all_file_paths: list[Path], + multi_doc: bool, + min_text_length: int, + bundle_strategy: str, +) -> None: + """Print statistics about the loaded data.""" + row_type = "bundle" if multi_doc else "document" + if multi_doc: + avg_docs = sum(len(m) for m in df["bundle_members"]) / len(df) if len(df) > 0 else 0 + print(f"Created {len(df)} bundles from {len(all_file_paths)} files") + print(f"Average documents per bundle: {avg_docs:.1f}") + else: + print(f"Loaded {len(df)} text files from directory") + + if min_text_length > 0: + print(f"Filtered to documents with at least {min_text_length} characters") + + total_chunks = sum(len(c) for c in df["chunks"]) + avg_chunks = total_chunks / len(df) if len(df) > 0 else 0 + print(f"Created {total_chunks} total chunks ({avg_chunks:.1f} chunks per {row_type})") + + total_sections = sum(len(s) for s in df["sections_structured"]) + avg_sections = total_sections / len(df) if len(df) > 0 else 0 + avg_chunks_per_section = total_chunks / total_sections if total_sections > 0 else 0 + print( + f"Organized into {total_sections} sections " + f"({avg_sections:.1f} sections per {row_type}, " + f"{avg_chunks_per_section:.1f} chunks per section)" + ) + print(f"Bundle strategy: {bundle_strategy}") diff --git a/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/models.py b/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/models.py new file mode 100644 index 0000000..0dd1fbe --- /dev/null +++ b/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/models.py @@ -0,0 +1,128 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Pydantic models for structured LLM outputs in the retriever SDG pipeline. + +These models define the schemas for artifact extraction, QA generation, +and quality evaluation columns. +""" + +from __future__ import annotations + +from typing import Literal + +from pydantic import BaseModel, Field + +# --------------------------------------------------------------------------- +# Artifact extraction models +# --------------------------------------------------------------------------- + + +class ArtifactItem(BaseModel): + """A single artifact item with text, description, and importance.""" + + text: str = Field(description="The artifact text or name") + description: str = Field(description="Detailed description of the artifact") + importance: str = Field(description="Why this artifact is important") + + +class DocumentArtifacts(BaseModel): + """Semantic artifacts extracted from a document.""" + + key_concepts: list[ArtifactItem] = Field(default_factory=list, description="Key concepts in the document") + relationships: list[ArtifactItem] = Field(default_factory=list, description="Relationships between concepts") + themes: list[ArtifactItem] = Field(default_factory=list, description="Main themes") + entities: list[ArtifactItem] = Field(default_factory=list, description="Entities mentioned") + processes: list[ArtifactItem] = Field(default_factory=list, description="Processes described") + insights: list[ArtifactItem] = Field(default_factory=list, description="Key insights") + technical_terms: list[ArtifactItem] = Field(default_factory=list, description="Technical terms") + contextual_factors: list[ArtifactItem] = Field(default_factory=list, description="Contextual factors") + + +# --------------------------------------------------------------------------- +# QA generation models +# --------------------------------------------------------------------------- + + +class HopContext(BaseModel): + """Context for a single hop in a multi-hop question.""" + + hop_number: int = Field(description="The hop number (1-indexed)") + segment_ids: list[int] = Field(description="Segment IDs for this hop") + summary: str = Field(description="Summary of the supporting segments for this hop") + + +class QuestionAnswerPair(BaseModel): + """A single question-answer pair with metadata.""" + + question: str = Field( + description=("The question requiring understanding of contexts without explicitly referencing them"), + ) + answer: str = Field( + description=("Comprehensive answer from the contexts without explicitly referencing them"), + ) + question_complexity: int = Field(description="Numeric score from min_complexity to 5") + query_type: Literal["multi_hop", "structural", "contextual"] = Field( + description="Type of query, one of multi_hop, structural, or contextual", + ) + reasoning_type: Literal["factual", "relational", "inferential", "temporal", "procedural", "visual", "causal"] = ( + Field( + description=( + "Type of reasoning required, one of factual, relational, inferential, " + "temporal, procedural, visual, or causal" + ), + ) + ) + segment_ids: list[int] = Field( + description="List of segment IDs that are source material for this question", + ) + hop_count: int = Field( + description=("Number of hops (min_hops to max_hops) for multi_hop questions, or 1 for non-multi-hop"), + ) + hop_contexts: list[HopContext] = Field(description="Array of hop detail objects") + + +class QuestionAnswerPairs(BaseModel): + """Collection of question-answer pairs.""" + + pairs: list[QuestionAnswerPair] = Field(description="List of question-answer pairs") + + +# --------------------------------------------------------------------------- +# QA evaluation models +# --------------------------------------------------------------------------- + + +class QAEvaluationCriterion(BaseModel): + """Evaluation criterion with score and justification.""" + + score: int = Field(description="Score from 1-10") + justification: str = Field(description="Brief justification for the score") + + +class QAOverallEvaluation(BaseModel): + """Overall evaluation with score and assessment.""" + + score: float = Field(description="Overall score from 1-10") + assessment: str = Field(description="Final assessment of the QA pair") + + +class QAEvaluation(BaseModel): + """Evaluation of a single QA pair.""" + + relevance: QAEvaluationCriterion = Field(description="Relevance of question to context") + accuracy: QAEvaluationCriterion = Field(description="Factual accuracy of answer") + context_support: QAEvaluationCriterion = Field( + description="How well answer is supported by context", + ) + clarity: QAEvaluationCriterion = Field(description="Clarity and unambiguity of question") + overall: QAOverallEvaluation = Field(description="Overall evaluation") + improvements: str = Field(description="Suggestions for improving this QA pair") + + +class QAPairEvaluations(BaseModel): + """Evaluations for all QA pairs in a document.""" + + evaluations: list[QAEvaluation] = Field( + description="List of evaluations, one per QA pair, in the same order as the QA pairs", + ) diff --git a/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/pipeline.py b/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/pipeline.py new file mode 100644 index 0000000..e1eacfa --- /dev/null +++ b/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/pipeline.py @@ -0,0 +1,349 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Pipeline builder for the retriever SDG workflow. + +Assembles a four-column DataDesigner pipeline: + +1. ``document_artifacts`` -- LLM-based artifact extraction +2. ``qa_generation`` -- LLM-based QA pair generation +3. ``deduplicated_qa_pairs`` -- embedding-based deduplication (plugin column) +4. ``qa_evaluations`` -- LLM-based quality evaluation +""" + +from __future__ import annotations + +import json +from pathlib import Path + +import data_designer.config as dd +import pandas as pd +from data_designer.config.default_model_settings import get_default_providers + +from data_designer_retrieval_sdg.config import RetrievalSdgDedupColumnConfig +from data_designer_retrieval_sdg.models import ( + DocumentArtifacts, + QAPairEvaluations, + QuestionAnswerPairs, +) +from data_designer_retrieval_sdg.prompts import ( + ARTIFACT_EXTRACTION_SYSTEM_PROMPT, + ARTIFACT_EXTRACTION_USER_PROMPT, + QA_EVALUATION_SYSTEM_PROMPT, + QA_EVALUATION_USER_PROMPT, + QA_GENERATION_SYSTEM_PROMPT, + QA_GENERATION_USER_PROMPT, +) + +# --------------------------------------------------------------------------- +# Model configuration +# --------------------------------------------------------------------------- + +DEFAULT_CHAT_MODEL = "nvidia/nemotron-3-nano-30b-a3b" +DEFAULT_EMBED_MODEL = "nvidia/llama-3.2-nv-embedqa-1b-v2" +DEFAULT_PROVIDER = "nvidia" + + +def custom_model_config( + artifact_extraction_model: str = DEFAULT_CHAT_MODEL, + artifact_extraction_provider: str = DEFAULT_PROVIDER, + qa_generation_model: str = DEFAULT_CHAT_MODEL, + qa_generation_provider: str = DEFAULT_PROVIDER, + quality_judge_model: str = DEFAULT_CHAT_MODEL, + quality_judge_provider: str = DEFAULT_PROVIDER, + embed_model: str = DEFAULT_EMBED_MODEL, + embed_provider: str = DEFAULT_PROVIDER, + max_parallel_requests_for_gen: int | None = None, +) -> tuple[list[dd.ModelConfig], dict[str, str]]: + """Configure the model suite for a generation job. + + Each pipeline role (artifact extraction, QA generation, quality judge, + embedding) can point at a different model+provider. When multiple roles + share the same ``(model, provider)`` pair a single ``ModelConfig`` is + created and the roles share its alias. + + Args: + artifact_extraction_model: Model name for artifact extraction. + artifact_extraction_provider: Provider for artifact extraction. + qa_generation_model: Model name for QA generation. + qa_generation_provider: Provider for QA generation. + quality_judge_model: Model name for quality judge. + quality_judge_provider: Provider for quality judge. + embed_model: Model name for embeddings. + embed_provider: Provider for embeddings. + max_parallel_requests_for_gen: Optional cap on parallel requests + for chat-completion models. + + Returns: + Tuple of ``(model_configs, role_aliases)`` where *role_aliases* + maps each role name to the ``ModelConfig`` alias it should reference. + """ + configs: list[dd.ModelConfig] = [ + dd.ModelConfig( + alias="embed", + model=embed_model, + inference_parameters=dd.EmbeddingInferenceParams( + max_parallel_requests=8, + extra_body={"input_type": "query", "truncate": "NONE"}, + ), + provider=embed_provider, + ), + ] + role_aliases: dict[str, str] = {"embed": "embed"} + + chat_roles = [ + ("artifact_extraction", artifact_extraction_model, artifact_extraction_provider), + ("qa_generation", qa_generation_model, qa_generation_provider), + ("quality_judge", quality_judge_model, quality_judge_provider), + ] + + seen: dict[tuple[str, str], str] = {} + for role_name, model, provider in chat_roles: + key = (model, provider) + if key not in seen: + seen[key] = role_name + inference_kwargs: dict = { + "temperature": 0.6, + "top_p": 0.95, + "timeout": 120, + } + if max_parallel_requests_for_gen is not None: + inference_kwargs["max_parallel_requests"] = max_parallel_requests_for_gen + configs.append( + dd.ModelConfig( + alias=role_name, + model=model, + provider=provider, + inference_parameters=dd.ChatCompletionInferenceParams(**inference_kwargs), + ) + ) + role_aliases[role_name] = seen[key] + + return configs, role_aliases + + +# --------------------------------------------------------------------------- +# Model-provider helpers +# --------------------------------------------------------------------------- + + +def build_model_providers( + custom_provider_endpoint: str | None = None, + custom_provider_name: str = "custom", + custom_provider_type: str = "openai", + custom_provider_api_key: str | None = None, + model_providers_file: Path | None = None, +) -> tuple[list[dd.ModelProvider] | None, list[dd.ModelProvider]]: + """Build a list of custom ``ModelProvider`` objects from CLI flags / config. + + Inline flags define a single provider; the config file can define + multiple. When both are supplied the inline provider overwrites any + file entry with the same name. + + Custom providers are merged with Data Designer defaults so that built-in + providers remain available. + + Args: + custom_provider_endpoint: Base URL for an inline custom provider. + custom_provider_name: Name for the inline provider. + custom_provider_type: API format (default ``"openai"``). + custom_provider_api_key: API key or env-var name. + model_providers_file: Path to a YAML/JSON file with provider entries. + + Returns: + Tuple of ``(all_providers, custom_only_providers)``. + ``all_providers`` is ``None`` when no custom providers exist. + """ + import yaml + + custom: list[dd.ModelProvider] = [] + + if model_providers_file is not None: + raw = model_providers_file.read_text(encoding="utf-8") + if model_providers_file.suffix in (".yaml", ".yml"): + entries = yaml.safe_load(raw) + else: + entries = json.loads(raw) + + if not isinstance(entries, list): + raise ValueError(f"model-providers-file must contain a YAML/JSON list, got {type(entries).__name__}") + for entry in entries: + custom.append(dd.ModelProvider(**entry)) + + if custom_provider_endpoint is not None: + custom = [p for p in custom if p.name != custom_provider_name] + custom.append( + dd.ModelProvider( + name=custom_provider_name, + endpoint=custom_provider_endpoint, + provider_type=custom_provider_type, + api_key=custom_provider_api_key, + ) + ) + + if not custom: + return None, [] + + custom_names = {p.name for p in custom} + defaults = [p for p in get_default_providers() if p.name not in custom_names] + return defaults + custom, custom + + +# --------------------------------------------------------------------------- +# Pipeline builder +# --------------------------------------------------------------------------- + +DEFAULT_QUERY_COUNTS: dict[str, int] = {"multi_hop": 3, "structural": 2, "contextual": 2} +DEFAULT_REASONING_COUNTS: dict[str, int] = { + "factual": 1, + "relational": 1, + "inferential": 1, + "temporal": 1, + "procedural": 1, + "causal": 1, + "visual": 1, +} + + +def build_qa_generation_pipeline( + seed_dataset: pd.DataFrame, + start_index: int = 0, + end_index: int = 199, + max_artifacts_per_type: int = 2, + num_pairs: int = 5, + query_counts: dict[str, int] | None = None, + min_hops: int = 2, + max_hops: int = 3, + reasoning_counts: dict[str, int] | None = None, + min_complexity: int = 4, + max_parallel_requests_for_gen: int | None = None, + artifact_extraction_model: str = DEFAULT_CHAT_MODEL, + artifact_extraction_provider: str = DEFAULT_PROVIDER, + qa_generation_model: str = DEFAULT_CHAT_MODEL, + qa_generation_provider: str = DEFAULT_PROVIDER, + quality_judge_model: str = DEFAULT_CHAT_MODEL, + quality_judge_provider: str = DEFAULT_PROVIDER, + embed_model: str = DEFAULT_EMBED_MODEL, + embed_provider: str = DEFAULT_PROVIDER, +) -> dd.DataDesignerConfigBuilder: + """Build a four-column QA generation pipeline. + + The pipeline adds columns in order: + + 1. ``document_artifacts`` -- structured artifact extraction + 2. ``qa_generation`` -- QA pair generation from artifacts + sections + 3. ``deduplicated_qa_pairs`` -- embedding dedup (plugin) + 4. ``qa_evaluations`` -- quality scoring + + Args: + seed_dataset: DataFrame with ``file_name``, ``text``, ``chunks``, + ``sections_structured`` columns. + start_index: Start index (inclusive) for ordered index-range selection. + end_index: End index (inclusive) for ordered index-range selection. + max_artifacts_per_type: Max artifacts extracted per type. + num_pairs: QA pairs to generate per document. + query_counts: Distribution of query types. + min_hops: Minimum hops for multi-hop questions. + max_hops: Maximum hops for multi-hop questions. + reasoning_counts: Distribution of reasoning types. + min_complexity: Minimum complexity score. + max_parallel_requests_for_gen: Cap on parallel requests for chat models. + artifact_extraction_model: Model for artifact extraction. + artifact_extraction_provider: Provider for artifact extraction. + qa_generation_model: Model for QA generation. + qa_generation_provider: Provider for QA generation. + quality_judge_model: Model for quality judge. + quality_judge_provider: Provider for quality judge. + embed_model: Model for embeddings. + embed_provider: Provider for embeddings. + + Returns: + Configured ``DataDesignerConfigBuilder`` ready for + ``DataDesigner.create()`` or ``.preview()``. + """ + if query_counts is None: + query_counts = dict(DEFAULT_QUERY_COUNTS) + if reasoning_counts is None: + reasoning_counts = dict(DEFAULT_REASONING_COUNTS) + + model_configs, role_aliases = custom_model_config( + artifact_extraction_model=artifact_extraction_model, + artifact_extraction_provider=artifact_extraction_provider, + qa_generation_model=qa_generation_model, + qa_generation_provider=qa_generation_provider, + quality_judge_model=quality_judge_model, + quality_judge_provider=quality_judge_provider, + embed_model=embed_model, + embed_provider=embed_provider, + max_parallel_requests_for_gen=max_parallel_requests_for_gen, + ) + + config_builder = dd.DataDesignerConfigBuilder(model_configs=model_configs) + + config_builder.with_seed_dataset( + dd.DataFrameSeedSource(df=seed_dataset), + sampling_strategy=dd.SamplingStrategy.ORDERED, + selection_strategy=dd.IndexRange(start=start_index, end=end_index), + ) + + # Column 1: artifact extraction + config_builder.add_column( + dd.LLMStructuredColumnConfig( + name="document_artifacts", + system_prompt=ARTIFACT_EXTRACTION_SYSTEM_PROMPT, + prompt=ARTIFACT_EXTRACTION_USER_PROMPT.format( + max_artifacts_per_type=max_artifacts_per_type, + ), + output_format=DocumentArtifacts, + model_alias=role_aliases["artifact_extraction"], + ) + ) + + # Column 2: QA generation + config_builder.add_column( + dd.LLMStructuredColumnConfig( + name="qa_generation", + system_prompt=QA_GENERATION_SYSTEM_PROMPT, + prompt=QA_GENERATION_USER_PROMPT.format( + query_counts_multi_hop=query_counts.get("multi_hop", 0), + query_counts_structural=query_counts.get("structural", 0), + query_counts_contextual=query_counts.get("contextual", 0), + reasoning_counts_factual=reasoning_counts.get("factual", 0), + reasoning_counts_relational=reasoning_counts.get("relational", 0), + reasoning_counts_inferential=reasoning_counts.get("inferential", 0), + reasoning_counts_temporal=reasoning_counts.get("temporal", 0), + reasoning_counts_procedural=reasoning_counts.get("procedural", 0), + reasoning_counts_visual=reasoning_counts.get("visual", 0), + reasoning_counts_causal=reasoning_counts.get("causal", 0), + min_hops=min_hops, + max_hops=max_hops, + min_complexity=min_complexity, + num_pairs=num_pairs, + ), + output_format=QuestionAnswerPairs, + model_alias=role_aliases["qa_generation"], + ) + ) + + # Column 3: deduplication (plugin column) + config_builder.add_column( + RetrievalSdgDedupColumnConfig( + name="deduplicated_qa_pairs", + qa_pairs_column="qa_generation", + embedding_alias="embed", + dedupe_similarity_threshold=0.9, + ) + ) + + # Column 4: quality evaluation + config_builder.add_column( + dd.LLMStructuredColumnConfig( + name="qa_evaluations", + system_prompt=QA_EVALUATION_SYSTEM_PROMPT, + prompt=QA_EVALUATION_USER_PROMPT, + output_format=QAPairEvaluations, + model_alias=role_aliases["quality_judge"], + ) + ) + + return config_builder diff --git a/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/plugin.py b/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/plugin.py new file mode 100644 index 0000000..aa19889 --- /dev/null +++ b/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/plugin.py @@ -0,0 +1,12 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Data Designer plugin registration for the retrieval-sdg-dedup column type.""" + +from data_designer.plugins.plugin import Plugin, PluginType + +plugin = Plugin( + config_qualified_name="data_designer_retrieval_sdg.config.RetrievalSdgDedupColumnConfig", + impl_qualified_name="data_designer_retrieval_sdg.dedup.RetrievalSdgDedupColumnGenerator", + plugin_type=PluginType.COLUMN_GENERATOR, +) diff --git a/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/postprocess.py b/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/postprocess.py new file mode 100644 index 0000000..1bbff7f --- /dev/null +++ b/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/postprocess.py @@ -0,0 +1,375 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Post-processing utilities for generated retriever SDG data. + +Includes BEIR-format export, quality-based filtering, and a helper +for loading positive documents with modality metadata. +""" + +from __future__ import annotations + +import json +from collections import defaultdict +from pathlib import Path + +import numpy as np +import pandas as pd + +# --------------------------------------------------------------------------- +# BEIR-format post-processing +# --------------------------------------------------------------------------- + + +def postprocess_retriever_data( + generated_df: pd.DataFrame, +) -> tuple[pd.DataFrame, pd.DataFrame, dict[str, list[str]]]: + """Flatten generated data into BEIR-style queries, qrels, and splits. + + Args: + generated_df: DataFrame produced by the pipeline, containing + ``file_name``, ``deduplicated_qa_pairs`` (or ``qa_generation``), + and metadata columns. + + Returns: + Tuple of ``(queries_df, qrels_df, splits)`` where *splits* maps + modality names to lists of query IDs. + """ + print(f"Processing {len(generated_df)} generated records...") + + queries_data: list[dict] = [] + qrels_data: list[dict] = [] + splits: dict[str, list[str]] = defaultdict(list) + query_counter = 0 + reasoning_types: list[str] = [] + query_types: list[str] = [] + + for _, row in generated_df.iterrows(): + if "file_name" not in row: + print("Warning: Skipping row without file_name") + continue + + file_name = row["file_name"] + qa_pairs = _extract_qa_pairs(row, file_name) + if qa_pairs is None: + continue + + for qa_pair in qa_pairs: + parsed = _parse_qa_pair(qa_pair) + if not parsed["question"] or not isinstance(parsed["question"], str): + continue + + query_id = f"q{query_counter:08d}" + query_counter += 1 + reasoning_types.append(parsed["reasoning_type"]) + query_types.append(parsed["query_type"]) + + metadata = { + "query_type": parsed["query_type"], + "reasoning_type": parsed["reasoning_type"], + "question_complexity": parsed["question_complexity"], + "hop_count": parsed["hop_count"], + "segment_ids": parsed["segment_ids"], + "source_file": file_name, + "answer": parsed["answer"], + } + if parsed["hop_contexts"]: + metadata["hop_contexts"] = parsed["hop_contexts"] + + queries_data.append({"_id": query_id, "metadata": metadata, "text": parsed["question"]}) + qrels_data.append({"query-id": query_id, "corpus-id": file_name, "score": 1}) + splits["text"].append(query_id) + + queries_df = pd.DataFrame(queries_data) + qrels_df = pd.DataFrame(qrels_data) + + total_queries = len(queries_df) + if total_queries > 0: + print(f"\nGenerated {total_queries} queries from {len(generated_df)} documents") + _print_distribution("Reasoning type", reasoning_types, total_queries) + _print_distribution("Query type", query_types, total_queries) + else: + print("\nWarning: No queries generated!") + + return queries_df, qrels_df, dict(splits) + + +# --------------------------------------------------------------------------- +# Quality filtering +# --------------------------------------------------------------------------- + + +def filter_qa_pairs_by_quality( + generated_df: pd.DataFrame, + quality_threshold: float = 7.0, +) -> tuple[pd.DataFrame, list[dict]]: + """Filter deduplicated QA pairs using evaluation scores. + + Each pair's ``overall.score`` from the ``qa_evaluations`` column is + compared against *quality_threshold*. Rows with mismatched + evaluation/pair counts are skipped. + + Args: + generated_df: DataFrame with ``deduplicated_qa_pairs``, + ``qa_evaluations``, and ``file_name`` columns. + quality_threshold: Minimum overall quality score to retain a pair. + + Returns: + Tuple of ``(filtered_df, skipped_files)`` where *skipped_files* is + a list of ``{"file_name": ..., "reason": ...}`` dicts. + """ + print(f"Filtering QA pairs based on quality threshold: {quality_threshold}") + + total_pairs = 0 + filtered_pairs = 0 + all_filtered: list[dict] = [] + skipped_files: list[dict] = [] + + for _, row in generated_df.iterrows(): + file_name = row.get("file_name", "unknown") + dedup_pairs = _to_list(row.get("deduplicated_qa_pairs")) + if dedup_pairs is None: + print(f"Warning: Skipping {file_name} - deduplicated_qa_pairs is None") + continue + if not dedup_pairs: + print(f"Warning: Skipping {file_name} - no valid deduplicated pairs found") + continue + + scores = _extract_evaluation_scores(row.get("qa_evaluations")) + + if len(scores) != len(dedup_pairs): + reason = f"deduplicated_qa_pairs has {len(dedup_pairs)} items but qa_evaluations has {len(scores)} items" + print(f"Warning: Skipping {file_name} - data integrity error: {reason}") + skipped_files.append({"file_name": file_name, "reason": reason}) + continue + + for pair_idx, qa_pair in enumerate(dedup_pairs): + total_pairs += 1 + quality_score = scores[pair_idx] if pair_idx < len(scores) else 0 + if quality_score >= quality_threshold: + pair_dict = _qa_pair_to_dict(qa_pair) + pair_dict["file_name"] = file_name + pair_dict["quality_score"] = quality_score + all_filtered.append(pair_dict) + else: + filtered_pairs += 1 + + filtered_df = pd.DataFrame(all_filtered) + + print("\nQuality Filtering Results:") + print(f" Total QA pairs: {total_pairs}") + print(f" Filtered out (score < {quality_threshold}): {filtered_pairs}") + print(f" Remaining high-quality pairs: {len(filtered_df)}") + print(f" Files skipped due to data issues: {len(skipped_files)}") + retention = len(filtered_df) / total_pairs * 100 if total_pairs > 0 else 0 + print(f" Retention rate: {retention:.1f}%") + + return filtered_df, skipped_files + + +# --------------------------------------------------------------------------- +# Modality / BEIR loader +# --------------------------------------------------------------------------- + + +def load_positive_docs_with_modality( + test_tsv_path: Path, + corpus_jsonl_path: Path, + split_json_path: Path, + min_text_length: int = 0, +) -> tuple[pd.DataFrame, dict[str, str]]: + """Load positive documents and map them to their modalities. + + Args: + test_tsv_path: Path to ``qrels/test.tsv``. + corpus_jsonl_path: Path to ``corpus.jsonl``. + split_json_path: Path to ``split.json``. + min_text_length: Minimum text length to include a document. + + Returns: + Tuple of ``(positive_docs_df, doc_to_modality_final)``. + """ + qrels_df = pd.read_csv(test_tsv_path, sep="\t") + + with open(split_json_path, encoding="utf-8") as f: + splits = json.load(f) + + query_to_modality: dict[str, str] = {} + for modality, query_ids in splits.items(): + for query_id in query_ids: + query_to_modality[query_id] = modality + + doc_to_modality: dict[str, set[str]] = defaultdict(set) + for _, row in qrels_df.iterrows(): + query_id = row["query-id"] + corpus_id = row["corpus-id "] # trailing space in column name + if query_id in query_to_modality: + doc_to_modality[corpus_id].add(query_to_modality[query_id]) + + doc_to_modality_final: dict[str, str] = {} + for doc_id, modalities in doc_to_modality.items(): + if len(modalities) == 1: + doc_to_modality_final[doc_id] = next(iter(modalities)) + else: + modality_counts: dict[str, int] = defaultdict(int) + for _, r in qrels_df[qrels_df["corpus-id "] == doc_id].iterrows(): + qid = r["query-id"] + if qid in query_to_modality: + modality_counts[query_to_modality[qid]] += 1 + doc_to_modality_final[doc_id] = max(modality_counts, key=modality_counts.get) # type: ignore[arg-type] + + unique_group_ids = set(doc_to_modality_final.keys()) + corpus_docs_by_group: dict[str, dict] = {} + with open(corpus_jsonl_path, encoding="utf-8") as f: + for line in f: + doc = json.loads(line) + group_id = doc.get("group_id", doc["_id"]) + if group_id in unique_group_ids and group_id not in corpus_docs_by_group: + corpus_docs_by_group[group_id] = doc + + positive_docs_data: list[dict] = [] + for group_id, modality in doc_to_modality_final.items(): + if group_id in corpus_docs_by_group: + doc = corpus_docs_by_group[group_id] + positive_docs_data.append( + { + "doc_id": doc["_id"], + "text": doc["text"], + "title": doc.get("title", ""), + "modality": modality, + "group_id": group_id, + } + ) + + positive_docs_df = pd.DataFrame(positive_docs_data) + + if min_text_length > 0 and len(positive_docs_df) > 0: + original_count = len(positive_docs_df) + positive_docs_df = positive_docs_df[positive_docs_df["text"].str.len() >= min_text_length] + filtered_count = original_count - len(positive_docs_df) + if filtered_count > 0: + print(f"Filtered out {filtered_count} documents shorter than {min_text_length} characters") + + return positive_docs_df, doc_to_modality_final + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + + +def _to_list(value: object) -> list | None: + """Coerce *value* to a Python list, handling numpy arrays.""" + if value is None: + return None + if isinstance(value, np.ndarray): + return value.tolist() + if isinstance(value, list): + return value + return None + + +def _extract_qa_pairs(row: pd.Series, file_name: object) -> list | None: + """Pull the QA pairs list from a generated row.""" + if "deduplicated_qa_pairs" in row and row["deduplicated_qa_pairs"] is not None: + pairs = row["deduplicated_qa_pairs"] + elif "qa_generation" in row: + qa_gen = row.get("qa_generation") + if qa_gen is None: + print(f"Warning: Skipping {file_name} - qa_generation is None") + return None + if isinstance(qa_gen, dict): + pairs = qa_gen.get("pairs", []) + else: + pairs = getattr(qa_gen, "pairs", []) + else: + print(f"Warning: Skipping {file_name} - no qa_generation or deduplicated_qa_pairs found") + return None + + pairs = _to_list(pairs) if not isinstance(pairs, list) else pairs + if not pairs: + print(f"Warning: Skipping {file_name} - no valid pairs found") + return None + return pairs + + +def _parse_qa_pair(qa_pair: object) -> dict: + """Normalise a QA pair (dict or Pydantic model) to a plain dict.""" + fields = ( + "question", + "answer", + "query_type", + "reasoning_type", + "question_complexity", + "segment_ids", + "hop_count", + "hop_contexts", + ) + defaults = ("", "", "", "", 0, [], 1, []) + + result: dict = {} + for field, default in zip(fields, defaults): + if isinstance(qa_pair, dict): + val = qa_pair.get(field, default) + else: + val = getattr(qa_pair, field, default) + if isinstance(val, np.ndarray): + val = val.tolist() + result[field] = val + return result + + +def _qa_pair_to_dict(qa_pair: object) -> dict: + """Convert a QA pair to a plain dict for DataFrame construction.""" + keys = ( + "question", + "answer", + "query_type", + "reasoning_type", + "question_complexity", + "segment_ids", + "hop_count", + "hop_contexts", + ) + if isinstance(qa_pair, dict): + return {k: qa_pair.get(k, None) for k in keys} + return {k: getattr(qa_pair, k, None) for k in keys} + + +def _extract_evaluation_scores(qa_evaluations: object) -> list[float]: + """Pull overall scores from the qa_evaluations object.""" + scores: list[float] = [] + if qa_evaluations is None: + return scores + + if isinstance(qa_evaluations, dict): + evaluations_list = qa_evaluations.get("evaluations", []) + else: + evaluations_list = getattr(qa_evaluations, "evaluations", []) + + if isinstance(evaluations_list, np.ndarray): + evaluations_list = evaluations_list.tolist() + + for eval_item in evaluations_list: + if isinstance(eval_item, dict): + overall = eval_item.get("overall", {}) + else: + overall = getattr(eval_item, "overall", None) + + if isinstance(overall, dict): + scores.append(overall.get("score", 0)) + elif overall is not None: + scores.append(getattr(overall, "score", 0)) + else: + scores.append(0) + + return scores + + +def _print_distribution(label: str, values: list[str], total: int) -> None: + """Print a frequency distribution to stdout.""" + print(f"\n{label} distribution:") + dist = pd.Series(values).value_counts() + for name, count in dist.items(): + pct = count / total * 100 + print(f" {name}: {count} queries ({pct:.1f}%)") diff --git a/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/prompts.py b/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/prompts.py new file mode 100644 index 0000000..b3d4ba9 --- /dev/null +++ b/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/prompts.py @@ -0,0 +1,298 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Prompt templates for the retriever SDG pipeline. + +All long-form system and user prompts are centralised here as module-level +constants so that ``pipeline.py`` stays concise and the prompts are easy to +review or override. +""" + +# --------------------------------------------------------------------------- +# Artifact extraction +# --------------------------------------------------------------------------- + +ARTIFACT_EXTRACTION_SYSTEM_PROMPT = "You are an expert at analyzing documents and extracting semantic artifacts." + +ARTIFACT_EXTRACTION_USER_PROMPT = """\ +Analyze the following content and extract semantic artifacts that would be \ +valuable for generating high-quality question-answer pairs. + +Note: The content may contain multiple documents bundled together \ +(separated by "=== Document Boundary ==="). \ +If multiple documents are present, identify cross-document relationships \ +and connections. + +CONTENT: +{{{{ text }}}} + +ARTIFACT TYPES TO EXTRACT: +- key_concepts: Core ideas and concepts discussed in the document(s) +- relationships: Connections and relationships between different concepts \ +(including cross-document relationships) +- themes: Overarching themes and topics +- entities: Specific entities, people, organizations, or items mentioned +- processes: Processes, workflows, or procedures described +- insights: Key insights, conclusions, or findings +- technical_terms: Technical terminology and specialized vocabulary +- contextual_factors: Contextual information that provides background + +INSTRUCTIONS: +1. Extract up to {max_artifacts_per_type} artifacts for each relevant type +2. Focus on the most significant and informative elements +3. Provide clear, concise descriptions for each artifact +4. Include context about why each artifact is important +5. Ensure artifacts are specific and actionable for Q&A generation +6. For multi-document bundles, pay special attention to relationships \ +and comparisons between documents +""" + +# --------------------------------------------------------------------------- +# QA generation +# --------------------------------------------------------------------------- + +QA_GENERATION_SYSTEM_PROMPT = ( + "You are an expert at extracting question and answer pairs from provided context/transcript/segments." +) + +QA_GENERATION_USER_PROMPT = """\ +You are an expert at extracting question and answer pairs from provided \ +context/transcript/segments. + +: +{{%- if document_artifacts.key_concepts %}} + +{{%- for item in document_artifacts.key_concepts %}} +- {{{{ item.text }}}}: {{{{ item.description }}}} +{{%- endfor %}} + +{{%- endif %}} + +{{%- if document_artifacts.relationships %}} + +{{%- for item in document_artifacts.relationships %}} +- {{{{ item.text }}}}: {{{{ item.description }}}} +{{%- endfor %}} + +{{%- endif %}} + +{{%- if document_artifacts.themes %}} + +{{%- for item in document_artifacts.themes %}} +- {{{{ item.text }}}}: {{{{ item.description }}}} +{{%- endfor %}} + +{{%- endif %}} + +{{%- if document_artifacts.entities %}} + +{{%- for item in document_artifacts.entities %}} +- {{{{ item.text }}}}: {{{{ item.description }}}} +{{%- endfor %}} + +{{%- endif %}} + +{{%- if document_artifacts.processes %}} + +{{%- for item in document_artifacts.processes %}} +- {{{{ item.text }}}}: {{{{ item.description }}}} +{{%- endfor %}} + +{{%- endif %}} + +{{%- if document_artifacts.insights %}} + +{{%- for item in document_artifacts.insights %}} +- {{{{ item.text }}}}: {{{{ item.description }}}} +{{%- endfor %}} + +{{%- endif %}} + +{{%- if document_artifacts.technical_terms %}} + +{{%- for item in document_artifacts.technical_terms %}} +- {{{{ item.text }}}}: {{{{ item.description }}}} +{{%- endfor %}} + +{{%- endif %}} + +{{%- if document_artifacts.contextual_factors %}} + +{{%- for item in document_artifacts.contextual_factors %}} +- {{{{ item.text }}}}: {{{{ item.description }}}} +{{%- endfor %}} + +{{%- endif %}} + + +: +{{%- for section in sections_structured %}} +{{{{ section }}}} + +{{%- endfor %}} + + +Guidelines: +1. Generate questions with varying complexity levels between 1 (simple) and \ +5 (complex): + - All questions MUST require understanding connections between different \ +parts of the context/transcript/segments + - Questions should test deep understanding, not simple facts + - Do not mention the existence of a context/transcript in the generated \ +question like "in the transcript", "from the given context", or \ +"in Segment 148". Produce a natural, standalone question. + - Only use facts present in the provided context/transcript; if missing, \ +say you cannot generate a question. + - Example: "How does the speaker's initial explanation of X relate to \ +the later implementation of Y?" + +2. Question Types to Generate (for the "query_type" field - ONLY these 3 \ +values allowed): + - "multi_hop" ({query_counts_multi_hop} questions): Connect \ +{min_hops}-{max_hops} separated segments + - "structural" ({query_counts_structural} questions): Focus on \ +relationships between concepts + - "contextual" ({query_counts_contextual} questions): Require \ +surrounding context to understand + - Use the cross-part context snippets to connect evidence that lives \ +outside the current transcript section + +3. Reasoning Types to Include (for the "reasoning_type" field - ONLY these \ +7 values allowed): + - "factual" ({reasoning_counts_factual} questions): Ask for complex \ +facts that require synthesizing multiple pieces of information \ +(NOT simple lookups) + - "relational" ({reasoning_counts_relational} questions): Ask how data \ +points compare or correlate across different segments + - "inferential" ({reasoning_counts_inferential} questions): Ask about \ +conclusions or implications requiring synthesis + - "temporal" ({reasoning_counts_temporal} questions): Ask about changes \ +or events over time across segments + - "procedural" ({reasoning_counts_procedural} questions): Ask about \ +complex multi-step processes or guidelines + - "visual" ({reasoning_counts_visual} questions): Ask about visual \ +details requiring cross-reference + - "causal" ({reasoning_counts_causal} questions): Ask about cause-effect \ +chains spanning segments + + Example COMPLEX questions by reasoning type: + - Factual: "What is the total combined budget allocation across all \ +departmental initiatives mentioned, and how does it relate to the overall \ +fiscal year target?" + - Relational: "How does the performance metric achieved in Q2 compare to \ +both the initial baseline and the revised targets that were set?" + - Inferential: "Based on the challenges outlined and the proposed \ +solutions, what unstated assumptions underlie the strategic pivot?" + - Temporal: "How did the implementation timeline evolve from the initial \ +proposal through the mid-year review to the final execution phase?" + - Procedural: "What is the complete approval workflow including standard \ +requirements, exceptions, and escalation processes?" + - Visual: "How do the visual elements presented relate to the verbal \ +descriptions provided, and what discrepancies exist between them?" + - Causal: "What chain of events, starting from the initial decision, led \ +through various complications to the final outcome?" + +4. IMPORTANT - Orthogonal Distributions (query_type and reasoning_type are \ +SEPARATE fields): + - Each question must have BOTH a query_type \ +(multi_hop/structural/contextual) AND a reasoning_type \ +(factual/relational/inferential/temporal/procedural/visual/causal) + - These are TWO DIFFERENT fields - do NOT put reasoning types in the \ +query_type field! + - For example: A question can be query_type="multi_hop" with \ +reasoning_type="procedural" + - Ensure the final distribution matches both specified percentages + +5. **IMPORTANT - Segment Identification**: + - The content below contains segments formatted as \ +"Segment N (HH:MM:SS - HH:MM:SS): text" or \ +"Segment N [Doc: doc_id] (HH:MM:SS - HH:MM:SS): text" where N starts from 1 + - The "[Doc: doc_id]" tag indicates which document the segment belongs to \ +(for multi-document bundles) + - For each question-answer pair you generate, identify ALL segment numbers \ +FROM which the question is derived + - These segments are the source material that should be retrieved when \ +someone asks this question + - Record these segment numbers in the "segment_ids" field as a list of \ +integers (e.g., [1, 4, 8]) + - For multi-document bundles, prefer questions that span multiple \ +documents to maximize cross-document reasoning + - For multi-hop questions: + * The top-level "segment_ids" should be the UNION of all segment IDs \ +across all hops + * Each hop in "hop_contexts" should specify its own "segment_ids" list + * Example: If hop 1 uses [1, 3] and hop 2 uses [6, 8], then top-level \ +segment_ids should be [1, 3, 6, 8] + * For multi-document bundles, try to have different hops reference \ +different documents + +6. For Each Question: + - Must have complexity level {min_complexity} or higher + - Generate the question FROM the identified segments (these segments are \ +the source material) + - Multi-hop questions must specify hop_count ({min_hops}-{max_hops}) + - Provide hop_contexts: a list where each hop includes "hop_number", \ +"segment_ids" (the source segments for this hop), and "summary" \ +(a concise summary describing the supporting segments). + +7. Generate {num_pairs} distinct question and answer pairs. + +The output should be a JSON object with a "pairs" field containing an array \ +of {num_pairs} objects, where each object contains: + - "question": the question, requiring understanding of the \ +contexts/transcripts/segments without explicitly referencing the \ +context/transcript/segments in the question + - "answer": comprehensive answer from the contexts/transcripts/segments \ +without explicitly referencing the context/transcript/segments in the answer + - "question_complexity": numeric score {min_complexity}-5 + - "query_type": MUST be exactly one of these three values: "multi_hop", \ +"structural", or "contextual" (NO other values allowed - do NOT use \ +reasoning types here) + - "reasoning_type": MUST be exactly one of these seven values: "factual", \ +"relational", "inferential", "temporal", "procedural", "visual", or \ +"causal" (this is DIFFERENT from query_type) + - "segment_ids": list of segment numbers (e.g., [1, 4, 8]) that are the \ +source material for this question (these should be retrieved when the \ +question is asked) + - "hop_count": number of hops ({min_hops}-{max_hops}) for multi_hop \ +questions, or 1 for non-multi-hop questions + - "hop_contexts": array of hop detail objects with "hop_number", \ +"segment_ids", "summary" + +CRITICAL: "query_type" and "reasoning_type" are TWO SEPARATE FIELDS with \ +different allowed values. Do NOT mix them up: + - query_type can ONLY be: "multi_hop", "structural", "contextual" + - reasoning_type can ONLY be: "factual", "relational", "inferential", \ +"temporal", "procedural", "visual", "causal" +""" + +# --------------------------------------------------------------------------- +# QA evaluation +# --------------------------------------------------------------------------- + +QA_EVALUATION_SYSTEM_PROMPT = "You are an expert evaluator of question-answer pairs." + +QA_EVALUATION_USER_PROMPT = """\ +You are an expert evaluator of question-answer pairs. + +You will evaluate multiple question-answer pairs from a document. + +{% for qa_pair in deduplicated_qa_pairs %} +=== QA Pair {{ loop.index }} === + +QUESTION: {{ qa_pair.question }} + +ANSWER: {{ qa_pair.answer }} + +CONTEXT (Relevant Segment IDs): {{ qa_pair.segment_ids }} + +{% endfor %} + + +{% for chunk in chunks %} +- Segment {{ chunk.chunk_id }}: {{ chunk.text }} +{% endfor %} + + +Evaluate EACH of the {{ deduplicated_qa_pairs | length }} QA pairs above. +""" diff --git a/plugins/data-designer-retrieval-sdg/tests/test_convert.py b/plugins/data-designer-retrieval-sdg/tests/test_convert.py new file mode 100644 index 0000000..40bedf8 --- /dev/null +++ b/plugins/data-designer-retrieval-sdg/tests/test_convert.py @@ -0,0 +1,182 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import json +from pathlib import Path + +import pandas as pd + +from data_designer_retrieval_sdg.convert import ( + UnionFind, + build_corpus_and_mappings, + create_train_val_test_split, + extract_base_filename, + file_tuple_in_set, + filter_mismatched_records, + generate_eval_set, + generate_training_set, + get_corpus_id, + get_file_identifier, + load_generated_json_files, + merge_groups_union_find, + normalize_file_name, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def test_get_corpus_id_deterministic() -> None: + assert get_corpus_id("hello") == get_corpus_id("hello") + assert get_corpus_id("hello") != get_corpus_id("world") + assert get_corpus_id("hello").startswith("d_") + + +def test_extract_base_filename() -> None: + assert extract_base_filename("path/to/file.txt") == "file" + assert extract_base_filename("README") == "README" + + +def test_normalize_file_name() -> None: + assert normalize_file_name("file.txt") == ["file.txt"] + assert normalize_file_name(["a.txt", "b.txt"]) == ["a.txt", "b.txt"] + assert normalize_file_name(42) == ["42"] + + +def test_get_file_identifier_single() -> None: + assert get_file_identifier(["path/to/doc.txt"]) == "doc" + + +def test_get_file_identifier_multi() -> None: + ident = get_file_identifier(["a.txt", "b.txt"]) + assert len(ident) == 16 # MD5 truncated + + +def test_file_tuple_in_set() -> None: + s = {("a.txt",), ("b.txt", "c.txt")} + assert file_tuple_in_set(["a.txt"], s) is True + assert file_tuple_in_set(["b.txt", "c.txt"], s) is True + assert file_tuple_in_set(["d.txt"], s) is False + + +# --------------------------------------------------------------------------- +# filter_mismatched_records +# --------------------------------------------------------------------------- + + +def test_filter_mismatched_records() -> None: + records = [ + {"file_name": "ok", "deduplicated_qa_pairs": [1], "qa_evaluations": {"evaluations": [1]}}, + {"file_name": "bad", "deduplicated_qa_pairs": [1, 2], "qa_evaluations": {"evaluations": [1]}}, + ] + filtered, dropped = filter_mismatched_records(records) + assert len(filtered) == 1 + assert dropped == 1 + + +# --------------------------------------------------------------------------- +# build_corpus_and_mappings +# --------------------------------------------------------------------------- + + +def test_build_corpus_and_mappings() -> None: + df = pd.DataFrame( + [ + { + "file_name": ["a.txt"], + "chunks": [{"chunk_id": 1, "text": "hello"}, {"chunk_id": 2, "text": "world"}], + } + ] + ) + corpus, mapping = build_corpus_and_mappings(df) + assert len(corpus) == 2 + assert ("a", 1) in mapping + assert mapping[("a", 1)] == "hello" + + +# --------------------------------------------------------------------------- +# create_train_val_test_split +# --------------------------------------------------------------------------- + + +def test_split_basic() -> None: + rows = [{"file_name": [f"f{i}.txt"], "question": f"Q{i}"} for i in range(10)] + df = pd.DataFrame(rows) + train, val, test = create_train_val_test_split(df, train_ratio=0.6, val_ratio=0.2, seed=42) + assert len(train) + len(val) + len(test) == 10 + + +# --------------------------------------------------------------------------- +# UnionFind +# --------------------------------------------------------------------------- + + +def test_union_find() -> None: + uf = UnionFind() + uf.union("a", "b") + uf.union("b", "c") + assert uf.find("a") == uf.find("c") + assert uf.find("d") != uf.find("a") + + +def test_merge_groups_union_find() -> None: + groups = {"g1": ["a", "b"], "g2": ["b", "c"]} + merged = merge_groups_union_find(groups) + assert len(merged) == 1 + members = list(merged.values())[0] + assert set(members) == {"a", "b", "c"} + + +# --------------------------------------------------------------------------- +# load_generated_json_files +# --------------------------------------------------------------------------- + + +def test_load_from_single_file(tmp_path: Path) -> None: + data = [ + { + "file_name": "doc.txt", + "deduplicated_qa_pairs": [{"question": "Q"}], + "qa_evaluations": {"evaluations": [{"overall": {"score": 8}}]}, + } + ] + p = tmp_path / "data.json" + p.write_text(json.dumps(data)) + df = load_generated_json_files(str(p)) + assert len(df) == 1 + assert df.iloc[0]["file_name"] == ["doc.txt"] + + +def test_load_from_directory(tmp_path: Path) -> None: + for i in range(2): + data = [{"file_name": f"d{i}.txt", "deduplicated_qa_pairs": [], "qa_evaluations": {"evaluations": []}}] + (tmp_path / f"generated_batch{i}.json").write_text(json.dumps(data)) + df = load_generated_json_files(str(tmp_path)) + assert len(df) == 2 + + +# --------------------------------------------------------------------------- +# generate_training_set / generate_eval_set +# --------------------------------------------------------------------------- + + +def test_generate_training_set(tmp_path: Path) -> None: + corpus = {"hello": "d_abc"} + chunk_mapping = {("doc", 1): "hello"} + df = pd.DataFrame([{"file_name": ["doc.txt"], "question": "Q?", "segment_ids": [1]}]) + generate_training_set(corpus, chunk_mapping, df, str(tmp_path), "my_corpus") + train_path = tmp_path / "train.json" + assert train_path.exists() + payload = json.loads(train_path.read_text()) + assert len(payload["data"]) == 1 + + +def test_generate_eval_set(tmp_path: Path) -> None: + corpus = {"hello": "d_abc"} + chunk_mapping = {("doc", 1): "hello"} + df = pd.DataFrame([{"file_name": ["doc.txt"], "question": "Q?", "segment_ids": [1]}]) + generate_eval_set(corpus, chunk_mapping, df, str(tmp_path), eval_only=True) + assert (tmp_path / "corpus.jsonl").exists() + assert (tmp_path / "queries.jsonl").exists() + assert (tmp_path / "qrels" / "test.tsv").exists() diff --git a/plugins/data-designer-retrieval-sdg/tests/test_dedup.py b/plugins/data-designer-retrieval-sdg/tests/test_dedup.py new file mode 100644 index 0000000..933fce8 --- /dev/null +++ b/plugins/data-designer-retrieval-sdg/tests/test_dedup.py @@ -0,0 +1,74 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for the deduplication logic (pure numpy, no LLM needed).""" + +from data_designer_retrieval_sdg.config import RetrievalSdgDedupColumnConfig + + +def _make_generator(): + """Instantiate the generator with minimal wiring for dedupe_qa_pairs. + + We only need the config for threshold; the embedder is not used in + dedupe_qa_pairs itself. + """ + from unittest.mock import MagicMock + + from data_designer_retrieval_sdg.dedup import RetrievalSdgDedupColumnGenerator + + config = RetrievalSdgDedupColumnConfig( + name="dedup", + qa_pairs_column="qa", + embedding_alias="embed", + dedupe_similarity_threshold=0.9, + ) + gen = object.__new__(RetrievalSdgDedupColumnGenerator) + gen._config = config + gen._resource_provider = MagicMock() + return gen + + +def test_dedupe_empty() -> None: + gen = _make_generator() + assert gen.dedupe_qa_pairs([]) == [] + + +def test_dedupe_no_duplicates() -> None: + gen = _make_generator() + embeddings = [[1.0, 0.0], [0.0, 1.0], [0.7, 0.7]] + kept = gen.dedupe_qa_pairs(embeddings) + assert kept == [0, 1, 2] + + +def test_dedupe_identical_vectors() -> None: + gen = _make_generator() + embeddings = [[1.0, 0.0], [1.0, 0.0], [0.0, 1.0]] + kept = gen.dedupe_qa_pairs(embeddings) + assert 0 in kept + assert 1 not in kept + assert 2 in kept + + +def test_dedupe_near_threshold() -> None: + gen = _make_generator() + v1 = [1.0, 0.0] + v2 = [0.95, 0.3122] # cosine sim ≈ 0.95 > 0.9 + v3 = [0.0, 1.0] + kept = gen.dedupe_qa_pairs([v1, v2, v3]) + assert 0 in kept + assert 1 not in kept + assert 2 in kept + + +def test_dedupe_single_element() -> None: + gen = _make_generator() + kept = gen.dedupe_qa_pairs([[1.0, 0.0]]) + assert kept == [0] + + +def test_config_column_type() -> None: + cfg = RetrievalSdgDedupColumnConfig(name="dedup", qa_pairs_column="qa", embedding_alias="embed") + assert cfg.column_type == "retrieval-sdg-dedup" + assert cfg.required_columns == ["qa"] + assert cfg.side_effect_columns == [] + assert cfg.get_column_emoji() == "🔍" diff --git a/plugins/data-designer-retrieval-sdg/tests/test_ingest.py b/plugins/data-designer-retrieval-sdg/tests/test_ingest.py new file mode 100644 index 0000000..5c87473 --- /dev/null +++ b/plugins/data-designer-retrieval-sdg/tests/test_ingest.py @@ -0,0 +1,145 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from pathlib import Path + +import pytest + +from data_designer_retrieval_sdg.ingest import ( + build_bundle_id, + build_bundles, + chunks_to_sections_structured, + file_matches_extensions, + is_traditional_extension, + load_text_files_from_directory, + text_to_sentence_chunks, +) + +# --------------------------------------------------------------------------- +# is_traditional_extension +# --------------------------------------------------------------------------- + + +def test_traditional_extensions() -> None: + assert is_traditional_extension(".txt") is True + assert is_traditional_extension(".md") is True + assert is_traditional_extension(".json") is True + assert is_traditional_extension(".mp3") is True + + +def test_non_traditional_extensions() -> None: + assert is_traditional_extension("") is False + assert is_traditional_extension(".com_publication_2001") is False + assert is_traditional_extension(".averylongextension123") is False + + +# --------------------------------------------------------------------------- +# file_matches_extensions +# --------------------------------------------------------------------------- + + +def test_file_matches_extensions_standard() -> None: + assert file_matches_extensions(Path("doc.txt"), [".txt", ".md"]) is True + assert file_matches_extensions(Path("doc.py"), [".txt", ".md"]) is False + + +def test_file_matches_extensions_no_ext() -> None: + assert file_matches_extensions(Path("README"), [""]) is True + + +# --------------------------------------------------------------------------- +# text_to_sentence_chunks +# --------------------------------------------------------------------------- + + +def test_text_to_sentence_chunks_basic() -> None: + text = "First sentence. Second sentence. Third sentence. Fourth sentence. Fifth sentence. Sixth sentence." + chunks = text_to_sentence_chunks(text, sentences_per_chunk=3) + assert len(chunks) == 2 + assert chunks[0]["chunk_id"] == 1 + assert chunks[1]["chunk_id"] == 2 + assert chunks[0]["sentence_count"] == 3 + + +def test_text_to_sentence_chunks_with_doc_id() -> None: + chunks = text_to_sentence_chunks("Hello world. Goodbye.", sentences_per_chunk=5, doc_id="doc1") + assert len(chunks) == 1 + assert chunks[0]["doc_id"] == "doc1" + + +def test_text_to_sentence_chunks_empty() -> None: + assert text_to_sentence_chunks("") == [] + + +# --------------------------------------------------------------------------- +# Section strategies +# --------------------------------------------------------------------------- + + +def test_chunks_to_sections_sequential() -> None: + chunks = [{"text": f"chunk {i}", "chunk_id": i} for i in range(1, 7)] + sections = chunks_to_sections_structured(chunks, num_sections=2, strategy="sequential") + assert len(sections) == 2 + assert "Section 1" in sections[0] + assert "Section 2" in sections[1] + + +def test_chunks_to_sections_empty() -> None: + assert chunks_to_sections_structured([], num_sections=2) == [] + + +# --------------------------------------------------------------------------- +# build_bundles +# --------------------------------------------------------------------------- + + +def test_build_bundles_sequential(tmp_path: Path) -> None: + files = [tmp_path / f"f{i}.txt" for i in range(4)] + for f in files: + f.write_text("content") + bundles = build_bundles(files, bundle_size=2, max_docs_per_bundle=3) + assert len(bundles) == 2 + assert len(bundles[0]) == 2 + + +def test_build_bundles_exceeds_max(tmp_path: Path) -> None: + files = [tmp_path / f"f{i}.txt" for i in range(4)] + for f in files: + f.write_text("content") + with pytest.raises(ValueError, match="exceeds max_docs_per_bundle"): + build_bundles(files, bundle_size=4, max_docs_per_bundle=2) + + +# --------------------------------------------------------------------------- +# build_bundle_id +# --------------------------------------------------------------------------- + + +def test_build_bundle_id_deterministic() -> None: + a = build_bundle_id(["a.txt", "b.txt"]) + b = build_bundle_id(["b.txt", "a.txt"]) + assert a == b + + +def test_build_bundle_id_empty() -> None: + assert build_bundle_id([]) == "" + + +# --------------------------------------------------------------------------- +# load_text_files_from_directory +# --------------------------------------------------------------------------- + + +def test_load_text_files_single_doc(tmp_path: Path) -> None: + (tmp_path / "a.txt").write_text("Hello world. This is a test. Another sentence.") + (tmp_path / "b.txt").write_text("Foo bar. Baz quux. Something else.") + df = load_text_files_from_directory(tmp_path, sentences_per_chunk=2) + assert len(df) == 2 + assert "file_name" in df.columns + assert "chunks" in df.columns + assert "sections_structured" in df.columns + + +def test_load_text_files_no_files(tmp_path: Path) -> None: + with pytest.raises(ValueError, match="No text files found"): + load_text_files_from_directory(tmp_path) diff --git a/plugins/data-designer-retrieval-sdg/tests/test_models.py b/plugins/data-designer-retrieval-sdg/tests/test_models.py new file mode 100644 index 0000000..0a46c80 --- /dev/null +++ b/plugins/data-designer-retrieval-sdg/tests/test_models.py @@ -0,0 +1,81 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from data_designer_retrieval_sdg.models import ( + ArtifactItem, + DocumentArtifacts, + HopContext, + QAEvaluation, + QAEvaluationCriterion, + QAOverallEvaluation, + QAPairEvaluations, + QuestionAnswerPair, + QuestionAnswerPairs, +) + + +def test_artifact_item_round_trip() -> None: + item = ArtifactItem(text="concept", description="a concept", importance="high") + assert item.text == "concept" + data = item.model_dump() + assert ArtifactItem.model_validate(data) == item + + +def test_document_artifacts_defaults() -> None: + artifacts = DocumentArtifacts() + assert artifacts.key_concepts == [] + assert artifacts.technical_terms == [] + + +def test_question_answer_pair() -> None: + pair = QuestionAnswerPair( + question="What?", + answer="This.", + question_complexity=4, + query_type="multi_hop", + reasoning_type="factual", + segment_ids=[1, 3], + hop_count=2, + hop_contexts=[ + HopContext(hop_number=1, segment_ids=[1], summary="first"), + HopContext(hop_number=2, segment_ids=[3], summary="second"), + ], + ) + assert pair.query_type == "multi_hop" + assert len(pair.hop_contexts) == 2 + + +def test_question_answer_pairs_container() -> None: + pairs = QuestionAnswerPairs( + pairs=[ + QuestionAnswerPair( + question="Q1", + answer="A1", + question_complexity=4, + query_type="structural", + reasoning_type="relational", + segment_ids=[2], + hop_count=1, + hop_contexts=[], + ) + ] + ) + assert len(pairs.pairs) == 1 + + +def test_qa_evaluation_round_trip() -> None: + evl = QAEvaluation( + relevance=QAEvaluationCriterion(score=8, justification="relevant"), + accuracy=QAEvaluationCriterion(score=9, justification="accurate"), + context_support=QAEvaluationCriterion(score=7, justification="supported"), + clarity=QAEvaluationCriterion(score=8, justification="clear"), + overall=QAOverallEvaluation(score=8.0, assessment="good"), + improvements="none", + ) + data = evl.model_dump() + assert QAEvaluation.model_validate(data).overall.score == 8.0 + + +def test_qa_pair_evaluations() -> None: + evals = QAPairEvaluations(evaluations=[]) + assert evals.evaluations == [] diff --git a/plugins/data-designer-retrieval-sdg/tests/test_plugin.py b/plugins/data-designer-retrieval-sdg/tests/test_plugin.py new file mode 100644 index 0000000..89855a2 --- /dev/null +++ b/plugins/data-designer-retrieval-sdg/tests/test_plugin.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from data_designer.engine.testing.utils import assert_valid_plugin + +from data_designer_retrieval_sdg.plugin import plugin + + +def test_valid_plugin() -> None: + assert_valid_plugin(plugin) diff --git a/plugins/data-designer-retrieval-sdg/tests/test_postprocess.py b/plugins/data-designer-retrieval-sdg/tests/test_postprocess.py new file mode 100644 index 0000000..f73b24d --- /dev/null +++ b/plugins/data-designer-retrieval-sdg/tests/test_postprocess.py @@ -0,0 +1,87 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import pandas as pd + +from data_designer_retrieval_sdg.postprocess import filter_qa_pairs_by_quality, postprocess_retriever_data + +# --------------------------------------------------------------------------- +# postprocess_retriever_data +# --------------------------------------------------------------------------- + + +def test_postprocess_basic() -> None: + df = pd.DataFrame( + [ + { + "file_name": ["doc.txt"], + "deduplicated_qa_pairs": [ + { + "question": "What is X?", + "answer": "X is Y.", + "query_type": "structural", + "reasoning_type": "factual", + "question_complexity": 4, + "segment_ids": [1], + "hop_count": 1, + "hop_contexts": [], + } + ], + } + ] + ) + queries_df, qrels_df, splits = postprocess_retriever_data(df) + assert len(queries_df) == 1 + assert queries_df.iloc[0]["text"] == "What is X?" + assert len(qrels_df) == 1 + assert "text" in splits + + +def test_postprocess_skips_missing() -> None: + df = pd.DataFrame([{"file_name": ["x.txt"]}]) + queries_df, _, _ = postprocess_retriever_data(df) + assert len(queries_df) == 0 + + +# --------------------------------------------------------------------------- +# filter_qa_pairs_by_quality +# --------------------------------------------------------------------------- + + +def test_filter_by_quality() -> None: + df = pd.DataFrame( + [ + { + "file_name": ["a.txt"], + "deduplicated_qa_pairs": [ + {"question": "Q1", "answer": "A1"}, + {"question": "Q2", "answer": "A2"}, + ], + "qa_evaluations": { + "evaluations": [ + {"overall": {"score": 9.0}}, + {"overall": {"score": 3.0}}, + ] + }, + } + ] + ) + filtered_df, skipped = filter_qa_pairs_by_quality(df, quality_threshold=7.0) + assert len(filtered_df) == 1 + assert filtered_df.iloc[0]["question"] == "Q1" + assert skipped == [] + + +def test_filter_skips_mismatched() -> None: + df = pd.DataFrame( + [ + { + "file_name": ["bad.txt"], + "deduplicated_qa_pairs": [{"question": "Q1", "answer": "A1"}], + "qa_evaluations": {"evaluations": []}, + } + ] + ) + filtered_df, skipped = filter_qa_pairs_by_quality(df, quality_threshold=5.0) + assert len(filtered_df) == 0 + assert len(skipped) == 1 diff --git a/pyproject.toml b/pyproject.toml index 296d95b..8eef2c1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,7 @@ ignore = [ ] [tool.ruff.lint.isort] -known-first-party = ["ddp", "data_designer_template"] +known-first-party = ["ddp", "data_designer_template", "data_designer_retrieval_sdg"] [tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "all" diff --git a/uv.lock b/uv.lock index bf6a01c..24942e9 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.10" resolution-markers = [ "python_full_version >= '3.12'", @@ -10,6 +10,7 @@ resolution-markers = [ [manifest] members = [ "data-designer-plugins-workspace", + "data-designer-retrieval-sdg", "data-designer-template", "ddp", ] @@ -427,6 +428,25 @@ name = "data-designer-plugins-workspace" version = "0.0.0" source = { virtual = "." } +[[package]] +name = "data-designer-retrieval-sdg" +version = "0.1.0" +source = { editable = "plugins/data-designer-retrieval-sdg" } +dependencies = [ + { name = "data-designer" }, + { name = "nltk" }, + { name = "pyarrow" }, + { name = "pyyaml" }, +] + +[package.metadata] +requires-dist = [ + { name = "data-designer", specifier = ">=0.5.7" }, + { name = "nltk", specifier = ">=3.9.2" }, + { name = "pyarrow", specifier = ">=14.0" }, + { name = "pyyaml", specifier = ">=6.0" }, +] + [[package]] name = "data-designer-template" version = "0.1.0" @@ -710,6 +730,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/62/a1/3d680cbfd5f4b8f15abc1d571870c5fc3e594bb582bc3b64ea099db13e56/jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67", size = 134899, upload-time = "2025-03-05T20:05:00.369Z" }, ] +[[package]] +name = "joblib" +version = "1.5.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/41/f2/d34e8b3a08a9cc79a50b2208a93dce981fe615b64d5a4d4abee421d898df/joblib-1.5.3.tar.gz", hash = "sha256:8561a3269e6801106863fd0d6d84bb737be9e7631e33aaed3fb9ce5953688da3", size = 331603, upload-time = "2025-12-15T08:41:46.427Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7b/91/984aca2ec129e2757d1e4e3c81c3fcda9d0f85b74670a094cc443d9ee949/joblib-1.5.3-py3-none-any.whl", hash = "sha256:5fc3c5039fc5ca8c0276333a188bbd59d6b7ab37fe6632daa76bc7f9ec18e713", size = 309071, upload-time = "2025-12-15T08:41:44.973Z" }, +] + [[package]] name = "json-repair" version = "0.59.4" @@ -1137,6 +1166,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9e/c9/b2622292ea83fbb4ec318f5b9ab867d0a28ab43c5717bb85b0a5f6b3b0a4/networkx-3.6.1-py3-none-any.whl", hash = "sha256:d47fbf302e7d9cbbb9e2555a0d267983d2aa476bac30e90dfbe5669bd57f3762", size = 2068504, upload-time = "2025-12-08T17:02:38.159Z" }, ] +[[package]] +name = "nltk" +version = "3.9.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "joblib" }, + { name = "regex" }, + { name = "tqdm" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/74/a1/b3b4adf15585a5bc4c357adde150c01ebeeb642173ded4d871e89468767c/nltk-3.9.4.tar.gz", hash = "sha256:ed03bc098a40481310320808b2db712d95d13ca65b27372f8a403949c8b523d0", size = 2946864, upload-time = "2026-03-24T06:13:40.641Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9d/91/04e965f8e717ba0ab4bdca5c112deeab11c9e750d94c4d4602f050295d39/nltk-3.9.4-py3-none-any.whl", hash = "sha256:f2fa301c3a12718ce4a0e9305c5675299da5ad9e26068218b69d692fda84828f", size = 1552087, upload-time = "2026-03-24T06:13:38.47Z" }, +] + [[package]] name = "numpy" version = "2.2.6" From 9f7645f472676fcbe059304dffbcbcfe4d98c341 Mon Sep 17 00:00:00 2001 From: Steve Han Date: Thu, 30 Apr 2026 10:06:42 -0400 Subject: [PATCH 2/6] refactor(data-designer-retrieval-sdg): two entry points and async dedup Restructure the plugin per PR #1 review feedback (Johnny Greco, Nabin Mulepati). The single PyPI package now registers two entry points, removes the manual ThreadPoolExecutor in favor of DataDesigner's new async engine, and replaces the hand-rolled DataFrame seed loader with a FileSystemSeedReader subclass. - Register two data_designer.plugins entry points in one package: * embedding-dedup (column generator) - generic cosine-similarity dedup of any list-valued column. Implements native agenerate() so the column engages DATA_DESIGNER_ASYNC_ENGINE=1 cell-level concurrency. * document-chunker (seed reader) - FileSystemSeedReader subclass that sentence-chunks files and emits structured sections. - Generalize the dedup column config (source_column, items_key, text_field, model_alias, similarity_threshold, column_type "embedding-dedup"); drop ThreadPoolExecutor; single batched embedding call per row in both sync and async paths. - Move reusable chunking/section/bundling helpers from ingest.py into chunking.py and delete ingest.py - file discovery, manifest building, and DataFrame construction now belong to the framework. - Update pipeline.py to take a DocumentChunkerSeedSource directly (no more DataFrameSeedSource wrapping) and the renamed EmbeddingDedupColumnConfig. - Refactor cli.py to build the seed source from CLI flags, drop the manual ETA helper, and let DataDesigner's progress logger surface progress. Per-batch JSON output for resumability is preserved. - Add @oliverholworthy to the plugin CODEOWNERS. - Refresh auto-derived metadata (docs/catalog.md now lists both entry points; .github/CODEOWNERS regenerated). - Tests: validate both plugin entries, add agenerate() async test for dedup, rename test_ingest -> test_chunking, add test_seed_reader. Local CI is green: lint, isolated-venv test (70 tests pass: 59 retrieval-sdg + 11 template), validate (3 plugins OK), check. Verified with sync and async (DATA_DESIGNER_ASYNC_ENGINE=1) end-to-end smoke runs against examples/sample_texts. Made-with: Cursor Signed-off-by: Steve Han --- .github/CODEOWNERS | 2 +- docs/catalog.md | 3 +- .../data-designer-retrieval-sdg/CODEOWNERS | 2 +- plugins/data-designer-retrieval-sdg/README.md | 105 ++- .../pyproject.toml | 5 +- .../data_designer_retrieval_sdg/__init__.py | 25 +- .../data_designer_retrieval_sdg/chunking.py | 369 +++++++++ .../src/data_designer_retrieval_sdg/cli.py | 202 +++-- .../src/data_designer_retrieval_sdg/config.py | 40 +- .../src/data_designer_retrieval_sdg/dedup.py | 153 ++-- .../src/data_designer_retrieval_sdg/ingest.py | 724 ------------------ .../data_designer_retrieval_sdg/pipeline.py | 54 +- .../src/data_designer_retrieval_sdg/plugin.py | 12 - .../data_designer_retrieval_sdg/plugins.py | 27 + .../seed_reader.py | 242 ++++++ .../seed_source.py | 66 ++ .../tests/test_chunking.py | 109 +++ .../tests/test_dedup.py | 145 +++- .../tests/test_ingest.py | 145 ---- .../tests/test_plugin.py | 12 +- .../tests/test_seed_reader.py | 119 +++ 21 files changed, 1379 insertions(+), 1182 deletions(-) create mode 100644 plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/chunking.py delete mode 100644 plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/ingest.py delete mode 100644 plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/plugin.py create mode 100644 plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/plugins.py create mode 100644 plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/seed_reader.py create mode 100644 plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/seed_source.py create mode 100644 plugins/data-designer-retrieval-sdg/tests/test_chunking.py delete mode 100644 plugins/data-designer-retrieval-sdg/tests/test_ingest.py create mode 100644 plugins/data-designer-retrieval-sdg/tests/test_seed_reader.py diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 135062f..c48ee7f 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -7,5 +7,5 @@ /.github/ @NVIDIA-NeMo/data_designer_reviewers # Plugins -/plugins/data-designer-retrieval-sdg/ @NVIDIA-NeMo/data_designer_reviewers @shan-nvidia +/plugins/data-designer-retrieval-sdg/ @NVIDIA-NeMo/data_designer_reviewers @shan-nvidia @oliverholworthy /plugins/data-designer-template/ @NVIDIA-NeMo/data_designer_reviewers diff --git a/docs/catalog.md b/docs/catalog.md index 4d6d7b1..7991a36 100644 --- a/docs/catalog.md +++ b/docs/catalog.md @@ -4,5 +4,6 @@ Auto-generated from plugin metadata. Do not edit manually. | Plugin | Version | Column Type | Description | |--------|---------|-------------|-------------| -| data-designer-retrieval-sdg | 0.1.0 | `retrieval-sdg-dedup` | Multi-step retriever SDG pipeline (artifact extraction, QA generation, dedup, evaluation) with Automodel-compatible data conversion; registers a retrieval-sdg-dedup column plugin | +| data-designer-retrieval-sdg | 0.1.0 | `document-chunker` | Retriever SDG toolkit: registers the embedding-dedup column generator and document-chunker seed reader, plus a multi-step QA generation pipeline, CLI, and Automodel-compatible data conversion | +| data-designer-retrieval-sdg | 0.1.0 | `embedding-dedup` | Retriever SDG toolkit: registers the embedding-dedup column generator and document-chunker seed reader, plus a multi-step QA generation pipeline, CLI, and Automodel-compatible data conversion | | data-designer-template | 0.1.0 | `text-transform` | Template Data Designer plugin — text transform column generator | diff --git a/plugins/data-designer-retrieval-sdg/CODEOWNERS b/plugins/data-designer-retrieval-sdg/CODEOWNERS index 3c525c6..4c971ba 100644 --- a/plugins/data-designer-retrieval-sdg/CODEOWNERS +++ b/plugins/data-designer-retrieval-sdg/CODEOWNERS @@ -1,3 +1,3 @@ # Owner(s) of this plugin — used to generate the root CODEOWNERS file. # GitHub accepts @username, @org/team, or email format. -* @NVIDIA-NeMo/data_designer_reviewers @shan-nvidia +* @NVIDIA-NeMo/data_designer_reviewers @shan-nvidia @oliverholworthy diff --git a/plugins/data-designer-retrieval-sdg/README.md b/plugins/data-designer-retrieval-sdg/README.md index be20051..426f39d 100644 --- a/plugins/data-designer-retrieval-sdg/README.md +++ b/plugins/data-designer-retrieval-sdg/README.md @@ -1,46 +1,53 @@ # data-designer-retrieval-sdg -Data Designer plugin for **retriever synthetic data generation**. Generates -multi-hop QA pairs from text documents and converts them into training -formats compatible with [Automodel](https://github.com/NVIDIA-NeMo/Automodel) -retriever finetuning. - -## Features - -- **Retrieval-sdg-dedup column plugin** — embedding-based QA-pair deduplication - registered as a `data_designer.plugins` entry point. -- **Four-column SDG pipeline** — artifact extraction → QA generation → - deduplication → quality evaluation, all orchestrated via DataDesigner. -- **Data conversion** — convert raw SDG output to NeMo Retriever training - format (`train.json`, `val.json`), BEIR evaluation format, and corpus - parquet with `merlin_metadata.json`. -- **CLI** — `data-designer-retrieval-sdg generate` and - `data-designer-retrieval-sdg convert` subcommands. +Data Designer toolkit for **retriever synthetic data generation**. The +package registers two `data_designer.plugins` entry points, ships a +ready-made multi-step QA generation pipeline, and exposes a CLI that +generates QA pairs and converts them into training formats compatible +with [Automodel](https://github.com/NVIDIA-NeMo/Automodel) retriever +finetuning. -## Installation +## Plugins + +The single PyPI package contributes two plugins to DataDesigner's +registries via `[project.entry-points."data_designer.plugins"]`: + +| Slug | Type | Purpose | +|------|------|---------| +| `embedding-dedup` | column generator | Generic cosine-similarity dedup of any list-valued column. Implements native `agenerate()` for the async engine. | +| `document-chunker` | seed reader | Sentence-chunks a directory of text files and emits structured sections, with optional multi-document bundling. | + +Both ship with the same `pip install data-designer-retrieval-sdg` and +become discoverable automatically through Python entry points. + +## Native async (`DATA_DESIGNER_ASYNC_ENGINE=1`) + +`embedding-dedup` implements `agenerate()` directly on top of +`model.agenerate_text_embeddings`, so the column participates in +DataDesigner's async cell-level scheduler whenever the env var is set: ```bash -pip install data-designer-retrieval-sdg +export DATA_DESIGNER_ASYNC_ENGINE=1 +data-designer-retrieval-sdg generate ... ``` -Or, for development inside the monorepo: +The async engine requires Python 3.11+; without the env var the package +runs on Python 3.10+ via the framework's sync bridge. + +## Installation ```bash -make sync # from the repo root +pip install data-designer-retrieval-sdg ``` -## Development setup - -When working inside the monorepo the CLI and library are installed into the -workspace virtual environment. Activate it before running commands: +For development inside the monorepo: ```bash make sync # install all packages into .venv source .venv/bin/activate # activate the virtual environment ``` -Alternatively, prefix any command with `uv run` to execute inside the venv -without activating it: +Or prefix any command with `uv run`: ```bash uv run data-designer-retrieval-sdg generate --help @@ -68,29 +75,51 @@ data-designer-retrieval-sdg convert ./generated_output \ ```python from data_designer_retrieval_sdg import ( + DocumentChunkerSeedSource, build_qa_generation_pipeline, - load_text_files_from_directory, ) -seed_df = load_text_files_from_directory(Path("./docs")) -config_builder = build_qa_generation_pipeline(seed_df) +seed_source = DocumentChunkerSeedSource( + path="./docs", + file_extensions=[".txt", ".md"], +) +config_builder = build_qa_generation_pipeline(seed_source) ``` -## Plugin column type +## Plugin configuration examples -The package registers the `retrieval-sdg-dedup` column type. Use it in a -DataDesigner pipeline to deduplicate QA pairs by embedding cosine -similarity: +### `embedding-dedup` column ```python -from data_designer_retrieval_sdg.config import RetrievalSdgDedupColumnConfig +from data_designer_retrieval_sdg.config import EmbeddingDedupColumnConfig config_builder.add_column( - RetrievalSdgDedupColumnConfig( + EmbeddingDedupColumnConfig( name="deduplicated_qa_pairs", - qa_pairs_column="qa_generation", - embedding_alias="embed", - dedupe_similarity_threshold=0.9, + source_column="qa_generation", # upstream column with the items + items_key="pairs", # key under the source column ("None" if the column is already a list) + text_field="question", # field on each item to embed + model_alias="embed", # registered embedding model alias + similarity_threshold=0.9, ) ) ``` + +### `document-chunker` seed reader + +```python +from data_designer_retrieval_sdg.seed_source import DocumentChunkerSeedSource + +seed_source = DocumentChunkerSeedSource( + path="./docs", + file_pattern="*", + recursive=True, + file_extensions=[".txt", ".md"], + sentences_per_chunk=5, + num_sections=1, + multi_doc=False, # set True for bundle-per-row mode +) +``` + +Output schema (one record per row): `file_name`, `text`, `chunks`, +`sections_structured`, `bundle_id`, `bundle_members`, `is_multi_doc`. diff --git a/plugins/data-designer-retrieval-sdg/pyproject.toml b/plugins/data-designer-retrieval-sdg/pyproject.toml index 42cf649..bc57a50 100644 --- a/plugins/data-designer-retrieval-sdg/pyproject.toml +++ b/plugins/data-designer-retrieval-sdg/pyproject.toml @@ -4,7 +4,7 @@ [project] name = "data-designer-retrieval-sdg" version = "0.1.0" -description = "Multi-step retriever SDG pipeline (artifact extraction, QA generation, dedup, evaluation) with Automodel-compatible data conversion; registers a retrieval-sdg-dedup column plugin" +description = "Retriever SDG toolkit: registers the embedding-dedup column generator and document-chunker seed reader, plus a multi-step QA generation pipeline, CLI, and Automodel-compatible data conversion" requires-python = ">=3.10" dependencies = [ "data-designer>=0.5.7", @@ -24,7 +24,8 @@ classifiers = [ ] [project.entry-points."data_designer.plugins"] -retrieval-sdg-dedup = "data_designer_retrieval_sdg.plugin:plugin" +embedding-dedup = "data_designer_retrieval_sdg.plugins:embedding_dedup_plugin" +document-chunker = "data_designer_retrieval_sdg.plugins:document_chunker_plugin" [project.scripts] data-designer-retrieval-sdg = "data_designer_retrieval_sdg.cli:main" diff --git a/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/__init__.py b/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/__init__.py index c187c03..fc5c14d 100644 --- a/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/__init__.py +++ b/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/__init__.py @@ -1,33 +1,34 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -"""Data Designer plugin for retriever synthetic data generation. +"""Data Designer plugins and pipeline for retriever synthetic data generation. -Provides a multi-step pipeline that generates QA pairs from text documents -for retriever finetuning, plus utilities for converting raw SDG output into -Automodel-compatible training formats. +The package registers two ``data_designer.plugins`` entry points: -Public API: +- ``embedding-dedup``: generic embedding-cosine-similarity column generator. +- ``document-chunker``: filesystem seed reader that loads text files, + sentence-chunks them, and emits structured sections. -- :func:`build_qa_generation_pipeline` -- build the four-column DD pipeline -- :func:`load_text_files_from_directory` -- load and chunk text files -- :func:`postprocess_retriever_data` -- flatten to BEIR format -- :func:`filter_qa_pairs_by_quality` -- quality-based filtering -- :func:`load_positive_docs_with_modality` -- load BEIR docs with modality +It also ships a ready-made four-column QA generation pipeline, a CLI for +running the pipeline end-to-end (``generate``) and exporting to NeMo +Retriever / BEIR formats (``convert``), and reusable post-processing +helpers. """ -from data_designer_retrieval_sdg.ingest import load_text_files_from_directory +from data_designer_retrieval_sdg.config import EmbeddingDedupColumnConfig from data_designer_retrieval_sdg.pipeline import build_qa_generation_pipeline from data_designer_retrieval_sdg.postprocess import ( filter_qa_pairs_by_quality, load_positive_docs_with_modality, postprocess_retriever_data, ) +from data_designer_retrieval_sdg.seed_source import DocumentChunkerSeedSource __all__ = [ + "DocumentChunkerSeedSource", + "EmbeddingDedupColumnConfig", "build_qa_generation_pipeline", "filter_qa_pairs_by_quality", "load_positive_docs_with_modality", - "load_text_files_from_directory", "postprocess_retriever_data", ] diff --git a/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/chunking.py b/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/chunking.py new file mode 100644 index 0000000..f8dffa2 --- /dev/null +++ b/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/chunking.py @@ -0,0 +1,369 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Text chunking, section-building, and multi-document bundling helpers. + +These pure utilities are shared by the document-chunker seed reader and +exposed for direct use in tests. They contain no DataDesigner-specific +state: file IO is performed by the seed reader, while this module focuses +on shaping text into chunks/sections and grouping files into bundles. +""" + +from __future__ import annotations + +import hashlib +import json +import math +import re +from collections import defaultdict, deque +from pathlib import Path +from typing import Literal + +import nltk +from nltk.tokenize import sent_tokenize + + +def load_multi_doc_manifest(manifest_path: Path | None) -> list[list[str]]: + """Load a multi-doc manifest file. + + Supports JSON or YAML format:: + + [["doc1.txt", "doc2.txt"], ["doc3.txt"]] + {"bundles": [{"docs": ["doc1.txt", "doc2.txt"]}]} + + Args: + manifest_path: Path to the manifest file, or ``None``. + + Returns: + List of bundles, each a list of file-path strings. + """ + import yaml + + if not manifest_path: + return [] + + try: + manifest_text = manifest_path.read_text(encoding="utf-8") + except OSError as exc: + print(f"Warning: Unable to read multi_doc_manifest at {manifest_path}: {exc}") + return [] + + data = None + try: + data = json.loads(manifest_text) + except json.JSONDecodeError: + try: + data = yaml.safe_load(manifest_text) + except yaml.YAMLError as exc: + print(f"Warning: Failed to parse multi_doc_manifest: {exc}") + return [] + + if isinstance(data, dict) and "bundles" in data: + data = data["bundles"] + + bundles: list[list[str]] = [] + if isinstance(data, list): + for entry in data: + if isinstance(entry, dict) and "docs" in entry: + docs = entry["docs"] + elif isinstance(entry, list): + docs = entry + else: + docs = [] + clean_docs = [str(doc) for doc in docs if doc] + if clean_docs: + bundles.append(clean_docs) + else: + print("Warning: multi_doc_manifest must be a list or dict with 'bundles'") + + return bundles + + +def build_bundle_id(bundle_members: list[str]) -> str: + """Generate a stable bundle ID from member identifiers. + + Args: + bundle_members: List of member paths (relative or absolute). + + Returns: + MD5 hex digest of sorted, normalised members. + """ + if not bundle_members: + return "" + normalized = "||".join(sorted(str(member) for member in bundle_members)) + return hashlib.md5(normalized.encode()).hexdigest() + + +def build_bundles( + file_paths: list[Path], + bundle_size: int = 2, + max_docs_per_bundle: int = 3, + manifest_bundles: list[list[str]] | None = None, + input_dir: Path | None = None, +) -> list[list[Path]]: + """Group file paths into document bundles. + + Manifest-defined bundles take priority. Remaining documents are grouped + sequentially according to ``bundle_size``. + + Args: + file_paths: All candidate file paths. + bundle_size: Documents per automatic bundle. + max_docs_per_bundle: Hard cap on bundle size. + manifest_bundles: Pre-defined bundles from a manifest file. + input_dir: Root directory for resolving relative manifest paths. + + Returns: + List of bundles, each a list of resolved ``Path`` objects. + + Raises: + ValueError: If any bundle exceeds ``max_docs_per_bundle``. + """ + if not file_paths: + return [] + + resolved_paths = [path.resolve() for path in file_paths] + seen: set[Path] = set() + bundles: list[list[Path]] = [] + + if manifest_bundles: + for entry in manifest_bundles: + resolved_bundle: list[Path] = [] + for raw_doc in entry: + candidate = Path(raw_doc) + if not candidate.is_absolute() and input_dir: + candidate = (input_dir / raw_doc).resolve() + candidate = candidate.resolve() + if candidate in resolved_paths and candidate not in seen: + resolved_bundle.append(candidate) + seen.add(candidate) + if resolved_bundle: + bundles.append(resolved_bundle) + + remaining = [p for p in resolved_paths if p not in seen] + for start in range(0, len(remaining), bundle_size): + bundle = remaining[start : start + bundle_size] + if bundle: + bundles.append(bundle) + + for i, bundle in enumerate(bundles): + if len(bundle) > max_docs_per_bundle: + raise ValueError( + f"Bundle {i} has {len(bundle)} documents, which exceeds " + f"max_docs_per_bundle={max_docs_per_bundle}. " + f"Either reduce the bundle size in your manifest or increase max_docs_per_bundle." + ) + + return [b for b in bundles if b] + + +def group_chunks_by_doc(chunks: list[dict]) -> dict[str, list[tuple[int, dict]]]: + """Group chunks by their ``doc_id`` field.""" + grouped: dict[str, list[tuple[int, dict]]] = defaultdict(list) + for idx, chunk in enumerate(chunks): + doc_id = chunk.get("doc_id", "default") + grouped[doc_id].append((idx, chunk)) + return dict(grouped) + + +def format_section_chunks(indexed_chunks: list[tuple[int, dict]], section_number: int) -> str: + """Render a list of indexed chunks into a section string.""" + section_lines: list[str] = [] + for _, chunk in indexed_chunks: + text = chunk.get("text", "").strip() + if not text: + continue + segment_id = chunk.get("chunk_id", 1) + doc_id = chunk.get("doc_id", "") + start_time = "00:00:00" + end_time = "00:00:00" + if doc_id: + segment_info = f"Segment {segment_id} [Doc: {doc_id}] ({start_time} - {end_time}): {text}" + else: + segment_info = f"Segment {segment_id} ({start_time} - {end_time}): {text}" + section_lines.append(segment_info) + + if section_lines: + return f"=== Section {section_number} ===\n" + "\n".join(section_lines) + return "" + + +def chunks_to_sections_sequential(chunks: list[dict], num_sections: int = 1) -> list[str]: + """Split chunks sequentially into ``num_sections`` formatted sections.""" + total = len(chunks) + if total == 0: + return [] + + section_size = max(1, total // num_sections) + formatted_sections: list[str] = [] + + for i in range(num_sections): + start_idx = i * section_size + end_idx = (i + 1) * section_size if i < num_sections - 1 else total + indexed_chunks = [(j, chunks[j]) for j in range(start_idx, end_idx)] + section_text = format_section_chunks(indexed_chunks, i + 1) + if section_text: + formatted_sections.append(section_text) + + return formatted_sections + + +def chunks_to_sections_doc_balanced(chunks: list[dict], num_sections: int = 1) -> list[str]: + """Split chunks so each section has proportional doc representation.""" + if not chunks: + return [] + + grouped = group_chunks_by_doc(chunks) + if len(grouped) <= 1: + return chunks_to_sections_sequential(chunks, num_sections) + + chunk_sizes = {doc_id: max(1, math.ceil(len(entries) / num_sections)) for doc_id, entries in grouped.items()} + + sections: list[list[tuple[int, dict]]] = [] + for part_idx in range(num_sections): + part_entries: list[tuple[int, dict]] = [] + for doc_id, entries in grouped.items(): + chunk_size = chunk_sizes[doc_id] + start = part_idx * chunk_size + end = min(len(entries), start + chunk_size) + if start < len(entries): + part_entries.extend(entries[start:end]) + if part_entries: + sections.append(part_entries) + + formatted_sections: list[str] = [] + for i, indexed_chunks in enumerate(sections): + section_text = format_section_chunks(indexed_chunks, i + 1) + if section_text: + formatted_sections.append(section_text) + + return formatted_sections + + +def chunks_to_sections_interleaved(chunks: list[dict], num_sections: int = 1) -> list[str]: + """Split chunks with round-robin interleaving across documents.""" + if not chunks: + return [] + + grouped = group_chunks_by_doc(chunks) + if len(grouped) <= 1: + return chunks_to_sections_sequential(chunks, num_sections) + + doc_iterators = {doc_id: deque(entries) for doc_id, entries in grouped.items()} + doc_order = list(grouped.keys()) + interleaved: list[tuple[int, dict]] = [] + + while True: + added = False + for doc_id in doc_order: + doc_queue = doc_iterators[doc_id] + if doc_queue: + interleaved.append(doc_queue.popleft()) + added = True + if not added: + break + + if not interleaved: + return [] + + total = len(interleaved) + section_size = max(1, total // num_sections) + formatted_sections: list[str] = [] + + for i in range(num_sections): + start_idx = i * section_size + end_idx = (i + 1) * section_size if i < num_sections - 1 else total + indexed_chunks = interleaved[start_idx:end_idx] + section_text = format_section_chunks(indexed_chunks, i + 1) + if section_text: + formatted_sections.append(section_text) + + return formatted_sections + + +def chunks_to_sections_structured( + chunks: list[dict], + num_sections: int = 1, + strategy: Literal["sequential", "doc_balanced", "interleaved"] = "sequential", +) -> list[str]: + """Split chunks into sections using the specified strategy.""" + if strategy == "doc_balanced": + return chunks_to_sections_doc_balanced(chunks, num_sections) + if strategy == "interleaved": + return chunks_to_sections_interleaved(chunks, num_sections) + return chunks_to_sections_sequential(chunks, num_sections) + + +def ensure_nltk_punkt() -> None: + """Download NLTK punkt tokeniser data if not already present.""" + for resource in ("tokenizers/punkt", "tokenizers/punkt_tab"): + try: + nltk.data.find(resource) + except LookupError: + nltk.download(resource.split("/")[-1], quiet=True) + + +def text_to_sentence_chunks( + text: str, + sentences_per_chunk: int = 5, + doc_id: str | None = None, + doc_path: str | None = None, + chunk_id_offset: int = 0, +) -> list[dict]: + """Chunk ``text`` into groups of sentences with metadata. + + Args: + text: Input text to chunk. + sentences_per_chunk: Sentences per chunk. + doc_id: Optional document identifier for multi-doc bundles. + doc_path: Optional document path for multi-doc bundles. + chunk_id_offset: Offset for global chunk IDs when aggregating. + + Returns: + List of chunk dicts with keys ``text``, ``start``, ``end``, + ``sentence_count``, ``word_count``, ``chunk_id``, + ``doc_chunk_index``, and optionally ``doc_id`` / ``doc_path``. + """ + ensure_nltk_punkt() + + paragraphs = re.split(r"\n\s*\n+", text) + paragraphs = [p.strip() for p in paragraphs if p.strip()] + + sentences: list[str] = [] + for paragraph in paragraphs: + sentences.extend(sent_tokenize(paragraph)) + + chunks: list[dict] = [] + word_position = 0 + doc_chunk_index = 0 + + for i in range(0, len(sentences), sentences_per_chunk): + chunk_sentences = sentences[i : i + sentences_per_chunk] + chunk_text = ". ".join(chunk_sentences) + if chunk_text and not chunk_text.endswith("."): + chunk_text += "." + + chunk_words = chunk_text.split() + start_word_pos = word_position + end_word_pos = word_position + len(chunk_words) + word_position = end_word_pos + doc_chunk_index += 1 + + chunk_data: dict = { + "text": chunk_text, + "start": start_word_pos, + "end": end_word_pos, + "sentence_count": len(chunk_sentences), + "word_count": len(chunk_words), + "chunk_id": chunk_id_offset + len(chunks) + 1, + "doc_chunk_index": doc_chunk_index, + } + + if doc_id is not None: + chunk_data["doc_id"] = doc_id + if doc_path is not None: + chunk_data["doc_path"] = doc_path + + chunks.append(chunk_data) + + return chunks diff --git a/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/cli.py b/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/cli.py index e298eed..19f1258 100644 --- a/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/cli.py +++ b/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/cli.py @@ -4,41 +4,64 @@ """CLI entry points for the data-designer-retrieval-sdg package. Provides two subcommands: + - ``generate`` -- run the full SDG pipeline on a directory of text files - ``convert`` -- convert raw SDG output to Automodel-compatible formats + +The ``generate`` subcommand drives a per-batch loop so each batch's output +is checkpointed to its own JSON file (resumable across crashes). The +batching wraps DataDesigner's native ``IndexRange`` selection strategy +applied to a :class:`DocumentChunkerSeedSource`; the framework owns +discovery, chunking, and async cell scheduling (when +``DATA_DESIGNER_ASYNC_ENGINE=1`` is set). """ from __future__ import annotations import argparse import sys -import time from pathlib import Path import data_designer.config as dd +from data_designer.engine.resources.seed_reader import SeedReaderError +from data_designer.engine.secret_resolver import PlaintextResolver from data_designer.interface import DataDesigner from data_designer.logging import LoggerConfig, LoggingConfig, OutputConfig, configure_logging from data_designer_retrieval_sdg.convert import run_conversion -from data_designer_retrieval_sdg.ingest import load_text_files_from_directory from data_designer_retrieval_sdg.pipeline import build_model_providers, build_qa_generation_pipeline +from data_designer_retrieval_sdg.seed_reader import DocumentChunkerSeedReader +from data_designer_retrieval_sdg.seed_source import DocumentChunkerSeedSource -def _format_duration(seconds: float) -> str: - """Format a duration in seconds to a human-readable string.""" - seconds = max(0, int(seconds)) - if seconds < 60: - return f"{seconds}s" - minutes, secs = divmod(seconds, 60) - if minutes < 60: - return f"{minutes}m {secs}s" - hours, minutes = divmod(minutes, 60) - return f"{hours}h {minutes}m" +def _build_seed_source(args: argparse.Namespace) -> DocumentChunkerSeedSource: + """Construct a :class:`DocumentChunkerSeedSource` from CLI arguments.""" + return DocumentChunkerSeedSource( + path=str(args.input_dir), + file_pattern=args.file_pattern, + recursive=args.recursive, + file_extensions=args.file_extensions, + min_text_length=args.min_text_length, + sentences_per_chunk=args.sentences_per_chunk, + num_sections=args.num_sections, + num_files=args.num_files, + multi_doc=args.multi_doc, + bundle_size=args.bundle_size, + bundle_strategy=args.bundle_strategy, + max_docs_per_bundle=args.max_docs_per_bundle, + multi_doc_manifest=str(args.multi_doc_manifest) if args.multi_doc_manifest else None, + ) -# --------------------------------------------------------------------------- -# ``generate`` subcommand -# --------------------------------------------------------------------------- +def _count_seed_records(seed_source: DocumentChunkerSeedSource) -> int: + """Probe the seed reader for the total number of records it will produce. + + Builds and attaches a temporary reader so the manifest is materialised + once for batch math without reading any file contents. + """ + reader = DocumentChunkerSeedReader() + reader.attach(seed_source, PlaintextResolver()) + return reader.get_seed_dataset_size() def _add_generate_parser(subparsers: argparse._SubParsersAction) -> None: @@ -51,18 +74,27 @@ def _add_generate_parser(subparsers: argparse._SubParsersAction) -> None: p.add_argument("--input-dir", type=Path, required=True, help="Directory containing text files") p.add_argument("--output-dir", type=Path, required=True, help="Directory to save generated output") + p.add_argument("--file-pattern", default="*", help="Filename glob (basenames only)") + p.add_argument("--no-recursive", dest="recursive", action="store_false", help="Disable recursive search") + p.set_defaults(recursive=True) + p.add_argument( + "--file-extensions", + nargs="+", + default=[".txt", ".md", ".text"], + help="Allowed file extensions (use empty string '' to match files without extensions)", + ) p.add_argument("--min-text-length", type=int, default=50, help="Minimum document text length") p.add_argument("--sentences-per-chunk", type=int, default=5, help="Sentences per chunk") p.add_argument("--num-sections", type=int, default=1, help="Sections to divide chunks into") + p.add_argument("--num-files", type=int, default=None, help="Max files to process") p.add_argument("--max-artifacts-per-type", type=int, default=2, help="Max artifacts per type") p.add_argument("--num-pairs", type=int, default=7, help="QA pairs per document") p.add_argument("--min-hops", type=int, default=2, help="Min hops for multi-hop questions") p.add_argument("--max-hops", type=int, default=4, help="Max hops for multi-hop questions") p.add_argument("--min-complexity", type=int, default=4, help="Min question complexity") + p.add_argument("--similarity-threshold", type=float, default=0.9, help="Cosine threshold for QA-pair dedup") p.add_argument("--preview", action="store_true", help="Preview without full generation") - p.add_argument("--file-extensions", nargs="+", default=None, help="File extensions to include") p.add_argument("--artifact-path", type=Path, default=Path("./artifacts"), help="DD artifact path") - p.add_argument("--num-files", type=int, default=None, help="Max files to process") p.add_argument("--batch-size", type=int, default=200, help="Records per batch") p.add_argument("--start-batch-index", type=int, default=0, help="Batch index to start from") p.add_argument("--end-batch-index", type=int, default=-1, help="Batch index to end at (exclusive)") @@ -74,7 +106,7 @@ def _add_generate_parser(subparsers: argparse._SubParsersAction) -> None: "--bundle-strategy", choices=["sequential", "doc_balanced", "interleaved"], default="sequential", - help="Segment splitting strategy", + help="Section splitting strategy", ) g.add_argument("--max-docs-per-bundle", type=int, default=3, help="Max docs per bundle") g.add_argument("--multi-doc-manifest", type=Path, default=None, help="Manifest for explicit bundles") @@ -105,29 +137,6 @@ def _add_generate_parser(subparsers: argparse._SubParsersAction) -> None: def _run_generate(args: argparse.Namespace) -> None: """Execute the ``generate`` subcommand.""" - file_extensions = args.file_extensions or [".txt", ".md", ".text", ""] - - print(f"Loading text files from {args.input_dir}...") - if args.multi_doc: - print(f"Multi-doc mode enabled: bundle_size={args.bundle_size}, strategy={args.bundle_strategy}") - - text_files_df = load_text_files_from_directory( - input_dir=args.input_dir, - file_extensions=file_extensions, - min_text_length=args.min_text_length, - sentences_per_chunk=args.sentences_per_chunk, - num_sections=args.num_sections, - num_files=args.num_files, - multi_doc=args.multi_doc, - bundle_size=args.bundle_size, - bundle_strategy=args.bundle_strategy, - max_docs_per_bundle=args.max_docs_per_bundle, - multi_doc_manifest=args.multi_doc_manifest, - ) - - row_type = "bundles" if args.multi_doc else "text files" - print(f"\nLoaded {len(text_files_df)} {row_type}") - configure_logging( LoggingConfig( logger_configs=[LoggerConfig(name="data_designer", level=args.log_level)], @@ -136,6 +145,16 @@ def _run_generate(args: argparse.Namespace) -> None: ) ) + seed_source = _build_seed_source(args) + try: + total_records = _count_seed_records(seed_source) + except SeedReaderError as exc: + print(f"Error: {exc}", file=sys.stderr) + sys.exit(1) + + row_type = "bundles" if args.multi_doc else "text files" + print(f"Discovered {total_records} {row_type} under {args.input_dir}") + model_providers, custom_providers = build_model_providers( custom_provider_endpoint=args.custom_provider_endpoint, custom_provider_name=args.custom_provider_name, @@ -149,40 +168,49 @@ def _run_generate(args: argparse.Namespace) -> None: args.output_dir.mkdir(parents=True, exist_ok=True) - total_records = len(text_files_df) num_batches = (total_records + args.batch_size - 1) // args.batch_size actual_end_batch = num_batches if args.end_batch_index == -1 else min(args.end_batch_index, num_batches) - model_kwargs: dict = { - "max_parallel_requests_for_gen": args.max_parallel_requests_for_gen, - "artifact_extraction_model": args.artifact_extraction_model, - "artifact_extraction_provider": args.artifact_extraction_provider, - "qa_generation_model": args.qa_generation_model, - "qa_generation_provider": args.qa_generation_provider, - "quality_judge_model": args.quality_judge_model, - "quality_judge_provider": args.quality_judge_provider, - "embed_model": args.embed_model, - "embed_provider": args.embed_provider, - } - + pipeline_kwargs = _pipeline_kwargs(args) _print_model_config(args, custom_providers) if args.preview: - _run_preview(data_designer, text_files_df, total_records, args, model_kwargs) + _run_preview(data_designer, seed_source, total_records, args, pipeline_kwargs) return _run_batches( data_designer, - text_files_df, + seed_source, total_records, num_batches, args.start_batch_index, actual_end_batch, args, - model_kwargs, + pipeline_kwargs, ) +def _pipeline_kwargs(args: argparse.Namespace) -> dict: + """Collect pipeline-builder keyword arguments shared between preview and batch runs.""" + return { + "max_artifacts_per_type": args.max_artifacts_per_type, + "num_pairs": args.num_pairs, + "min_hops": args.min_hops, + "max_hops": args.max_hops, + "min_complexity": args.min_complexity, + "similarity_threshold": args.similarity_threshold, + "max_parallel_requests_for_gen": args.max_parallel_requests_for_gen, + "artifact_extraction_model": args.artifact_extraction_model, + "artifact_extraction_provider": args.artifact_extraction_provider, + "qa_generation_model": args.qa_generation_model, + "qa_generation_provider": args.qa_generation_provider, + "quality_judge_model": args.quality_judge_model, + "quality_judge_provider": args.quality_judge_provider, + "embed_model": args.embed_model, + "embed_provider": args.embed_provider, + } + + def _print_model_config(args: argparse.Namespace, custom_providers: list) -> None: """Print model configuration to stdout.""" print("\nModel configuration:") @@ -198,45 +226,37 @@ def _print_model_config(args: argparse.Namespace, custom_providers: list) -> Non def _run_preview( data_designer: DataDesigner, - text_files_df: object, + seed_source: DocumentChunkerSeedSource, total_records: int, args: argparse.Namespace, - model_kwargs: dict, + pipeline_kwargs: dict, ) -> None: """Run a single-record preview of the pipeline.""" config_builder = build_qa_generation_pipeline( - seed_dataset=text_files_df, + seed_source=seed_source, start_index=0, end_index=min(args.batch_size - 1, total_records - 1), - max_artifacts_per_type=args.max_artifacts_per_type, - num_pairs=args.num_pairs, - min_hops=args.min_hops, - max_hops=args.max_hops, - min_complexity=args.min_complexity, - **model_kwargs, + **pipeline_kwargs, ) print("\nPreviewing generation...") try: preview_result = data_designer.preview(config_builder, num_records=1) preview_result.display_sample_record() - except Exception as e: + except Exception as e: # noqa: BLE001 - preview is best-effort UX print(f"Preview error: {e}") def _run_batches( data_designer: DataDesigner, - text_files_df: object, + seed_source: DocumentChunkerSeedSource, total_records: int, num_batches: int, start_batch: int, end_batch: int, args: argparse.Namespace, - model_kwargs: dict, + pipeline_kwargs: dict, ) -> None: - """Process the pipeline in batches.""" - total_batches_to_run = end_batch - start_batch - batch_times: list[float] = [] - + """Process the pipeline in batches, writing one JSON per batch.""" print(f"\nTotal records: {total_records}") print(f"Batch size: {args.batch_size}") print(f"Total batches: {num_batches}") @@ -252,18 +272,11 @@ def _run_batches( print(f"Processing batch {batch_idx}/{num_batches - 1} (records {start_idx}-{end_idx})") print(f"{'=' * 60}") - batch_start = time.monotonic() - config_builder = build_qa_generation_pipeline( - seed_dataset=text_files_df, + seed_source=seed_source, start_index=start_idx, end_index=end_idx, - max_artifacts_per_type=args.max_artifacts_per_type, - num_pairs=args.num_pairs, - min_hops=args.min_hops, - max_hops=args.max_hops, - min_complexity=args.min_complexity, - **model_kwargs, + **pipeline_kwargs, ) input_basename = args.input_dir.name @@ -273,31 +286,13 @@ def _run_batches( output_filename = f"generated_batch{batch_idx}_{start_idx}_{end_idx}.json" generated_df.to_json(args.output_dir / output_filename, orient="records", indent=2) - - batch_elapsed = time.monotonic() - batch_start - batch_times.append(batch_elapsed) - - batches_done = batch_idx - start_batch + 1 - batches_remaining = end_batch - batch_idx - 1 - - print(f"Batch {batch_idx}/{num_batches - 1} done in {_format_duration(batch_elapsed)}") - print(f" Saved to {output_filename} ({len(generated_df)} records)") - if batches_remaining > 0: - avg_time = sum(batch_times) / len(batch_times) - eta = avg_time * batches_remaining - print(f" Progress: {batches_done}/{total_batches_to_run} batches") - print(f" ETA: ~{_format_duration(eta)} remaining") + print(f"Saved {output_filename} ({len(generated_df)} records)") print(f"\n{'=' * 60}") print(f"Generation complete! All batches saved to {args.output_dir}") print(f"Total batches processed: {end_batch - start_batch}") -# --------------------------------------------------------------------------- -# ``convert`` subcommand -# --------------------------------------------------------------------------- - - def _add_convert_parser(subparsers: argparse._SubParsersAction) -> None: """Register the ``convert`` subcommand.""" p = subparsers.add_parser( @@ -340,11 +335,6 @@ def _run_convert(args: argparse.Namespace) -> None: ) -# --------------------------------------------------------------------------- -# Main entry point -# --------------------------------------------------------------------------- - - def main() -> None: """CLI entry point for ``data-designer-retrieval-sdg``.""" parser = argparse.ArgumentParser( diff --git a/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/config.py b/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/config.py index 38ae6ad..adaf186 100644 --- a/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/config.py +++ b/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/config.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -"""Column configuration for the retrieval deduplication plugin.""" +"""Column configuration for the embedding-dedup plugin.""" from __future__ import annotations @@ -10,31 +10,41 @@ from data_designer.config.base import SingleColumnConfig -class RetrievalSdgDedupColumnConfig(SingleColumnConfig): - """Deduplicate QA pairs from a retrieval generation set via embedding similarity. +class EmbeddingDedupColumnConfig(SingleColumnConfig): + """Deduplicate items in a list-valued column via embedding cosine similarity. - This column reads QA pairs from a source column, embeds each question, - and removes near-duplicates whose cosine similarity exceeds a threshold. + The column reads a list of items from ``source_column``, embeds a chosen + text field on each item, computes pairwise cosine similarity, and greedily + drops items above ``similarity_threshold``. ``items_key`` selects whether + the source column is a wrapper dict (``data[source_column][items_key]``) + or a bare list (``items_key=None``). Args: - qa_pairs_column: Name of the upstream column containing QA pairs - with a ``pairs`` key. - embedding_alias: Model alias registered in the DataDesigner model + source_column: Name of the upstream column containing the items to + deduplicate. + items_key: Key under ``source_column`` that holds the list of items. + Set to ``None`` when ``source_column`` already evaluates to a list. + Defaults to ``"pairs"`` for compatibility with the QA-pair shape. + text_field: Attribute or dictionary key on each item that should be + embedded for similarity comparison. Defaults to ``"question"``. + model_alias: Model alias registered in the DataDesigner model registry to use for computing embeddings. column_type: Fixed literal identifying this column type. - dedupe_similarity_threshold: Cosine similarity threshold above which - two questions are considered duplicates. Defaults to ``0.9``. + similarity_threshold: Cosine similarity threshold above which two + items are considered duplicates. Defaults to ``0.9``. """ - qa_pairs_column: str - embedding_alias: str - column_type: Literal["retrieval-sdg-dedup"] = "retrieval-sdg-dedup" - dedupe_similarity_threshold: float = 0.9 + source_column: str + items_key: str | None = "pairs" + text_field: str = "question" + model_alias: str + column_type: Literal["embedding-dedup"] = "embedding-dedup" + similarity_threshold: float = 0.9 @property def required_columns(self) -> list[str]: """Columns that must be present before this column can run.""" - return [self.qa_pairs_column] + return [self.source_column] @property def side_effect_columns(self) -> list[str]: diff --git a/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/dedup.py b/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/dedup.py index f9d6585..dac89f6 100644 --- a/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/dedup.py +++ b/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/dedup.py @@ -1,68 +1,103 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -"""Column generator that deduplicates QA pairs via embedding cosine similarity.""" +"""Generic embedding-cosine-similarity dedup column generator. + +Implements both ``generate()`` (sync) and ``agenerate()`` (async-native) +so the column participates in DataDesigner's ``DATA_DESIGNER_ASYNC_ENGINE`` +scheduler when enabled, falling back to the sync bridge otherwise. +""" from __future__ import annotations import logging -from concurrent.futures import ThreadPoolExecutor +from typing import Any import numpy as np from data_designer.engine.column_generators.generators.base import ColumnGeneratorCellByCell -from data_designer_retrieval_sdg.config import RetrievalSdgDedupColumnConfig +from data_designer_retrieval_sdg.config import EmbeddingDedupColumnConfig logger = logging.getLogger(__name__) -class RetrievalSdgDedupColumnGenerator(ColumnGeneratorCellByCell[RetrievalSdgDedupColumnConfig]): - """Remove near-duplicate QA pairs using embedding cosine similarity. +class EmbeddingDedupColumnGenerator(ColumnGeneratorCellByCell[EmbeddingDedupColumnConfig]): + """Remove near-duplicate items from a list-valued column. + + For each row the generator: - For each cell the generator: - 1. Reads QA pairs from the configured source column. - 2. Embeds every question in parallel via the registered embedding model. - 3. Computes pairwise cosine similarity and greedily drops duplicates - whose similarity exceeds ``dedupe_similarity_threshold``. - 4. Returns the surviving pairs under the column name. + 1. Resolves the items list at ``data[source_column][items_key]`` + (or ``data[source_column]`` when ``items_key`` is ``None``). + 2. Pulls the text field from each item via :meth:`extract_text`. + 3. Embeds the texts in a single batched call to the embedding model. + 4. Computes pairwise cosine similarity and greedily drops items whose + similarity exceeds ``similarity_threshold``. + 5. Returns the surviving items under ``self.config.name``. """ @property def embedder(self): """Resolve the embedding model from the resource provider.""" return self.resource_provider.model_registry.get_model( - model_alias=self.config.embedding_alias, + model_alias=self.config.model_alias, ) - def embed_text(self, text: str) -> list[float]: - """Compute an embedding vector for *text* using the configured model. + def resolve_items(self, data: dict) -> list[Any]: + """Return the list of items to deduplicate from a row dict. Args: - text: Input string to embed. + data: Row dict containing the configured source column. Returns: - List of floats representing the embedding vector. + The list referenced by ``source_column`` and (optionally) + ``items_key``; an empty list if the source value is missing. + + Raises: + TypeError: If the resolved value is not a list. """ - vectors = self.embedder.generate_text_embeddings( - input_texts=[text], - encoding_format="float", - ) - return vectors[0] + value = data.get(self.config.source_column) + if self.config.items_key is not None: + if value is None: + return [] + value = value[self.config.items_key] if isinstance(value, dict) else getattr(value, self.config.items_key) + if value is None: + return [] + if not isinstance(value, list): + raise TypeError( + f"EmbeddingDedupColumnGenerator expected a list at " + f"{self.config.source_column!r}" + f"{f'[{self.config.items_key!r}]' if self.config.items_key else ''}, " + f"got {type(value).__name__}" + ) + return value + + def extract_text(self, item: Any) -> str: + """Pull the text field from an item. + + Supports dict items and Pydantic / attribute-style items. - def dedupe_qa_pairs(self, embeddings: list[list[float]]) -> list[int]: - """Return indices of QA pairs to keep after greedy deduplication. + Args: + item: One element of the resolved items list. + + Returns: + The text to embed for similarity comparison. + """ + field = self.config.text_field + if isinstance(item, dict): + return str(item.get(field, "")) + return str(getattr(item, field, "")) - Computes pairwise cosine similarity. For every pair above the - threshold the later item is dropped. + def dedupe_indices(self, embeddings: list[list[float]]) -> list[int]: + """Return indices to keep after greedy cosine-similarity dedup. Args: - embeddings: 2-D list of embedding vectors, one per QA pair. + embeddings: 2-D list of embedding vectors, one per item. Returns: Sorted list of integer indices to retain. Raises: - ValueError: If *embeddings* is not a 2-D structure. + ValueError: If ``embeddings`` is not a 2-D structure. """ if not embeddings: return [] @@ -77,7 +112,7 @@ def dedupe_qa_pairs(self, embeddings: list[list[float]]) -> list[int]: cosine_sim = np.clip(normalized @ normalized.T, -1.0, 1.0) - threshold = self.config.dedupe_similarity_threshold + threshold = self.config.similarity_threshold keep_indexes: list[int] = [] dropped = np.zeros(len(embeddings), dtype=bool) @@ -92,36 +127,44 @@ def dedupe_qa_pairs(self, embeddings: list[list[float]]) -> list[int]: return keep_indexes - def generate(self, data: dict) -> dict: - """Deduplicate QA pairs for a single record. - - Args: - data: Row dict containing at least the ``qa_pairs_column``. - - Returns: - Updated row dict with the deduplicated pairs stored under - ``self.config.name``. - """ - logger.debug("Deduplicating QA pairs from column: %s", self.config.qa_pairs_column) - - qa_pairs: list = data[self.config.qa_pairs_column]["pairs"] - max_parallel = self.embedder.max_parallel_requests - workers = max(1, max_parallel or 1) - - with ThreadPoolExecutor(max_workers=workers) as executor: - embeddings = list(executor.map(self.embed_text, [qa["question"] for qa in qa_pairs])) - - retained_indexes = self.dedupe_qa_pairs(embeddings) - dropped = len(qa_pairs) - len(retained_indexes) + def log_dedup_outcome(self, kept: int, total: int) -> None: + """Log dedup statistics at info or debug level.""" + dropped = total - kept if dropped > 0: logger.info( - "Dedup: retained %d of %d QA pairs (%d duplicates removed)", - len(retained_indexes), - len(qa_pairs), + "Dedup: retained %d of %d items (%d duplicates removed)", + kept, + total, dropped, ) else: - logger.debug("Dedup: retained all %d QA pairs (no duplicates)", len(qa_pairs)) + logger.debug("Dedup: retained all %d items (no duplicates)", total) - retained_qa_pairs = [qa_pairs[i] for i in retained_indexes] - return data | {self.config.name: retained_qa_pairs} + def generate(self, data: dict) -> dict: + """Synchronous dedup for a single row using the embedding model.""" + items = self.resolve_items(data) + if not items: + return data | {self.config.name: []} + + texts = [self.extract_text(item) for item in items] + embeddings = self.embedder.generate_text_embeddings(input_texts=texts, encoding_format="float") + retained_indexes = self.dedupe_indices(embeddings) + self.log_dedup_outcome(len(retained_indexes), len(items)) + return data | {self.config.name: [items[i] for i in retained_indexes]} + + async def agenerate(self, data: dict) -> dict: + """Async dedup using ``model.agenerate_text_embeddings``. + + Drives the cell-level concurrency the async engine enables when + ``DATA_DESIGNER_ASYNC_ENGINE=1``; the framework's sync bridge runs + this from synchronous callers transparently. + """ + items = self.resolve_items(data) + if not items: + return data | {self.config.name: []} + + texts = [self.extract_text(item) for item in items] + embeddings = await self.embedder.agenerate_text_embeddings(input_texts=texts, encoding_format="float") + retained_indexes = self.dedupe_indices(embeddings) + self.log_dedup_outcome(len(retained_indexes), len(items)) + return data | {self.config.name: [items[i] for i in retained_indexes]} diff --git a/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/ingest.py b/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/ingest.py deleted file mode 100644 index c80fd77..0000000 --- a/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/ingest.py +++ /dev/null @@ -1,724 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -"""Text ingestion, chunking, and section-building utilities. - -This module handles loading text files from a directory, chunking them by -sentence boundaries, and organising chunks into sections using various -strategies (sequential, doc-balanced, interleaved). It supports both -single-document and multi-document (bundled) modes. -""" - -from __future__ import annotations - -import hashlib -import math -import re -from collections import defaultdict, deque -from pathlib import Path -from typing import Literal - -import nltk -import pandas as pd -from nltk.tokenize import sent_tokenize - -# --------------------------------------------------------------------------- -# File-matching helpers -# --------------------------------------------------------------------------- - - -def is_traditional_extension(suffix: str) -> bool: - """Check whether *suffix* looks like a real file extension. - - Traditional extensions are short (1-10 chars), start with a period, and - contain only alphanumeric characters. For example ``.txt``, ``.md``, - ``.json`` are traditional, whereas - ``.com_publication_2001-08_user-programmable`` is not. - - Args: - suffix: The file suffix (including leading ``'.'``). - - Returns: - ``True`` when the suffix matches the traditional pattern. - """ - if not suffix or not suffix.startswith("."): - return False - ext_part = suffix[1:] - return len(ext_part) <= 10 and ext_part.replace("_", "").isalnum() - - -def file_matches_extensions(file_path: Path, file_extensions: list[str]) -> bool: - """Decide whether *file_path* has one of the allowed extensions. - - Files whose suffix is not *traditional* (see - :func:`is_traditional_extension`) are treated as having no extension - and matched against ``""`` in *file_extensions*. - - Args: - file_path: Path to the file. - file_extensions: Allowed extensions, e.g. ``[".txt", ".md", ""]``. - - Returns: - ``True`` when the file matches. - """ - suffix = file_path.suffix.lower() - if is_traditional_extension(suffix): - return suffix in file_extensions - return "" in file_extensions - - -# --------------------------------------------------------------------------- -# Multi-document bundling helpers -# --------------------------------------------------------------------------- - - -def load_multi_doc_manifest(manifest_path: Path | None) -> list[list[str]]: - """Load a multi-doc manifest file. - - Supports JSON or YAML format:: - - [["doc1.txt", "doc2.txt"], ["doc3.txt"]] - {"bundles": [{"docs": ["doc1.txt", "doc2.txt"]}]} - - Args: - manifest_path: Path to the manifest file, or ``None``. - - Returns: - List of bundles, each a list of file-path strings. - """ - import json - - import yaml - - if not manifest_path: - return [] - - try: - manifest_text = manifest_path.read_text(encoding="utf-8") - except Exception as exc: - print(f"Warning: Unable to read multi_doc_manifest at {manifest_path}: {exc}") - return [] - - data = None - try: - data = json.loads(manifest_text) - except json.JSONDecodeError: - try: - data = yaml.safe_load(manifest_text) - except Exception as exc: - print(f"Warning: Failed to parse multi_doc_manifest: {exc}") - return [] - - if isinstance(data, dict) and "bundles" in data: - data = data["bundles"] - - bundles: list[list[str]] = [] - if isinstance(data, list): - for entry in data: - if isinstance(entry, dict) and "docs" in entry: - docs = entry["docs"] - elif isinstance(entry, list): - docs = entry - else: - docs = [] - clean_docs = [str(doc) for doc in docs if doc] - if clean_docs: - bundles.append(clean_docs) - else: - print("Warning: multi_doc_manifest must be a list or dict with 'bundles'") - - return bundles - - -def build_bundle_id(bundle_members: list[str]) -> str: - """Generate a unique bundle ID from member paths. - - Args: - bundle_members: List of file paths in the bundle. - - Returns: - MD5 hex digest of sorted, normalised paths. - """ - if not bundle_members: - return "" - normalized = "||".join(sorted(str(Path(member).resolve()) for member in bundle_members)) - return hashlib.md5(normalized.encode()).hexdigest() - - -def build_bundles( - file_paths: list[Path], - bundle_size: int = 2, - max_docs_per_bundle: int = 3, - manifest_bundles: list[list[str]] | None = None, - input_dir: Path | None = None, -) -> list[list[Path]]: - """Group file paths into document bundles. - - Manifest-defined bundles take priority. Remaining documents are grouped - sequentially according to *bundle_size*. - - Args: - file_paths: All candidate file paths. - bundle_size: Documents per automatic bundle. - max_docs_per_bundle: Hard cap on bundle size. - manifest_bundles: Pre-defined bundles from a manifest file. - input_dir: Root directory for resolving relative manifest paths. - - Returns: - List of bundles, each a list of resolved ``Path`` objects. - - Raises: - ValueError: If any bundle exceeds *max_docs_per_bundle*. - """ - if not file_paths: - return [] - - resolved_paths = [path.resolve() for path in file_paths] - seen: set[Path] = set() - bundles: list[list[Path]] = [] - - if manifest_bundles: - for entry in manifest_bundles: - resolved_bundle: list[Path] = [] - for raw_doc in entry: - candidate = Path(raw_doc) - if not candidate.is_absolute() and input_dir: - candidate = (input_dir / raw_doc).resolve() - candidate = candidate.resolve() - if candidate in resolved_paths and candidate not in seen: - resolved_bundle.append(candidate) - seen.add(candidate) - if resolved_bundle: - bundles.append(resolved_bundle) - - remaining = [p for p in resolved_paths if p not in seen] - for start in range(0, len(remaining), bundle_size): - bundle = remaining[start : start + bundle_size] - if bundle: - bundles.append(bundle) - - for i, bundle in enumerate(bundles): - if len(bundle) > max_docs_per_bundle: - raise ValueError( - f"Bundle {i} has {len(bundle)} documents, which exceeds " - f"max_docs_per_bundle={max_docs_per_bundle}. " - f"Either reduce the bundle size in your manifest or increase max_docs_per_bundle." - ) - - return [b for b in bundles if b] - - -# --------------------------------------------------------------------------- -# Section-building strategies -# --------------------------------------------------------------------------- - - -def group_chunks_by_doc(chunks: list[dict]) -> dict[str, list[tuple[int, dict]]]: - """Group chunks by their ``doc_id`` field. - - Args: - chunks: Chunk dicts, each optionally containing ``'doc_id'``. - - Returns: - Mapping from ``doc_id`` to ``(global_index, chunk)`` pairs. - """ - grouped: dict[str, list[tuple[int, dict]]] = defaultdict(list) - for idx, chunk in enumerate(chunks): - doc_id = chunk.get("doc_id", "default") - grouped[doc_id].append((idx, chunk)) - return dict(grouped) - - -def format_section_chunks(indexed_chunks: list[tuple[int, dict]], section_number: int) -> str: - """Render a list of indexed chunks into a section string. - - Args: - indexed_chunks: ``(global_index, chunk)`` tuples. - section_number: Section ordinal for the header. - - Returns: - Formatted section text, or ``""`` if no content. - """ - section_lines: list[str] = [] - for _, chunk in indexed_chunks: - text = chunk.get("text", "").strip() - if not text: - continue - segment_id = chunk.get("chunk_id", 1) - doc_id = chunk.get("doc_id", "") - start_time = "00:00:00" - end_time = "00:00:00" - if doc_id: - segment_info = f"Segment {segment_id} [Doc: {doc_id}] ({start_time} - {end_time}): {text}" - else: - segment_info = f"Segment {segment_id} ({start_time} - {end_time}): {text}" - section_lines.append(segment_info) - - if section_lines: - return f"=== Section {section_number} ===\n" + "\n".join(section_lines) - return "" - - -def chunks_to_sections_sequential(chunks: list[dict], num_sections: int = 1) -> list[str]: - """Split chunks sequentially into *num_sections* sections. - - Args: - chunks: Chunk dicts in document order. - num_sections: How many sections to produce. - - Returns: - List of formatted section strings. - """ - total = len(chunks) - if total == 0: - return [] - - section_size = max(1, total // num_sections) - formatted_sections: list[str] = [] - - for i in range(num_sections): - start_idx = i * section_size - end_idx = (i + 1) * section_size if i < num_sections - 1 else total - indexed_chunks = [(j, chunks[j]) for j in range(start_idx, end_idx)] - section_text = format_section_chunks(indexed_chunks, i + 1) - if section_text: - formatted_sections.append(section_text) - - return formatted_sections - - -def chunks_to_sections_doc_balanced(chunks: list[dict], num_sections: int = 1) -> list[str]: - """Split chunks so each section has proportional doc representation. - - Falls back to sequential when there is only one document. - - Args: - chunks: Chunk dicts with ``'doc_id'`` fields. - num_sections: How many sections to produce. - - Returns: - List of formatted section strings. - """ - if not chunks: - return [] - - grouped = group_chunks_by_doc(chunks) - if len(grouped) <= 1: - return chunks_to_sections_sequential(chunks, num_sections) - - chunk_sizes = {doc_id: max(1, math.ceil(len(entries) / num_sections)) for doc_id, entries in grouped.items()} - - sections: list[list[tuple[int, dict]]] = [] - for part_idx in range(num_sections): - part_entries: list[tuple[int, dict]] = [] - for doc_id, entries in grouped.items(): - chunk_size = chunk_sizes[doc_id] - start = part_idx * chunk_size - end = min(len(entries), start + chunk_size) - if start < len(entries): - part_entries.extend(entries[start:end]) - if part_entries: - sections.append(part_entries) - - formatted_sections: list[str] = [] - for i, indexed_chunks in enumerate(sections): - section_text = format_section_chunks(indexed_chunks, i + 1) - if section_text: - formatted_sections.append(section_text) - - return formatted_sections - - -def chunks_to_sections_interleaved(chunks: list[dict], num_sections: int = 1) -> list[str]: - """Split chunks with round-robin interleaving across documents. - - Falls back to sequential when there is only one document. - - Args: - chunks: Chunk dicts with ``'doc_id'`` fields. - num_sections: How many sections to produce. - - Returns: - List of formatted section strings. - """ - if not chunks: - return [] - - grouped = group_chunks_by_doc(chunks) - if len(grouped) <= 1: - return chunks_to_sections_sequential(chunks, num_sections) - - doc_iterators = {doc_id: deque(entries) for doc_id, entries in grouped.items()} - doc_order = list(grouped.keys()) - interleaved: list[tuple[int, dict]] = [] - - while True: - added = False - for doc_id in doc_order: - doc_queue = doc_iterators[doc_id] - if doc_queue: - interleaved.append(doc_queue.popleft()) - added = True - if not added: - break - - if not interleaved: - return [] - - total = len(interleaved) - section_size = max(1, total // num_sections) - formatted_sections: list[str] = [] - - for i in range(num_sections): - start_idx = i * section_size - end_idx = (i + 1) * section_size if i < num_sections - 1 else total - indexed_chunks = interleaved[start_idx:end_idx] - section_text = format_section_chunks(indexed_chunks, i + 1) - if section_text: - formatted_sections.append(section_text) - - return formatted_sections - - -def chunks_to_sections_structured( - chunks: list[dict], - num_sections: int = 1, - strategy: Literal["sequential", "doc_balanced", "interleaved"] = "sequential", -) -> list[str]: - """Split chunks into sections using the specified strategy. - - Args: - chunks: Chunk dicts. - num_sections: How many sections to produce. - strategy: ``"sequential"``, ``"doc_balanced"``, or ``"interleaved"``. - - Returns: - List of formatted section strings. - """ - if strategy == "doc_balanced": - return chunks_to_sections_doc_balanced(chunks, num_sections) - if strategy == "interleaved": - return chunks_to_sections_interleaved(chunks, num_sections) - return chunks_to_sections_sequential(chunks, num_sections) - - -# --------------------------------------------------------------------------- -# Sentence chunking -# --------------------------------------------------------------------------- - - -def _ensure_nltk_punkt() -> None: - """Download NLTK punkt tokeniser data if not already present.""" - for resource in ("tokenizers/punkt", "tokenizers/punkt_tab"): - try: - nltk.data.find(resource) - except LookupError: - nltk.download(resource.split("/")[-1], quiet=True) - - -def text_to_sentence_chunks( - text: str, - sentences_per_chunk: int = 5, - doc_id: str | None = None, - doc_path: str | None = None, - chunk_id_offset: int = 0, -) -> list[dict]: - """Chunk *text* into groups of sentences with metadata. - - Args: - text: Input text to chunk. - sentences_per_chunk: Sentences per chunk. - doc_id: Optional document identifier for multi-doc bundles. - doc_path: Optional document path for multi-doc bundles. - chunk_id_offset: Offset for global chunk IDs when aggregating. - - Returns: - List of chunk dicts with keys ``text``, ``start``, ``end``, - ``sentence_count``, ``word_count``, ``chunk_id``, - ``doc_chunk_index``, and optionally ``doc_id`` / ``doc_path``. - """ - _ensure_nltk_punkt() - - paragraphs = re.split(r"\n\s*\n+", text) - paragraphs = [p.strip() for p in paragraphs if p.strip()] - - sentences: list[str] = [] - for paragraph in paragraphs: - sentences.extend(sent_tokenize(paragraph)) - - chunks: list[dict] = [] - word_position = 0 - doc_chunk_index = 0 - - for i in range(0, len(sentences), sentences_per_chunk): - chunk_sentences = sentences[i : i + sentences_per_chunk] - chunk_text = ". ".join(chunk_sentences) - if chunk_text and not chunk_text.endswith("."): - chunk_text += "." - - chunk_words = chunk_text.split() - start_word_pos = word_position - end_word_pos = word_position + len(chunk_words) - word_position = end_word_pos - doc_chunk_index += 1 - - chunk_data: dict = { - "text": chunk_text, - "start": start_word_pos, - "end": end_word_pos, - "sentence_count": len(chunk_sentences), - "word_count": len(chunk_words), - "chunk_id": chunk_id_offset + len(chunks) + 1, - "doc_chunk_index": doc_chunk_index, - } - - if doc_id is not None: - chunk_data["doc_id"] = doc_id - if doc_path is not None: - chunk_data["doc_path"] = doc_path - - chunks.append(chunk_data) - - return chunks - - -# --------------------------------------------------------------------------- -# Top-level directory loader -# --------------------------------------------------------------------------- - - -def load_text_files_from_directory( - input_dir: Path, - file_extensions: list[str] | None = None, - min_text_length: int = 0, - sentences_per_chunk: int = 5, - num_sections: int = 1, - num_files: int | None = None, - multi_doc: bool = False, - bundle_size: int = 2, - bundle_strategy: Literal["sequential", "doc_balanced", "interleaved"] = "sequential", - max_docs_per_bundle: int = 3, - multi_doc_manifest: Path | None = None, -) -> pd.DataFrame: - """Load text files from a directory into a seed DataFrame. - - Supports single-document mode (one row per file) and multi-document mode - (files grouped into bundles, one row per bundle). - - Args: - input_dir: Root directory containing text files. - file_extensions: Allowed extensions (default ``[".txt", ".md", ".text", ""]``). - min_text_length: Minimum character count to include a document. - sentences_per_chunk: Sentences per chunk. - num_sections: Sections to split chunks into. - num_files: Cap on the number of files to process. - multi_doc: Enable multi-document bundling. - bundle_size: Documents per automatic bundle. - bundle_strategy: Section-building strategy. - max_docs_per_bundle: Hard cap on bundle size. - multi_doc_manifest: Path to a manifest defining explicit bundles. - - Returns: - DataFrame with columns ``file_name``, ``text``, ``chunks``, - ``sections_structured``, and (when multi-doc) ``bundle_id``, - ``bundle_members``, ``is_multi_doc``. - - Raises: - ValueError: If no text files or valid documents are found. - """ - if file_extensions is None: - file_extensions = [".txt", ".md", ".text", ""] - - all_file_paths: list[Path] = [] - for file_path in input_dir.rglob("*"): - if num_files is not None and len(all_file_paths) >= num_files: - break - if file_path.is_file() and file_matches_extensions(file_path, file_extensions): - try: - content = file_path.read_text(encoding="utf-8") - if min_text_length > 0 and len(content) < min_text_length: - continue - all_file_paths.append(file_path) - except Exception as e: - print(f"Warning: Could not read {file_path}: {e}") - continue - - if not all_file_paths: - raise ValueError(f"No text files found in {input_dir} with extensions {file_extensions}") - - resolved_input_dir = input_dir.resolve() - documents: list[dict] = [] - - if multi_doc: - documents = _load_multi_doc( - all_file_paths, - resolved_input_dir, - sentences_per_chunk, - num_sections, - bundle_size, - bundle_strategy, - max_docs_per_bundle, - multi_doc_manifest, - ) - else: - documents = _load_single_doc( - all_file_paths, - input_dir, - sentences_per_chunk, - num_sections, - bundle_strategy, - ) - - if not documents: - raise ValueError(f"No valid documents created from {input_dir}") - - df = pd.DataFrame(documents) - _print_load_stats(df, all_file_paths, multi_doc, min_text_length, bundle_strategy) - return df - - -# --------------------------------------------------------------------------- -# Internal loader helpers -# --------------------------------------------------------------------------- - - -def _load_single_doc( - file_paths: list[Path], - input_dir: Path, - sentences_per_chunk: int, - num_sections: int, - bundle_strategy: Literal["sequential", "doc_balanced", "interleaved"], -) -> list[dict]: - """Build one row per file.""" - documents: list[dict] = [] - for file_path in file_paths: - relative_path = file_path.relative_to(input_dir) - try: - content = file_path.read_text(encoding="utf-8") - except Exception as e: - print(f"Warning: Could not read {relative_path}: {e}") - continue - - chunks = text_to_sentence_chunks(content, sentences_per_chunk=sentences_per_chunk) - sections_structured = chunks_to_sections_structured(chunks, num_sections=num_sections, strategy=bundle_strategy) - documents.append( - { - "file_name": [str(relative_path)], - "text": content, - "chunks": chunks, - "sections_structured": sections_structured, - "bundle_id": "", - "bundle_members": [str(relative_path)], - "is_multi_doc": False, - } - ) - return documents - - -def _load_multi_doc( - file_paths: list[Path], - resolved_input_dir: Path, - sentences_per_chunk: int, - num_sections: int, - bundle_size: int, - bundle_strategy: Literal["sequential", "doc_balanced", "interleaved"], - max_docs_per_bundle: int, - multi_doc_manifest: Path | None, -) -> list[dict]: - """Build one row per bundle.""" - manifest_bundles = load_multi_doc_manifest(multi_doc_manifest) - bundles = build_bundles( - file_paths, - bundle_size=bundle_size, - max_docs_per_bundle=max_docs_per_bundle, - manifest_bundles=manifest_bundles, - input_dir=resolved_input_dir, - ) - - print(f"Multi-doc mode: Created {len(bundles)} bundles from {len(file_paths)} files") - documents: list[dict] = [] - - for bundle in bundles: - bundle_texts: list[str] = [] - bundle_chunks: list[dict] = [] - bundle_members: list[str] = [] - chunk_id_offset = 0 - - for file_path in bundle: - relative_path = file_path.relative_to(resolved_input_dir) - doc_id = str(relative_path) - bundle_members.append(doc_id) - - try: - content = file_path.read_text(encoding="utf-8") - except Exception as e: - print(f"Warning: Could not read {file_path}: {e}") - continue - - bundle_texts.append(content) - doc_chunks = text_to_sentence_chunks( - content, - sentences_per_chunk=sentences_per_chunk, - doc_id=doc_id, - doc_path=str(file_path), - chunk_id_offset=chunk_id_offset, - ) - bundle_chunks.extend(doc_chunks) - chunk_id_offset += len(doc_chunks) - - if not bundle_chunks: - continue - - combined_text = "\n\n=== Document Boundary ===\n\n".join(bundle_texts) - sections_structured = chunks_to_sections_structured( - bundle_chunks, num_sections=num_sections, strategy=bundle_strategy - ) - bid = build_bundle_id(bundle_members) - - documents.append( - { - "file_name": bundle_members, - "text": combined_text, - "chunks": bundle_chunks, - "sections_structured": sections_structured, - "bundle_id": bid, - "bundle_members": bundle_members, - "is_multi_doc": True, - } - ) - - return documents - - -def _print_load_stats( - df: pd.DataFrame, - all_file_paths: list[Path], - multi_doc: bool, - min_text_length: int, - bundle_strategy: str, -) -> None: - """Print statistics about the loaded data.""" - row_type = "bundle" if multi_doc else "document" - if multi_doc: - avg_docs = sum(len(m) for m in df["bundle_members"]) / len(df) if len(df) > 0 else 0 - print(f"Created {len(df)} bundles from {len(all_file_paths)} files") - print(f"Average documents per bundle: {avg_docs:.1f}") - else: - print(f"Loaded {len(df)} text files from directory") - - if min_text_length > 0: - print(f"Filtered to documents with at least {min_text_length} characters") - - total_chunks = sum(len(c) for c in df["chunks"]) - avg_chunks = total_chunks / len(df) if len(df) > 0 else 0 - print(f"Created {total_chunks} total chunks ({avg_chunks:.1f} chunks per {row_type})") - - total_sections = sum(len(s) for s in df["sections_structured"]) - avg_sections = total_sections / len(df) if len(df) > 0 else 0 - avg_chunks_per_section = total_chunks / total_sections if total_sections > 0 else 0 - print( - f"Organized into {total_sections} sections " - f"({avg_sections:.1f} sections per {row_type}, " - f"{avg_chunks_per_section:.1f} chunks per section)" - ) - print(f"Bundle strategy: {bundle_strategy}") diff --git a/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/pipeline.py b/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/pipeline.py index e1eacfa..b7b7760 100644 --- a/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/pipeline.py +++ b/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/pipeline.py @@ -17,10 +17,9 @@ from pathlib import Path import data_designer.config as dd -import pandas as pd from data_designer.config.default_model_settings import get_default_providers -from data_designer_retrieval_sdg.config import RetrievalSdgDedupColumnConfig +from data_designer_retrieval_sdg.config import EmbeddingDedupColumnConfig from data_designer_retrieval_sdg.models import ( DocumentArtifacts, QAPairEvaluations, @@ -34,10 +33,7 @@ QA_GENERATION_SYSTEM_PROMPT, QA_GENERATION_USER_PROMPT, ) - -# --------------------------------------------------------------------------- -# Model configuration -# --------------------------------------------------------------------------- +from data_designer_retrieval_sdg.seed_source import DocumentChunkerSeedSource DEFAULT_CHAT_MODEL = "nvidia/nemotron-3-nano-30b-a3b" DEFAULT_EMBED_MODEL = "nvidia/llama-3.2-nv-embedqa-1b-v2" @@ -75,7 +71,7 @@ def custom_model_config( for chat-completion models. Returns: - Tuple of ``(model_configs, role_aliases)`` where *role_aliases* + Tuple of ``(model_configs, role_aliases)`` where ``role_aliases`` maps each role name to the ``ModelConfig`` alias it should reference. """ configs: list[dd.ModelConfig] = [ @@ -122,11 +118,6 @@ def custom_model_config( return configs, role_aliases -# --------------------------------------------------------------------------- -# Model-provider helpers -# --------------------------------------------------------------------------- - - def build_model_providers( custom_provider_endpoint: str | None = None, custom_provider_name: str = "custom", @@ -138,10 +129,8 @@ def build_model_providers( Inline flags define a single provider; the config file can define multiple. When both are supplied the inline provider overwrites any - file entry with the same name. - - Custom providers are merged with Data Designer defaults so that built-in - providers remain available. + file entry with the same name. Custom providers are merged with Data + Designer defaults so that built-in providers remain available. Args: custom_provider_endpoint: Base URL for an inline custom provider. @@ -151,8 +140,8 @@ def build_model_providers( model_providers_file: Path to a YAML/JSON file with provider entries. Returns: - Tuple of ``(all_providers, custom_only_providers)``. - ``all_providers`` is ``None`` when no custom providers exist. + Tuple of ``(all_providers, custom_only_providers)``. ``all_providers`` + is ``None`` when no custom providers exist. """ import yaml @@ -189,10 +178,6 @@ def build_model_providers( return defaults + custom, custom -# --------------------------------------------------------------------------- -# Pipeline builder -# --------------------------------------------------------------------------- - DEFAULT_QUERY_COUNTS: dict[str, int] = {"multi_hop": 3, "structural": 2, "contextual": 2} DEFAULT_REASONING_COUNTS: dict[str, int] = { "factual": 1, @@ -206,7 +191,7 @@ def build_model_providers( def build_qa_generation_pipeline( - seed_dataset: pd.DataFrame, + seed_source: DocumentChunkerSeedSource, start_index: int = 0, end_index: int = 199, max_artifacts_per_type: int = 2, @@ -216,6 +201,7 @@ def build_qa_generation_pipeline( max_hops: int = 3, reasoning_counts: dict[str, int] | None = None, min_complexity: int = 4, + similarity_threshold: float = 0.9, max_parallel_requests_for_gen: int | None = None, artifact_extraction_model: str = DEFAULT_CHAT_MODEL, artifact_extraction_provider: str = DEFAULT_PROVIDER, @@ -236,8 +222,9 @@ def build_qa_generation_pipeline( 4. ``qa_evaluations`` -- quality scoring Args: - seed_dataset: DataFrame with ``file_name``, ``text``, ``chunks``, - ``sections_structured`` columns. + seed_source: Configured :class:`DocumentChunkerSeedSource` whose + output schema includes ``file_name``, ``text``, ``chunks``, + ``sections_structured``. start_index: Start index (inclusive) for ordered index-range selection. end_index: End index (inclusive) for ordered index-range selection. max_artifacts_per_type: Max artifacts extracted per type. @@ -247,6 +234,7 @@ def build_qa_generation_pipeline( max_hops: Maximum hops for multi-hop questions. reasoning_counts: Distribution of reasoning types. min_complexity: Minimum complexity score. + similarity_threshold: Cosine similarity threshold for QA-pair dedup. max_parallel_requests_for_gen: Cap on parallel requests for chat models. artifact_extraction_model: Model for artifact extraction. artifact_extraction_provider: Provider for artifact extraction. @@ -281,12 +269,11 @@ def build_qa_generation_pipeline( config_builder = dd.DataDesignerConfigBuilder(model_configs=model_configs) config_builder.with_seed_dataset( - dd.DataFrameSeedSource(df=seed_dataset), + seed_source, sampling_strategy=dd.SamplingStrategy.ORDERED, selection_strategy=dd.IndexRange(start=start_index, end=end_index), ) - # Column 1: artifact extraction config_builder.add_column( dd.LLMStructuredColumnConfig( name="document_artifacts", @@ -299,7 +286,6 @@ def build_qa_generation_pipeline( ) ) - # Column 2: QA generation config_builder.add_column( dd.LLMStructuredColumnConfig( name="qa_generation", @@ -325,17 +311,17 @@ def build_qa_generation_pipeline( ) ) - # Column 3: deduplication (plugin column) config_builder.add_column( - RetrievalSdgDedupColumnConfig( + EmbeddingDedupColumnConfig( name="deduplicated_qa_pairs", - qa_pairs_column="qa_generation", - embedding_alias="embed", - dedupe_similarity_threshold=0.9, + source_column="qa_generation", + items_key="pairs", + text_field="question", + model_alias="embed", + similarity_threshold=similarity_threshold, ) ) - # Column 4: quality evaluation config_builder.add_column( dd.LLMStructuredColumnConfig( name="qa_evaluations", diff --git a/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/plugin.py b/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/plugin.py deleted file mode 100644 index aa19889..0000000 --- a/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/plugin.py +++ /dev/null @@ -1,12 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -"""Data Designer plugin registration for the retrieval-sdg-dedup column type.""" - -from data_designer.plugins.plugin import Plugin, PluginType - -plugin = Plugin( - config_qualified_name="data_designer_retrieval_sdg.config.RetrievalSdgDedupColumnConfig", - impl_qualified_name="data_designer_retrieval_sdg.dedup.RetrievalSdgDedupColumnGenerator", - plugin_type=PluginType.COLUMN_GENERATOR, -) diff --git a/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/plugins.py b/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/plugins.py new file mode 100644 index 0000000..26606f1 --- /dev/null +++ b/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/plugins.py @@ -0,0 +1,27 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Data Designer plugin registrations exported by this package. + +Two ``data_designer.plugins`` entry points are wired here: + +- :data:`embedding_dedup_plugin` -- generic embedding-cosine-similarity + deduplication column generator (``column_type="embedding-dedup"``). +- :data:`document_chunker_plugin` -- filesystem seed reader that loads + text files, chunks them by sentence, and emits structured sections + (``seed_type="document-chunker"``). +""" + +from data_designer.plugins.plugin import Plugin, PluginType + +embedding_dedup_plugin = Plugin( + config_qualified_name="data_designer_retrieval_sdg.config.EmbeddingDedupColumnConfig", + impl_qualified_name="data_designer_retrieval_sdg.dedup.EmbeddingDedupColumnGenerator", + plugin_type=PluginType.COLUMN_GENERATOR, +) + +document_chunker_plugin = Plugin( + config_qualified_name="data_designer_retrieval_sdg.seed_source.DocumentChunkerSeedSource", + impl_qualified_name="data_designer_retrieval_sdg.seed_reader.DocumentChunkerSeedReader", + plugin_type=PluginType.SEED_READER, +) diff --git a/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/seed_reader.py b/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/seed_reader.py new file mode 100644 index 0000000..e58f211 --- /dev/null +++ b/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/seed_reader.py @@ -0,0 +1,242 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Filesystem seed reader that loads, chunks, and sections text files. + +Implements the framework's :class:`FileSystemSeedReader` contract: a cheap +``build_manifest`` that lists discovered files (or bundles), and an +expensive ``hydrate_row`` that reads file contents and produces the +chunked output rows. +""" + +from __future__ import annotations + +import json +import logging +from pathlib import Path, PurePosixPath +from typing import Any, ClassVar + +from data_designer.engine.resources.seed_reader import ( + FileSystemSeedReader, + SeedReaderError, + SeedReaderFileSystemContext, +) + +from data_designer_retrieval_sdg.chunking import ( + build_bundle_id, + build_bundles, + chunks_to_sections_structured, + load_multi_doc_manifest, + text_to_sentence_chunks, +) +from data_designer_retrieval_sdg.seed_source import DocumentChunkerSeedSource + +logger = logging.getLogger(__name__) + + +def _path_matches_extensions(relative_path: str, extensions: list[str] | None) -> bool: + """Return ``True`` when ``relative_path`` passes extension filtering. + + When ``extensions`` is ``None``, no filtering is applied. A literal + empty string ``""`` in the list matches files whose basename contains + no dot (i.e. no extension). + """ + if not extensions: + return True + ext_set = {e.lower() for e in extensions} + suffix = PurePosixPath(relative_path).suffix.lower() + if suffix in ext_set: + return True + if "" in ext_set and "." not in PurePosixPath(relative_path).name: + return True + return False + + +class DocumentChunkerSeedReader(FileSystemSeedReader[DocumentChunkerSeedSource]): + """Sentence-chunk text files into a DataDesigner seed dataset. + + Output schema (one record per row): + + - ``file_name``: ``list[str]`` of relative paths (always a list, + even in single-doc mode, for downstream uniformity). + - ``text``: combined document text. In multi-doc mode documents are + joined with ``"\\n\\n=== Document Boundary ===\\n\\n"`` separators. + - ``chunks``: ``list[dict]`` of sentence chunks with metadata. + - ``sections_structured``: ``list[str]`` of formatted section blocks. + - ``bundle_id``: stable hash of the bundle members (single-doc rows + have an empty string). + - ``bundle_members``: ``list[str]`` of relative paths (mirrors + ``file_name``; preserved for backward compatibility). + - ``is_multi_doc``: ``True`` when ``DocumentChunkerSeedSource.multi_doc`` + is enabled, ``False`` otherwise. + """ + + output_columns: ClassVar[list[str] | None] = [ + "file_name", + "text", + "chunks", + "sections_structured", + "bundle_id", + "bundle_members", + "is_multi_doc", + ] + + def build_manifest(self, *, context: SeedReaderFileSystemContext) -> list[dict[str, Any]]: + """Discover files (and bundles) under ``context.root_path``. + + In single-doc mode each row references one file. In multi-doc + mode each row references a bundle of files; the bundle membership + is JSON-encoded in ``bundle_members_json`` so the manifest stays + a flat string-only schema (DuckDB-friendly). + """ + matched_paths = self.get_matching_relative_paths( + context=context, + file_pattern=self.source.file_pattern, + recursive=self.source.recursive, + ) + matched_paths = [p for p in matched_paths if _path_matches_extensions(p, self.source.file_extensions)] + + if self.source.num_files is not None: + matched_paths = matched_paths[: self.source.num_files] + + if not matched_paths: + raise SeedReaderError( + f"No files matched extensions {self.source.file_extensions!r} under {context.root_path}" + ) + + if self.source.multi_doc: + return self._build_multi_doc_manifest(matched_paths, context) + return [{"bundle_members_json": json.dumps([p])} for p in matched_paths] + + def hydrate_row( + self, + *, + manifest_row: dict[str, Any], + context: SeedReaderFileSystemContext, + ) -> dict[str, Any] | list[dict[str, Any]]: + """Read file contents for the manifest row and emit a chunked record. + + Returns an empty list when no file in the row passes + ``min_text_length`` or no chunks are produced (the row is dropped). + """ + members: list[str] = json.loads(manifest_row["bundle_members_json"]) + is_multi_doc = self.source.multi_doc + + if not is_multi_doc: + record = self._hydrate_single(members[0], context) + return [record] if record else [] + + record = self._hydrate_bundle(members, context) + return [record] if record else [] + + def _build_multi_doc_manifest( + self, + matched_paths: list[str], + context: SeedReaderFileSystemContext, + ) -> list[dict[str, Any]]: + manifest_path = Path(self.source.multi_doc_manifest) if self.source.multi_doc_manifest else None + manifest_bundles = load_multi_doc_manifest(manifest_path) + + absolute_paths = [context.root_path / rel for rel in matched_paths] + bundles = build_bundles( + absolute_paths, + bundle_size=self.source.bundle_size, + max_docs_per_bundle=self.source.max_docs_per_bundle, + manifest_bundles=manifest_bundles, + input_dir=context.root_path, + ) + if not bundles: + raise SeedReaderError(f"build_bundles produced no bundles from {context.root_path}") + + manifest: list[dict[str, Any]] = [] + for bundle_paths in bundles: + relative_members = [str(p.relative_to(context.root_path)) for p in bundle_paths] + manifest.append({"bundle_members_json": json.dumps(relative_members)}) + return manifest + + def _read_file(self, relative_path: str, context: SeedReaderFileSystemContext) -> str | None: + """Read a single file, returning ``None`` when it is too short or unreadable.""" + absolute_path = context.root_path / relative_path + try: + with context.fs.open(relative_path, "r", encoding="utf-8") as handle: + content = handle.read() + except (OSError, UnicodeDecodeError) as exc: + logger.warning("Skipping %s: %s", absolute_path, exc) + return None + + if self.source.min_text_length > 0 and len(content) < self.source.min_text_length: + return None + return content + + def _hydrate_single( + self, + relative_path: str, + context: SeedReaderFileSystemContext, + ) -> dict[str, Any] | None: + content = self._read_file(relative_path, context) + if content is None: + return None + + chunks = text_to_sentence_chunks(content, sentences_per_chunk=self.source.sentences_per_chunk) + if not chunks: + return None + + sections = chunks_to_sections_structured( + chunks, + num_sections=self.source.num_sections, + strategy=self.source.bundle_strategy, + ) + return { + "file_name": [relative_path], + "text": content, + "chunks": chunks, + "sections_structured": sections, + "bundle_id": "", + "bundle_members": [relative_path], + "is_multi_doc": False, + } + + def _hydrate_bundle( + self, + relative_members: list[str], + context: SeedReaderFileSystemContext, + ) -> dict[str, Any] | None: + bundle_texts: list[str] = [] + bundle_chunks: list[dict[str, Any]] = [] + bundle_members: list[str] = [] + chunk_id_offset = 0 + + for relative_path in relative_members: + content = self._read_file(relative_path, context) + if content is None: + continue + bundle_members.append(relative_path) + bundle_texts.append(content) + doc_chunks = text_to_sentence_chunks( + content, + sentences_per_chunk=self.source.sentences_per_chunk, + doc_id=relative_path, + doc_path=str(context.root_path / relative_path), + chunk_id_offset=chunk_id_offset, + ) + bundle_chunks.extend(doc_chunks) + chunk_id_offset += len(doc_chunks) + + if not bundle_chunks: + return None + + combined_text = "\n\n=== Document Boundary ===\n\n".join(bundle_texts) + sections = chunks_to_sections_structured( + bundle_chunks, + num_sections=self.source.num_sections, + strategy=self.source.bundle_strategy, + ) + return { + "file_name": bundle_members, + "text": combined_text, + "chunks": bundle_chunks, + "sections_structured": sections, + "bundle_id": build_bundle_id(bundle_members), + "bundle_members": bundle_members, + "is_multi_doc": True, + } diff --git a/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/seed_source.py b/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/seed_source.py new file mode 100644 index 0000000..4a7dee6 --- /dev/null +++ b/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/seed_source.py @@ -0,0 +1,66 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Seed source configuration for the document-chunker plugin.""" + +from __future__ import annotations + +from typing import Literal + +from data_designer.config.base import ConfigBase +from data_designer.config.seed_source import FileSystemSeedSource +from pydantic import Field + + +class DocumentChunkerSeedSource(FileSystemSeedSource, ConfigBase): + """Load text files, sentence-chunk them, and build structured sections. + + Subclasses :class:`FileSystemSeedSource` (so the framework owns + directory discovery, glob matching, and DuckDB registration) and + :class:`ConfigBase` (required by ``assert_valid_plugin``). This + config layers chunking and multi-document bundling parameters on top. + + Inherited fields: + path: Directory containing source text files. + file_pattern: Filename glob (basenames only). Defaults to ``"*"``. + recursive: Whether to descend into subdirectories. + + Args: + file_extensions: Optional list of allowed file extensions (e.g. + ``[".txt", ".md"]``). Filtered after glob matching against + ``file_pattern``. ``None`` disables extension filtering. + min_text_length: Minimum character count to keep a document. + sentences_per_chunk: Sentences grouped into a single chunk. + num_sections: Sections to organise chunks into per row. + num_files: Cap on the number of files to load (``None`` = no cap). + multi_doc: If true, group files into multi-document bundles + (one row per bundle) instead of one row per file. + bundle_size: Documents per automatic bundle. + bundle_strategy: ``"sequential"`` / ``"doc_balanced"`` / + ``"interleaved"``; controls how chunks across documents are + split into sections. + max_docs_per_bundle: Hard cap on bundle size. + multi_doc_manifest: Optional path to a JSON/YAML manifest + defining explicit bundles; falls back to automatic bundling + for any files not listed. + """ + + seed_type: Literal["document-chunker"] = "document-chunker" + + file_extensions: list[str] | None = Field( + default=None, + description=( + "Optional list of allowed file extensions (e.g. ['.txt', '.md']). " + "Filtered after glob matching against file_pattern." + ), + ) + min_text_length: int = Field(default=0, ge=0) + sentences_per_chunk: int = Field(default=5, ge=1) + num_sections: int = Field(default=1, ge=1) + num_files: int | None = Field(default=None, ge=1) + + multi_doc: bool = False + bundle_size: int = Field(default=2, ge=1) + bundle_strategy: Literal["sequential", "doc_balanced", "interleaved"] = "sequential" + max_docs_per_bundle: int = Field(default=3, ge=1) + multi_doc_manifest: str | None = None diff --git a/plugins/data-designer-retrieval-sdg/tests/test_chunking.py b/plugins/data-designer-retrieval-sdg/tests/test_chunking.py new file mode 100644 index 0000000..d44341e --- /dev/null +++ b/plugins/data-designer-retrieval-sdg/tests/test_chunking.py @@ -0,0 +1,109 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for the chunking, section, and bundling helpers.""" + +from pathlib import Path + +import pytest + +from data_designer_retrieval_sdg.chunking import ( + build_bundle_id, + build_bundles, + chunks_to_sections_structured, + text_to_sentence_chunks, +) + + +def test_text_to_sentence_chunks_basic() -> None: + text = "First sentence. Second sentence. Third sentence. Fourth sentence. Fifth sentence. Sixth sentence." + chunks = text_to_sentence_chunks(text, sentences_per_chunk=3) + assert len(chunks) == 2 + assert chunks[0]["chunk_id"] == 1 + assert chunks[1]["chunk_id"] == 2 + assert chunks[0]["sentence_count"] == 3 + + +def test_text_to_sentence_chunks_with_doc_id() -> None: + chunks = text_to_sentence_chunks("Hello world. Goodbye.", sentences_per_chunk=5, doc_id="doc1") + assert len(chunks) == 1 + assert chunks[0]["doc_id"] == "doc1" + + +def test_text_to_sentence_chunks_empty() -> None: + assert text_to_sentence_chunks("") == [] + + +def test_chunks_to_sections_sequential() -> None: + chunks = [{"text": f"chunk {i}", "chunk_id": i} for i in range(1, 7)] + sections = chunks_to_sections_structured(chunks, num_sections=2, strategy="sequential") + assert len(sections) == 2 + assert "Section 1" in sections[0] + assert "Section 2" in sections[1] + + +def test_chunks_to_sections_empty() -> None: + assert chunks_to_sections_structured([], num_sections=2) == [] + + +def test_chunks_to_sections_doc_balanced_falls_back_to_sequential_for_single_doc() -> None: + chunks = [{"text": f"chunk {i}", "chunk_id": i, "doc_id": "only"} for i in range(1, 5)] + sections = chunks_to_sections_structured(chunks, num_sections=2, strategy="doc_balanced") + assert len(sections) == 2 + + +def test_chunks_to_sections_doc_balanced_multi_doc() -> None: + chunks = [ + {"text": "a1", "chunk_id": 1, "doc_id": "a"}, + {"text": "a2", "chunk_id": 2, "doc_id": "a"}, + {"text": "b1", "chunk_id": 3, "doc_id": "b"}, + {"text": "b2", "chunk_id": 4, "doc_id": "b"}, + ] + sections = chunks_to_sections_structured(chunks, num_sections=2, strategy="doc_balanced") + assert len(sections) == 2 + for section in sections: + assert "[Doc: a]" in section + assert "[Doc: b]" in section + + +def test_chunks_to_sections_interleaved_multi_doc() -> None: + chunks = [ + {"text": "a1", "chunk_id": 1, "doc_id": "a"}, + {"text": "a2", "chunk_id": 2, "doc_id": "a"}, + {"text": "b1", "chunk_id": 3, "doc_id": "b"}, + ] + sections = chunks_to_sections_structured(chunks, num_sections=1, strategy="interleaved") + assert len(sections) == 1 + assert "[Doc: a]" in sections[0] + assert "[Doc: b]" in sections[0] + + +def test_build_bundles_sequential(tmp_path: Path) -> None: + files = [tmp_path / f"f{i}.txt" for i in range(4)] + for f in files: + f.write_text("content") + bundles = build_bundles(files, bundle_size=2, max_docs_per_bundle=3) + assert len(bundles) == 2 + assert len(bundles[0]) == 2 + + +def test_build_bundles_exceeds_max(tmp_path: Path) -> None: + files = [tmp_path / f"f{i}.txt" for i in range(4)] + for f in files: + f.write_text("content") + with pytest.raises(ValueError, match="exceeds max_docs_per_bundle"): + build_bundles(files, bundle_size=4, max_docs_per_bundle=2) + + +def test_build_bundles_empty() -> None: + assert build_bundles([], bundle_size=2, max_docs_per_bundle=3) == [] + + +def test_build_bundle_id_deterministic() -> None: + a = build_bundle_id(["a.txt", "b.txt"]) + b = build_bundle_id(["b.txt", "a.txt"]) + assert a == b + + +def test_build_bundle_id_empty() -> None: + assert build_bundle_id([]) == "" diff --git a/plugins/data-designer-retrieval-sdg/tests/test_dedup.py b/plugins/data-designer-retrieval-sdg/tests/test_dedup.py index 933fce8..4d0fff6 100644 --- a/plugins/data-designer-retrieval-sdg/tests/test_dedup.py +++ b/plugins/data-designer-retrieval-sdg/tests/test_dedup.py @@ -1,74 +1,153 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -"""Tests for the deduplication logic (pure numpy, no LLM needed).""" +"""Tests for the embedding-dedup column generator.""" -from data_designer_retrieval_sdg.config import RetrievalSdgDedupColumnConfig +import asyncio +from unittest.mock import AsyncMock, MagicMock +from data_designer_retrieval_sdg.config import EmbeddingDedupColumnConfig +from data_designer_retrieval_sdg.dedup import EmbeddingDedupColumnGenerator -def _make_generator(): - """Instantiate the generator with minimal wiring for dedupe_qa_pairs. - We only need the config for threshold; the embedder is not used in - dedupe_qa_pairs itself. - """ - from unittest.mock import MagicMock - - from data_designer_retrieval_sdg.dedup import RetrievalSdgDedupColumnGenerator - - config = RetrievalSdgDedupColumnConfig( +def _make_generator( + *, + source_column: str = "qa", + items_key: str | None = "pairs", + text_field: str = "question", + threshold: float = 0.9, +) -> EmbeddingDedupColumnGenerator: + """Instantiate the generator with minimal wiring for unit-level tests.""" + config = EmbeddingDedupColumnConfig( name="dedup", - qa_pairs_column="qa", - embedding_alias="embed", - dedupe_similarity_threshold=0.9, + source_column=source_column, + items_key=items_key, + text_field=text_field, + model_alias="embed", + similarity_threshold=threshold, ) - gen = object.__new__(RetrievalSdgDedupColumnGenerator) + gen = object.__new__(EmbeddingDedupColumnGenerator) gen._config = config gen._resource_provider = MagicMock() return gen -def test_dedupe_empty() -> None: +def test_dedupe_indices_empty() -> None: gen = _make_generator() - assert gen.dedupe_qa_pairs([]) == [] + assert gen.dedupe_indices([]) == [] -def test_dedupe_no_duplicates() -> None: +def test_dedupe_indices_no_duplicates() -> None: gen = _make_generator() embeddings = [[1.0, 0.0], [0.0, 1.0], [0.7, 0.7]] - kept = gen.dedupe_qa_pairs(embeddings) - assert kept == [0, 1, 2] + assert gen.dedupe_indices(embeddings) == [0, 1, 2] -def test_dedupe_identical_vectors() -> None: +def test_dedupe_indices_identical_vectors() -> None: gen = _make_generator() embeddings = [[1.0, 0.0], [1.0, 0.0], [0.0, 1.0]] - kept = gen.dedupe_qa_pairs(embeddings) + kept = gen.dedupe_indices(embeddings) assert 0 in kept assert 1 not in kept assert 2 in kept -def test_dedupe_near_threshold() -> None: +def test_dedupe_indices_near_threshold() -> None: gen = _make_generator() v1 = [1.0, 0.0] - v2 = [0.95, 0.3122] # cosine sim ≈ 0.95 > 0.9 + v2 = [0.95, 0.3122] v3 = [0.0, 1.0] - kept = gen.dedupe_qa_pairs([v1, v2, v3]) + kept = gen.dedupe_indices([v1, v2, v3]) assert 0 in kept assert 1 not in kept assert 2 in kept -def test_dedupe_single_element() -> None: +def test_dedupe_indices_single_element() -> None: + gen = _make_generator() + assert gen.dedupe_indices([[1.0, 0.0]]) == [0] + + +def test_resolve_items_with_items_key() -> None: + gen = _make_generator(items_key="pairs") + items = gen.resolve_items({"qa": {"pairs": [{"question": "x"}]}}) + assert items == [{"question": "x"}] + + +def test_resolve_items_without_items_key() -> None: + gen = _make_generator(items_key=None) + items = gen.resolve_items({"qa": [{"question": "x"}]}) + assert items == [{"question": "x"}] + + +def test_resolve_items_missing_source_returns_empty_list() -> None: + gen = _make_generator(items_key=None) + assert gen.resolve_items({}) == [] + + +def test_extract_text_dict_and_attribute() -> None: + gen = _make_generator(text_field="question") + assert gen.extract_text({"question": "hello"}) == "hello" + + class Item: + question = "world" + + assert gen.extract_text(Item()) == "world" + + +def test_generate_calls_embedder_once_with_all_texts() -> None: gen = _make_generator() - kept = gen.dedupe_qa_pairs([[1.0, 0.0]]) - assert kept == [0] + embedder = MagicMock() + embedder.generate_text_embeddings.return_value = [[1.0, 0.0], [1.0, 0.0], [0.0, 1.0]] + gen.resource_provider.model_registry.get_model.return_value = embedder + row = {"qa": {"pairs": [{"question": "a"}, {"question": "b"}, {"question": "c"}]}} + out = gen.generate(row) -def test_config_column_type() -> None: - cfg = RetrievalSdgDedupColumnConfig(name="dedup", qa_pairs_column="qa", embedding_alias="embed") - assert cfg.column_type == "retrieval-sdg-dedup" - assert cfg.required_columns == ["qa"] + embedder.generate_text_embeddings.assert_called_once() + call_kwargs = embedder.generate_text_embeddings.call_args.kwargs + assert call_kwargs["input_texts"] == ["a", "b", "c"] + assert call_kwargs["encoding_format"] == "float" + assert out["dedup"] == [{"question": "a"}, {"question": "c"}] + + +def test_agenerate_uses_async_embedder() -> None: + gen = _make_generator() + embedder = MagicMock() + embedder.agenerate_text_embeddings = AsyncMock(return_value=[[1.0, 0.0], [1.0, 0.0]]) + embedder.generate_text_embeddings = MagicMock() + gen.resource_provider.model_registry.get_model.return_value = embedder + + row = {"qa": {"pairs": [{"question": "a"}, {"question": "b"}]}} + out = asyncio.run(gen.agenerate(row)) + + embedder.agenerate_text_embeddings.assert_awaited_once() + embedder.generate_text_embeddings.assert_not_called() + assert out["dedup"] == [{"question": "a"}] + + +def test_agenerate_empty_items_short_circuits() -> None: + gen = _make_generator() + embedder = MagicMock() + embedder.agenerate_text_embeddings = AsyncMock() + gen.resource_provider.model_registry.get_model.return_value = embedder + + out = asyncio.run(gen.agenerate({"qa": {"pairs": []}})) + + embedder.agenerate_text_embeddings.assert_not_awaited() + assert out["dedup"] == [] + + +def test_config_round_trip() -> None: + cfg = EmbeddingDedupColumnConfig( + name="dedup", + source_column="qa_generation", + model_alias="embed", + ) + assert cfg.column_type == "embedding-dedup" + assert cfg.required_columns == ["qa_generation"] assert cfg.side_effect_columns == [] assert cfg.get_column_emoji() == "🔍" + assert cfg.items_key == "pairs" + assert cfg.text_field == "question" + assert cfg.similarity_threshold == 0.9 diff --git a/plugins/data-designer-retrieval-sdg/tests/test_ingest.py b/plugins/data-designer-retrieval-sdg/tests/test_ingest.py deleted file mode 100644 index 5c87473..0000000 --- a/plugins/data-designer-retrieval-sdg/tests/test_ingest.py +++ /dev/null @@ -1,145 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from pathlib import Path - -import pytest - -from data_designer_retrieval_sdg.ingest import ( - build_bundle_id, - build_bundles, - chunks_to_sections_structured, - file_matches_extensions, - is_traditional_extension, - load_text_files_from_directory, - text_to_sentence_chunks, -) - -# --------------------------------------------------------------------------- -# is_traditional_extension -# --------------------------------------------------------------------------- - - -def test_traditional_extensions() -> None: - assert is_traditional_extension(".txt") is True - assert is_traditional_extension(".md") is True - assert is_traditional_extension(".json") is True - assert is_traditional_extension(".mp3") is True - - -def test_non_traditional_extensions() -> None: - assert is_traditional_extension("") is False - assert is_traditional_extension(".com_publication_2001") is False - assert is_traditional_extension(".averylongextension123") is False - - -# --------------------------------------------------------------------------- -# file_matches_extensions -# --------------------------------------------------------------------------- - - -def test_file_matches_extensions_standard() -> None: - assert file_matches_extensions(Path("doc.txt"), [".txt", ".md"]) is True - assert file_matches_extensions(Path("doc.py"), [".txt", ".md"]) is False - - -def test_file_matches_extensions_no_ext() -> None: - assert file_matches_extensions(Path("README"), [""]) is True - - -# --------------------------------------------------------------------------- -# text_to_sentence_chunks -# --------------------------------------------------------------------------- - - -def test_text_to_sentence_chunks_basic() -> None: - text = "First sentence. Second sentence. Third sentence. Fourth sentence. Fifth sentence. Sixth sentence." - chunks = text_to_sentence_chunks(text, sentences_per_chunk=3) - assert len(chunks) == 2 - assert chunks[0]["chunk_id"] == 1 - assert chunks[1]["chunk_id"] == 2 - assert chunks[0]["sentence_count"] == 3 - - -def test_text_to_sentence_chunks_with_doc_id() -> None: - chunks = text_to_sentence_chunks("Hello world. Goodbye.", sentences_per_chunk=5, doc_id="doc1") - assert len(chunks) == 1 - assert chunks[0]["doc_id"] == "doc1" - - -def test_text_to_sentence_chunks_empty() -> None: - assert text_to_sentence_chunks("") == [] - - -# --------------------------------------------------------------------------- -# Section strategies -# --------------------------------------------------------------------------- - - -def test_chunks_to_sections_sequential() -> None: - chunks = [{"text": f"chunk {i}", "chunk_id": i} for i in range(1, 7)] - sections = chunks_to_sections_structured(chunks, num_sections=2, strategy="sequential") - assert len(sections) == 2 - assert "Section 1" in sections[0] - assert "Section 2" in sections[1] - - -def test_chunks_to_sections_empty() -> None: - assert chunks_to_sections_structured([], num_sections=2) == [] - - -# --------------------------------------------------------------------------- -# build_bundles -# --------------------------------------------------------------------------- - - -def test_build_bundles_sequential(tmp_path: Path) -> None: - files = [tmp_path / f"f{i}.txt" for i in range(4)] - for f in files: - f.write_text("content") - bundles = build_bundles(files, bundle_size=2, max_docs_per_bundle=3) - assert len(bundles) == 2 - assert len(bundles[0]) == 2 - - -def test_build_bundles_exceeds_max(tmp_path: Path) -> None: - files = [tmp_path / f"f{i}.txt" for i in range(4)] - for f in files: - f.write_text("content") - with pytest.raises(ValueError, match="exceeds max_docs_per_bundle"): - build_bundles(files, bundle_size=4, max_docs_per_bundle=2) - - -# --------------------------------------------------------------------------- -# build_bundle_id -# --------------------------------------------------------------------------- - - -def test_build_bundle_id_deterministic() -> None: - a = build_bundle_id(["a.txt", "b.txt"]) - b = build_bundle_id(["b.txt", "a.txt"]) - assert a == b - - -def test_build_bundle_id_empty() -> None: - assert build_bundle_id([]) == "" - - -# --------------------------------------------------------------------------- -# load_text_files_from_directory -# --------------------------------------------------------------------------- - - -def test_load_text_files_single_doc(tmp_path: Path) -> None: - (tmp_path / "a.txt").write_text("Hello world. This is a test. Another sentence.") - (tmp_path / "b.txt").write_text("Foo bar. Baz quux. Something else.") - df = load_text_files_from_directory(tmp_path, sentences_per_chunk=2) - assert len(df) == 2 - assert "file_name" in df.columns - assert "chunks" in df.columns - assert "sections_structured" in df.columns - - -def test_load_text_files_no_files(tmp_path: Path) -> None: - with pytest.raises(ValueError, match="No text files found"): - load_text_files_from_directory(tmp_path) diff --git a/plugins/data-designer-retrieval-sdg/tests/test_plugin.py b/plugins/data-designer-retrieval-sdg/tests/test_plugin.py index 89855a2..d256f80 100644 --- a/plugins/data-designer-retrieval-sdg/tests/test_plugin.py +++ b/plugins/data-designer-retrieval-sdg/tests/test_plugin.py @@ -1,10 +1,16 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +"""Plugin-registration tests for both entry points.""" + from data_designer.engine.testing.utils import assert_valid_plugin -from data_designer_retrieval_sdg.plugin import plugin +from data_designer_retrieval_sdg.plugins import document_chunker_plugin, embedding_dedup_plugin + + +def test_embedding_dedup_plugin_valid() -> None: + assert_valid_plugin(embedding_dedup_plugin) -def test_valid_plugin() -> None: - assert_valid_plugin(plugin) +def test_document_chunker_plugin_valid() -> None: + assert_valid_plugin(document_chunker_plugin) diff --git a/plugins/data-designer-retrieval-sdg/tests/test_seed_reader.py b/plugins/data-designer-retrieval-sdg/tests/test_seed_reader.py new file mode 100644 index 0000000..e7ac829 --- /dev/null +++ b/plugins/data-designer-retrieval-sdg/tests/test_seed_reader.py @@ -0,0 +1,119 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for :class:`DocumentChunkerSeedReader`.""" + +from pathlib import Path + +import pytest +from data_designer.engine.resources.seed_reader import SeedReaderError +from data_designer.engine.secret_resolver import PlaintextResolver + +from data_designer_retrieval_sdg.seed_reader import DocumentChunkerSeedReader +from data_designer_retrieval_sdg.seed_source import DocumentChunkerSeedSource + + +def _attached_reader(source: DocumentChunkerSeedSource) -> DocumentChunkerSeedReader: + reader = DocumentChunkerSeedReader() + reader.attach(source, PlaintextResolver()) + return reader + + +def _write_sample_files(root: Path) -> None: + (root / "a.txt").write_text("First doc. Has two sentences.") + (root / "b.txt").write_text("Second doc. Has three sentences. Done.") + (root / "skip.bin").write_text("ignored") + nested = root / "nested" + nested.mkdir() + (nested / "c.md").write_text("Nested doc content. Another sentence.") + + +def test_single_doc_manifest_and_hydration(tmp_path: Path) -> None: + _write_sample_files(tmp_path) + source = DocumentChunkerSeedSource( + path=str(tmp_path), + file_extensions=[".txt", ".md"], + sentences_per_chunk=1, + ) + reader = _attached_reader(source) + + assert reader.get_seed_dataset_size() == 3 + + output_df = reader._get_output_dataframe() + assert sorted(output_df.columns) == sorted(DocumentChunkerSeedReader.output_columns) + assert len(output_df) == 3 + + first = output_df.iloc[0].to_dict() + assert first["is_multi_doc"] is False + assert isinstance(first["file_name"], list) + assert len(first["file_name"]) == 1 + assert first["bundle_members"] == first["file_name"] + assert first["bundle_id"] == "" + assert first["chunks"], "expected non-empty chunk list" + + +def test_extension_filtering(tmp_path: Path) -> None: + _write_sample_files(tmp_path) + source = DocumentChunkerSeedSource( + path=str(tmp_path), + file_extensions=[".md"], + ) + reader = _attached_reader(source) + assert reader.get_seed_dataset_size() == 1 + + +def test_min_text_length_drops_short_files(tmp_path: Path) -> None: + (tmp_path / "tiny.txt").write_text("hi.") + (tmp_path / "long.txt").write_text("This is a much longer document. It has many sentences. Good.") + source = DocumentChunkerSeedSource( + path=str(tmp_path), + file_extensions=[".txt"], + min_text_length=20, + ) + reader = _attached_reader(source) + output_df = reader._get_output_dataframe() + assert len(output_df) == 1 + assert output_df.iloc[0]["file_name"] == ["long.txt"] + + +def test_num_files_caps_manifest(tmp_path: Path) -> None: + for i in range(5): + (tmp_path / f"d{i}.txt").write_text(f"Content {i}. More text.") + source = DocumentChunkerSeedSource( + path=str(tmp_path), + file_extensions=[".txt"], + num_files=2, + ) + reader = _attached_reader(source) + assert reader.get_seed_dataset_size() == 2 + + +def test_no_matching_files_raises(tmp_path: Path) -> None: + (tmp_path / "ignored.bin").write_text("x") + source = DocumentChunkerSeedSource( + path=str(tmp_path), + file_extensions=[".txt"], + ) + reader = _attached_reader(source) + with pytest.raises(SeedReaderError): + reader.get_seed_dataset_size() + + +def test_multi_doc_bundles(tmp_path: Path) -> None: + for i in range(4): + (tmp_path / f"d{i}.txt").write_text(f"Doc {i}. Sentence two.") + source = DocumentChunkerSeedSource( + path=str(tmp_path), + file_extensions=[".txt"], + multi_doc=True, + bundle_size=2, + ) + reader = _attached_reader(source) + output_df = reader._get_output_dataframe() + + assert len(output_df) == 2 + for _, row in output_df.iterrows(): + assert row["is_multi_doc"] is True + assert len(row["bundle_members"]) == 2 + assert row["bundle_id"], "multi-doc rows must carry a non-empty bundle_id" + assert "=== Document Boundary ===" in row["text"] From dfa743bf56d0310d7dba38680347648839618c6d Mon Sep 17 00:00:00 2001 From: Steve Han Date: Thu, 30 Apr 2026 14:06:49 -0400 Subject: [PATCH 3/6] fix(data-designer-retrieval-sdg): route embedding-dedup through async LLM-wait semaphore Address PR review feedback that embedding-dedup column was bypassing the async scheduler's LLM-wait semaphore in DATA_DESIGNER_ASYNC_ENGINE mode. ColumnGeneratorCellByCell inherits is_llm_bound = False from the base ColumnGenerator, so build_llm_bound_lookup() in async_scheduler.py would skip _llm_wait_semaphore for this column and could fan out up to a full row group's worth of concurrent embedding requests at the endpoint. - Switch the base class to ColumnGeneratorWithModelRegistry so the generator reports is_llm_bound = True and gets the get_model() and get_model_config() helpers (mirrors how the framework's own EmbeddingCellGenerator is wired through ColumnGeneratorWithModel). - Pin the cell-by-cell strategy explicitly via get_generation_strategy(). - Cache the resolved ModelFacade via functools.cached_property so per-row dedup doesn't re-walk the model registry. - Override _validate() to fail fast at task construction with a DatasetGenerationError when the configured alias resolves to a non- embedding ModelConfig, instead of surfacing as an AttributeError from the facade or a 400 from the embeddings API on the first row. Tests added (TDD; verified RED before implementing): - is_llm_bound returns True - _validate accepts an embedding ModelConfig - _validate rejects a chat-completion ModelConfig with the offending alias name in the message - embedder is cached across accesses Local CI is green: 63/63 retrieval-sdg tests pass, ruff lint and format clean, ddp validate reports OK for all three plugins. End-to-end smoke run with DATA_DESIGNER_ASYNC_ENGINE=1 against examples/sample_texts confirms deduplicated_qa_pairs completes 3/3 cells, 0 failures. Made-with: Cursor Signed-off-by: Steve Han --- .../src/data_designer_retrieval_sdg/dedup.py | 53 ++++++++++++++++--- .../tests/test_dedup.py | 53 +++++++++++++++++++ 2 files changed, 98 insertions(+), 8 deletions(-) diff --git a/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/dedup.py b/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/dedup.py index dac89f6..da99556 100644 --- a/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/dedup.py +++ b/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/dedup.py @@ -10,18 +10,25 @@ from __future__ import annotations +import functools import logging from typing import Any import numpy as np -from data_designer.engine.column_generators.generators.base import ColumnGeneratorCellByCell +from data_designer.config.models import GenerationType +from data_designer.engine.column_generators.generators.base import ( + ColumnGeneratorWithModelRegistry, + GenerationStrategy, +) +from data_designer.engine.dataset_builders.errors import DatasetGenerationError +from data_designer.engine.models.facade import ModelFacade from data_designer_retrieval_sdg.config import EmbeddingDedupColumnConfig logger = logging.getLogger(__name__) -class EmbeddingDedupColumnGenerator(ColumnGeneratorCellByCell[EmbeddingDedupColumnConfig]): +class EmbeddingDedupColumnGenerator(ColumnGeneratorWithModelRegistry[EmbeddingDedupColumnConfig]): """Remove near-duplicate items from a list-valued column. For each row the generator: @@ -33,14 +40,44 @@ class EmbeddingDedupColumnGenerator(ColumnGeneratorCellByCell[EmbeddingDedupColu 4. Computes pairwise cosine similarity and greedily drops items whose similarity exceeds ``similarity_threshold``. 5. Returns the surviving items under ``self.config.name``. + + Extends :class:`ColumnGeneratorWithModelRegistry` so the column reports + ``is_llm_bound = True`` to the async scheduler. Without this, embedding + HTTP calls would bypass ``_llm_wait_semaphore`` and could fan out up to + a full row group's worth of concurrent requests at the embedding endpoint. """ - @property - def embedder(self): - """Resolve the embedding model from the resource provider.""" - return self.resource_provider.model_registry.get_model( - model_alias=self.config.model_alias, - ) + @staticmethod + def get_generation_strategy() -> GenerationStrategy: + """Each row's items are deduplicated independently.""" + return GenerationStrategy.CELL_BY_CELL + + @functools.cached_property + def embedder(self) -> ModelFacade: + """Resolve the embedding model once and cache it on the instance.""" + return self.get_model(model_alias=self.config.model_alias) + + def _validate(self) -> None: + """Fail fast at task construction if the alias isn't an embedding model. + + Without this guard, a misconfigured chat-model alias surfaces only on + the first row's embedding call as either an :class:`AttributeError` + from the facade or a 400 from the embeddings endpoint. + + Raises: + DatasetGenerationError: When ``self.config.model_alias`` resolves + to a :class:`ModelConfig` whose inference parameters are not + ``EmbeddingInferenceParams``. + """ + super()._validate() + model_config = self.get_model_config(model_alias=self.config.model_alias) + if model_config.generation_type != GenerationType.EMBEDDING: + raise DatasetGenerationError( + f"EmbeddingDedupColumnGenerator requires an embedding model, " + f"but model alias {self.config.model_alias!r} resolves to a " + f"{model_config.generation_type.value!r} model. Configure a " + f"ModelConfig with EmbeddingInferenceParams for this alias." + ) def resolve_items(self, data: dict) -> list[Any]: """Return the list of items to deduplicate from a row dict. diff --git a/plugins/data-designer-retrieval-sdg/tests/test_dedup.py b/plugins/data-designer-retrieval-sdg/tests/test_dedup.py index 4d0fff6..11fb66e 100644 --- a/plugins/data-designer-retrieval-sdg/tests/test_dedup.py +++ b/plugins/data-designer-retrieval-sdg/tests/test_dedup.py @@ -6,6 +6,14 @@ import asyncio from unittest.mock import AsyncMock, MagicMock +import pytest +from data_designer.config.models import ( + ChatCompletionInferenceParams, + EmbeddingInferenceParams, + ModelConfig, +) +from data_designer.engine.dataset_builders.errors import DatasetGenerationError + from data_designer_retrieval_sdg.config import EmbeddingDedupColumnConfig from data_designer_retrieval_sdg.dedup import EmbeddingDedupColumnGenerator @@ -151,3 +159,48 @@ def test_config_round_trip() -> None: assert cfg.items_key == "pairs" assert cfg.text_field == "question" assert cfg.similarity_threshold == 0.9 + + +def test_is_llm_bound_true() -> None: + """The column issues embedding HTTP calls and must route through the + async scheduler's LLM-wait semaphore.""" + gen = _make_generator() + assert gen.is_llm_bound is True + + +def test_validate_accepts_embedding_model() -> None: + """``_validate()`` should succeed when the configured alias resolves to + a ``ModelConfig`` whose inference parameters declare an embedding model.""" + gen = _make_generator() + gen.resource_provider.model_registry.get_model_config.return_value = ModelConfig( + alias="embed", + model="some/embedding-model", + inference_parameters=EmbeddingInferenceParams(), + ) + gen._validate() + + +def test_validate_rejects_chat_model() -> None: + """``_validate()`` should fail fast at task construction when the alias + resolves to a non-embedding model, naming the offending alias.""" + gen = _make_generator() + gen.resource_provider.model_registry.get_model_config.return_value = ModelConfig( + alias="embed", + model="some/chat-model", + inference_parameters=ChatCompletionInferenceParams(), + ) + with pytest.raises(DatasetGenerationError, match="embed"): + gen._validate() + + +def test_embedder_is_cached_across_calls() -> None: + """Repeated access should hit ``model_registry.get_model`` exactly once + so per-row dedup doesn't re-walk the registry.""" + gen = _make_generator() + gen.resource_provider.model_registry.get_model.return_value = MagicMock() + + first = gen.embedder + second = gen.embedder + + assert first is second + gen.resource_provider.model_registry.get_model.assert_called_once_with(model_alias="embed") From c39b6c09f064a81917fb8c8af37fc77d45c31521 Mon Sep 17 00:00:00 2001 From: Steve Han Date: Thu, 30 Apr 2026 14:16:58 -0400 Subject: [PATCH 4/6] fix(data-designer-retrieval-sdg): route warnings through logging instead of print chunking.py and the cli preview path were calling print() for warning/ error messages, bypassing the LoggerConfig configure_logging() sets up in the CLI. As a result, --log-level ERROR users would still see "Warning: Failed to parse multi_doc_manifest" on stdout, and the preview-error path was invisible on the configured log stream. - chunking.py: add module-level logger; the three load_multi_doc_manifest warnings (unreadable manifest, unparseable manifest, wrong shape) now emit via logger.warning(...) with %s-style args. - cli.py: add module-level logger; the best-effort preview except path now emits via logger.warning("Preview error: %s", e) so it honors --log-level and the configured stderr OutputConfig. Signed-off-by: Steve Han Made-with: Cursor --- .../src/data_designer_retrieval_sdg/chunking.py | 9 ++++++--- .../src/data_designer_retrieval_sdg/cli.py | 5 ++++- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/chunking.py b/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/chunking.py index f8dffa2..fff6ee7 100644 --- a/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/chunking.py +++ b/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/chunking.py @@ -13,6 +13,7 @@ import hashlib import json +import logging import math import re from collections import defaultdict, deque @@ -22,6 +23,8 @@ import nltk from nltk.tokenize import sent_tokenize +logger = logging.getLogger(__name__) + def load_multi_doc_manifest(manifest_path: Path | None) -> list[list[str]]: """Load a multi-doc manifest file. @@ -45,7 +48,7 @@ def load_multi_doc_manifest(manifest_path: Path | None) -> list[list[str]]: try: manifest_text = manifest_path.read_text(encoding="utf-8") except OSError as exc: - print(f"Warning: Unable to read multi_doc_manifest at {manifest_path}: {exc}") + logger.warning("Unable to read multi_doc_manifest at %s: %s", manifest_path, exc) return [] data = None @@ -55,7 +58,7 @@ def load_multi_doc_manifest(manifest_path: Path | None) -> list[list[str]]: try: data = yaml.safe_load(manifest_text) except yaml.YAMLError as exc: - print(f"Warning: Failed to parse multi_doc_manifest: {exc}") + logger.warning("Failed to parse multi_doc_manifest: %s", exc) return [] if isinstance(data, dict) and "bundles" in data: @@ -74,7 +77,7 @@ def load_multi_doc_manifest(manifest_path: Path | None) -> list[list[str]]: if clean_docs: bundles.append(clean_docs) else: - print("Warning: multi_doc_manifest must be a list or dict with 'bundles'") + logger.warning("multi_doc_manifest must be a list or dict with 'bundles'") return bundles diff --git a/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/cli.py b/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/cli.py index 19f1258..e6ef319 100644 --- a/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/cli.py +++ b/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/cli.py @@ -19,6 +19,7 @@ from __future__ import annotations import argparse +import logging import sys from pathlib import Path @@ -33,6 +34,8 @@ from data_designer_retrieval_sdg.seed_reader import DocumentChunkerSeedReader from data_designer_retrieval_sdg.seed_source import DocumentChunkerSeedSource +logger = logging.getLogger(__name__) + def _build_seed_source(args: argparse.Namespace) -> DocumentChunkerSeedSource: """Construct a :class:`DocumentChunkerSeedSource` from CLI arguments.""" @@ -243,7 +246,7 @@ def _run_preview( preview_result = data_designer.preview(config_builder, num_records=1) preview_result.display_sample_record() except Exception as e: # noqa: BLE001 - preview is best-effort UX - print(f"Preview error: {e}") + logger.warning("Preview error: %s", e) def _run_batches( From bf5d0be36b97711125e8e465410057ffceb87725 Mon Sep 17 00:00:00 2001 From: Steve Han <150830061+shan-nvidia@users.noreply.github.com> Date: Mon, 4 May 2026 12:13:03 -0400 Subject: [PATCH 5/6] Update plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/config.py Co-authored-by: Nabin Mulepati --- .../src/data_designer_retrieval_sdg/config.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/config.py b/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/config.py index adaf186..30ac3c3 100644 --- a/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/config.py +++ b/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/config.py @@ -19,7 +19,7 @@ class EmbeddingDedupColumnConfig(SingleColumnConfig): the source column is a wrapper dict (``data[source_column][items_key]``) or a bare list (``items_key=None``). - Args: + Attributes: source_column: Name of the upstream column containing the items to deduplicate. items_key: Key under ``source_column`` that holds the list of items. @@ -32,6 +32,9 @@ class EmbeddingDedupColumnConfig(SingleColumnConfig): column_type: Fixed literal identifying this column type. similarity_threshold: Cosine similarity threshold above which two items are considered duplicates. Defaults to ``0.9``. + Inherited Attributes: + name (required): Unique name of the column to be generated. + drop: If True, generate this column but remove it from the final dataset. """ source_column: str From ea96467108bd834e231fcc06f7709059abc8110c Mon Sep 17 00:00:00 2001 From: Steve Han Date: Tue, 5 May 2026 09:06:52 -0400 Subject: [PATCH 6/6] address review comments Signed-off-by: Steve Han --- .gitignore | 3 ++ .../data_designer_retrieval_sdg/chunking.py | 33 +++++++++---------- .../src/data_designer_retrieval_sdg/dedup.py | 6 ++-- .../tests/test_dedup.py | 4 +-- 4 files changed, 23 insertions(+), 23 deletions(-) diff --git a/.gitignore b/.gitignore index b2f39c4..74905e0 100644 --- a/.gitignore +++ b/.gitignore @@ -35,3 +35,6 @@ htmlcov/ # Distribution *.tar.gz + +# CI artifacts +*artifacts/ diff --git a/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/chunking.py b/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/chunking.py index fff6ee7..58596ca 100644 --- a/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/chunking.py +++ b/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/chunking.py @@ -21,6 +21,7 @@ from typing import Literal import nltk +import yaml from nltk.tokenize import sent_tokenize logger = logging.getLogger(__name__) @@ -40,8 +41,6 @@ def load_multi_doc_manifest(manifest_path: Path | None) -> list[list[str]]: Returns: List of bundles, each a list of file-path strings. """ - import yaml - if not manifest_path: return [] @@ -160,19 +159,19 @@ def build_bundles( return [b for b in bundles if b] -def group_chunks_by_doc(chunks: list[dict]) -> dict[str, list[tuple[int, dict]]]: +def group_chunks_by_doc(chunks: list[dict]) -> dict[str, list[dict]]: """Group chunks by their ``doc_id`` field.""" - grouped: dict[str, list[tuple[int, dict]]] = defaultdict(list) - for idx, chunk in enumerate(chunks): + grouped: dict[str, list[dict]] = defaultdict(list) + for chunk in chunks: doc_id = chunk.get("doc_id", "default") - grouped[doc_id].append((idx, chunk)) + grouped[doc_id].append(chunk) return dict(grouped) -def format_section_chunks(indexed_chunks: list[tuple[int, dict]], section_number: int) -> str: - """Render a list of indexed chunks into a section string.""" +def format_section_chunks(section_chunks: list[dict], section_number: int) -> str: + """Render a list of chunks into a section string.""" section_lines: list[str] = [] - for _, chunk in indexed_chunks: + for chunk in section_chunks: text = chunk.get("text", "").strip() if not text: continue @@ -203,8 +202,7 @@ def chunks_to_sections_sequential(chunks: list[dict], num_sections: int = 1) -> for i in range(num_sections): start_idx = i * section_size end_idx = (i + 1) * section_size if i < num_sections - 1 else total - indexed_chunks = [(j, chunks[j]) for j in range(start_idx, end_idx)] - section_text = format_section_chunks(indexed_chunks, i + 1) + section_text = format_section_chunks(chunks[start_idx:end_idx], i + 1) if section_text: formatted_sections.append(section_text) @@ -222,9 +220,9 @@ def chunks_to_sections_doc_balanced(chunks: list[dict], num_sections: int = 1) - chunk_sizes = {doc_id: max(1, math.ceil(len(entries) / num_sections)) for doc_id, entries in grouped.items()} - sections: list[list[tuple[int, dict]]] = [] + sections: list[list[dict]] = [] for part_idx in range(num_sections): - part_entries: list[tuple[int, dict]] = [] + part_entries: list[dict] = [] for doc_id, entries in grouped.items(): chunk_size = chunk_sizes[doc_id] start = part_idx * chunk_size @@ -235,8 +233,8 @@ def chunks_to_sections_doc_balanced(chunks: list[dict], num_sections: int = 1) - sections.append(part_entries) formatted_sections: list[str] = [] - for i, indexed_chunks in enumerate(sections): - section_text = format_section_chunks(indexed_chunks, i + 1) + for i, section_chunks in enumerate(sections): + section_text = format_section_chunks(section_chunks, i + 1) if section_text: formatted_sections.append(section_text) @@ -254,7 +252,7 @@ def chunks_to_sections_interleaved(chunks: list[dict], num_sections: int = 1) -> doc_iterators = {doc_id: deque(entries) for doc_id, entries in grouped.items()} doc_order = list(grouped.keys()) - interleaved: list[tuple[int, dict]] = [] + interleaved: list[dict] = [] while True: added = False @@ -276,8 +274,7 @@ def chunks_to_sections_interleaved(chunks: list[dict], num_sections: int = 1) -> for i in range(num_sections): start_idx = i * section_size end_idx = (i + 1) * section_size if i < num_sections - 1 else total - indexed_chunks = interleaved[start_idx:end_idx] - section_text = format_section_chunks(indexed_chunks, i + 1) + section_text = format_section_chunks(interleaved[start_idx:end_idx], i + 1) if section_text: formatted_sections.append(section_text) diff --git a/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/dedup.py b/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/dedup.py index da99556..eca5dc1 100644 --- a/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/dedup.py +++ b/plugins/data-designer-retrieval-sdg/src/data_designer_retrieval_sdg/dedup.py @@ -15,12 +15,12 @@ from typing import Any import numpy as np +from data_designer.config.errors import BuilderConfigurationError from data_designer.config.models import GenerationType from data_designer.engine.column_generators.generators.base import ( ColumnGeneratorWithModelRegistry, GenerationStrategy, ) -from data_designer.engine.dataset_builders.errors import DatasetGenerationError from data_designer.engine.models.facade import ModelFacade from data_designer_retrieval_sdg.config import EmbeddingDedupColumnConfig @@ -65,14 +65,14 @@ def _validate(self) -> None: from the facade or a 400 from the embeddings endpoint. Raises: - DatasetGenerationError: When ``self.config.model_alias`` resolves + BuilderConfigurationError: When ``self.config.model_alias`` resolves to a :class:`ModelConfig` whose inference parameters are not ``EmbeddingInferenceParams``. """ super()._validate() model_config = self.get_model_config(model_alias=self.config.model_alias) if model_config.generation_type != GenerationType.EMBEDDING: - raise DatasetGenerationError( + raise BuilderConfigurationError( f"EmbeddingDedupColumnGenerator requires an embedding model, " f"but model alias {self.config.model_alias!r} resolves to a " f"{model_config.generation_type.value!r} model. Configure a " diff --git a/plugins/data-designer-retrieval-sdg/tests/test_dedup.py b/plugins/data-designer-retrieval-sdg/tests/test_dedup.py index 11fb66e..ef9c329 100644 --- a/plugins/data-designer-retrieval-sdg/tests/test_dedup.py +++ b/plugins/data-designer-retrieval-sdg/tests/test_dedup.py @@ -7,12 +7,12 @@ from unittest.mock import AsyncMock, MagicMock import pytest +from data_designer.config.errors import BuilderConfigurationError from data_designer.config.models import ( ChatCompletionInferenceParams, EmbeddingInferenceParams, ModelConfig, ) -from data_designer.engine.dataset_builders.errors import DatasetGenerationError from data_designer_retrieval_sdg.config import EmbeddingDedupColumnConfig from data_designer_retrieval_sdg.dedup import EmbeddingDedupColumnGenerator @@ -189,7 +189,7 @@ def test_validate_rejects_chat_model() -> None: model="some/chat-model", inference_parameters=ChatCompletionInferenceParams(), ) - with pytest.raises(DatasetGenerationError, match="embed"): + with pytest.raises(BuilderConfigurationError, match="embed"): gen._validate()