diff --git a/Cargo.lock b/Cargo.lock index 14da56f..5e532d3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -523,6 +523,7 @@ dependencies = [ "rayon", "serde", "serde_json", + "tempfile", "tokio", ] diff --git a/Cargo.toml b/Cargo.toml index 72aa921..804e8a2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -47,6 +47,7 @@ extension-module = ["pyo3/extension-module"] bytes = "1" criterion = "0.8" proptest = "1.9" +tempfile = "3" [[bench]] name = "generators" diff --git a/README.md b/README.md index ac3f5ea..caeb661 100644 --- a/README.md +++ b/README.md @@ -673,6 +673,73 @@ sql = records_sql(1000, { For large batches, multiple INSERT statements are generated with up to 1000 rows each. Column names are double-quoted and string values use single-quote escaping. +### Streaming File Writer + +For datasets that exceed available memory, `records_to_file()` generates records +in bounded-memory chunks and writes each chunk to disk before generating the next. +Memory usage is proportional to `chunk_size`, not total `n`. + +```python +from forgery import Faker + +fake = Faker() +fake.seed(42) + +# Generate 100 million records — memory stays at ~500-800 MB +fake.records_to_file( + 100_000_000, + {"id": "uuid", "name": "name", "amount": ("float", 0.01, 9999.99)}, + "transactions.parquet", + chunk_size=1_000_000, # records per chunk (default: 1M, max: 10M) +) +``` + +**Supported formats:** CSV (`.csv`), NDJSON (`.ndjson`/`.jsonl`), SQL (`.sql`), +Parquet (`.parquet`). Format is auto-detected from the file extension, or set +explicitly with `format="csv"`. + +SQL format requires a `table` parameter: + +```python +from forgery import records_to_file, seed + +seed(42) +records_to_file( + 50_000_000, + {"name": "name", "email": "email"}, + "users.sql", + table="users", + chunk_size=500_000, +) +``` + +**Progress callback** — track progress with an optional callback: + +```python +from forgery import records_to_file, seed + +seed(42) +records_to_file( + 10_000_000, + {"name": "name", "email": "email"}, + "users.csv", + on_progress=lambda written, total: print(f"\r{written/total:.0%}", end=""), +) +``` + +**Memory estimation** — plan chunk sizes based on available RAM: + +```python +from forgery import Faker + +schema = {"id": "uuid", "name": "name", "amount": ("float", 0.01, 9999.99)} +est = Faker.estimate_memory(1_000_000, schema) +print(f"~{est / 1024**2:.0f} MB per 1M records") +``` + +All streaming formats use row-major generation, so the same seed produces +identical data across CSV, NDJSON, SQL, and Parquet output. + ### Schema Field Types | Type | Syntax | Example | diff --git a/python/forgery/__init__.py b/python/forgery/__init__.py index 7cfb396..2085a79 100644 --- a/python/forgery/__init__.py +++ b/python/forgery/__init__.py @@ -57,7 +57,7 @@ >>> german_fake.names(10) # German names """ -from collections.abc import Coroutine +from collections.abc import Callable, Coroutine from typing import TYPE_CHECKING, Any from forgery._forgery import Faker @@ -129,6 +129,7 @@ "ean13s", "email", "emails", + "estimate_memory", "fake", "file_extension", "file_extensions", @@ -225,6 +226,7 @@ "records_ndjson", "records_parquet", "records_sql", + "records_to_file", "records_tuples", "records_tuples_async", "remove_provider", @@ -1491,6 +1493,80 @@ def records_sql(n: int, schema: Schema, table: str) -> str: return fake.records_sql(n, schema, table) +# === Streaming File Writer === + + +def records_to_file( + n: int, + schema: Schema, + path: str, + *, + format: str | None = None, + chunk_size: int | None = None, + table: str | None = None, + on_progress: Callable[[int, int], None] | None = None, +) -> int: + """Generate records and stream them to a file in chunks. + + Memory stays bounded by chunk_size regardless of total n, enabling + generation of datasets far larger than available RAM. + + Supported formats: csv, ndjson, sql, parquet (auto-detected from + file extension, or specified explicitly). + + Args: + n: Total number of records to generate. + schema: Dictionary mapping field names to type specifications. + path: Output file path. + format: Output format ("csv", "ndjson", "sql", "parquet") or + None to auto-detect from the file extension. + chunk_size: Records per chunk (default: 1,000,000). Max: 10,000,000. + table: Table name (required for SQL format, ignored otherwise). + on_progress: Optional callback(records_written, total) called + after each chunk is written. + + Returns: + The total number of records written. + + Example: + >>> from forgery import records_to_file, seed + >>> seed(42) + >>> records_to_file(100, {"name": "name", "email": "email"}, "/tmp/test.csv") + 100 + """ + return fake.records_to_file( + n, + schema, + path, + format=format, + chunk_size=chunk_size, + table=table, + on_progress=on_progress, + ) + + +def estimate_memory(n: int, schema: Schema) -> int: + """Estimate memory usage in bytes for generating n records. + + Provides a rough estimate based on average field sizes. Useful for + deciding chunk_size for records_to_file(). + + Args: + n: Number of records. + schema: Dictionary mapping field names to type specifications. + + Returns: + Estimated memory in bytes. + + Example: + >>> from forgery import estimate_memory + >>> est = estimate_memory(1_000_000, {"name": "name", "email": "email"}) + >>> est > 0 + True + """ + return Faker.estimate_memory(n, schema) + + # === Async Records Generation === diff --git a/python/forgery/__init__.pyi b/python/forgery/__init__.pyi index 67b5175..6014273 100644 --- a/python/forgery/__init__.pyi +++ b/python/forgery/__init__.pyi @@ -1,6 +1,6 @@ """Type stubs for the forgery package.""" -from collections.abc import Coroutine +from collections.abc import Callable, Coroutine from typing import Any from forgery._forgery import CreditCardFull as CreditCardFull @@ -752,6 +752,52 @@ def records_sql(n: int, schema: Schema, table: str) -> str: """ ... +# Streaming file writer + +def records_to_file( + n: int, + schema: Schema, + path: str, + *, + format: str | None = None, + chunk_size: int | None = None, + table: str | None = None, + on_progress: Callable[[int, int], None] | None = None, +) -> int: + """Generate records and stream them to a file in chunks. + + Memory stays bounded by chunk_size regardless of total n. + + Args: + n: Total number of records to generate. + schema: Dictionary mapping field names to type specifications. + path: Output file path. + format: Output format ("csv", "ndjson", "sql", "parquet") or None. + chunk_size: Records per chunk (default: 1,000,000). + table: Table name (required for SQL format). + on_progress: Optional callback(records_written, total). + + Returns: + The total number of records written. + + Raises: + ValueError: If schema is invalid or format is unsupported. + OSError: If file cannot be created. + """ + ... + +def estimate_memory(n: int, schema: Schema) -> int: + """Estimate memory usage in bytes for generating n records. + + Args: + n: Number of records. + schema: Dictionary mapping field names to type specifications. + + Returns: + Estimated memory in bytes. + """ + ... + # Async Records generation def records_async( diff --git a/python/forgery/_forgery.pyi b/python/forgery/_forgery.pyi index 28c62cb..06dec33 100644 --- a/python/forgery/_forgery.pyi +++ b/python/forgery/_forgery.pyi @@ -1,7 +1,7 @@ """Type stubs for the Rust extension module.""" import builtins -from collections.abc import Coroutine +from collections.abc import Callable, Coroutine from typing import Any, TypedDict class CreditCardFull(TypedDict): @@ -1095,6 +1095,75 @@ class Faker: """ ... + # Serialized output formats + def records_csv(self, n: int, schema: Schema) -> str: + """Generate records as a CSV string with header row.""" + ... + + def records_json(self, n: int, schema: Schema) -> str: + """Generate records as a JSON array string.""" + ... + + def records_ndjson(self, n: int, schema: Schema) -> str: + """Generate records as newline-delimited JSON.""" + ... + + def records_parquet(self, n: int, schema: Schema) -> bytes: + """Generate records as Parquet file bytes.""" + ... + + def records_sql(self, n: int, schema: Schema, table: str) -> str: + """Generate records as SQL INSERT statements.""" + ... + + # Streaming file writer + def records_to_file( + self, + n: int, + schema: Schema, + path: str, + format: str | None = None, + chunk_size: int | None = None, + table: str | None = None, + on_progress: Callable[[int, int], None] | None = None, + ) -> int: + """Generate records and stream them to a file in chunks. + + Memory stays bounded by chunk_size regardless of total n. + Supports: csv, ndjson, sql, parquet (auto-detected from file extension). + + Args: + n: Total number of records to generate. + schema: Dictionary mapping field names to type specifications. + path: Output file path. + format: Output format or None for auto-detect from extension. + chunk_size: Records per chunk (default: 1,000,000). + table: Table name (required for SQL format, ignored otherwise). + on_progress: Optional callback(records_written, total) after each chunk. + + Returns: + The total number of records written. + + Raises: + ValueError: If schema is invalid, format is unsupported, or SQL + format is used without a table name. + OSError: If the file cannot be created or written. + """ + ... + + @staticmethod + def estimate_memory(n: int, schema: Schema) -> int: + """Estimate memory usage in bytes for generating n records. + + Args: + n: Number of records. + schema: Schema dictionary. + + Returns: + Estimated memory in bytes. + """ + ... + # Async records generators def records_async( self, n: int, schema: Schema, chunk_size: int | None = None diff --git a/src/lib.rs b/src/lib.rs index 18cae74..12cc100 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2153,6 +2153,53 @@ impl Faker { )?) } + // === Streaming File Writer === + + /// Generate records and write them to a file in chunks. + /// + /// Memory usage is bounded by `chunk_size` regardless of total `n`. + /// The `MAX_BATCH_SIZE` limit applies to `chunk_size`, not to `n`. + /// + /// # Errors + /// + /// Returns an error if the schema is invalid, the file cannot be created, + /// or a write operation fails. + #[allow(clippy::too_many_arguments)] + pub fn records_to_file( + &mut self, + n: usize, + schema: &BTreeMap, + path: &std::path::Path, + format: providers::file_writer::OutputFormat, + chunk_size: Option, + table: Option<&str>, + progress_callback: Option<&dyn Fn(usize, usize) -> Result<(), String>>, + ) -> Result { + let cs = chunk_size.unwrap_or(providers::file_writer::DEFAULT_FILE_CHUNK_SIZE); + providers::file_writer::records_to_file( + &mut self.rng, + self.locale, + n, + schema, + &self.custom_providers, + path, + format, + cs, + table, + progress_callback, + ) + } + + /// Estimate memory usage in bytes for generating `n` records with a schema. + /// + /// Useful for deciding `chunk_size` for `records_to_file()`. + pub fn estimate_memory( + n: usize, + schema: &BTreeMap, + ) -> usize { + providers::file_writer::estimate_memory(n, schema) + } + // === Custom Providers === /// Register a custom provider with uniform random selection. @@ -4302,6 +4349,139 @@ impl Faker { .map_err(|e| PyValueError::new_err(e.to_string())) } + // ============================================================================ + // Streaming File Writer + // ============================================================================ + + /// Generate records and stream them to a file in chunks. + /// + /// Memory stays bounded by chunk_size regardless of total n. + /// Supports: csv, ndjson, sql, parquet (auto-detected from file extension). + /// + /// Example: + /// ```python + /// from forgery import Faker + /// f = Faker() + /// f.seed(42) + /// f.records_to_file( + /// 100_000_000, + /// {"name": "name", "email": "email"}, + /// "users.csv", + /// chunk_size=1_000_000, + /// ) + /// ``` + #[allow(clippy::too_many_arguments)] + #[pyo3(name = "records_to_file", signature = (n, schema, path, format=None, chunk_size=None, table=None, on_progress=None))] + fn py_records_to_file( + &mut self, + py: Python<'_>, + n: usize, + schema: &Bound<'_, PyDict>, + path: &str, + format: Option<&str>, + chunk_size: Option, + table: Option<&str>, + on_progress: Option<&Bound<'_, PyAny>>, + ) -> PyResult { + let custom_names = self.custom_provider_names(); + let rust_schema = parse_py_schema_with_custom(schema, &custom_names)?; + + // Determine format + let output_format = match format { + Some(f) => providers::file_writer::OutputFormat::from_name(f) + .map_err(|e| PyValueError::new_err(e.to_string()))?, + None => providers::file_writer::OutputFormat::from_extension(path) + .map_err(|e| PyValueError::new_err(e.to_string()))?, + }; + + // Validate chunk_size (not n — the whole point is n can exceed MAX_BATCH_SIZE) + let cs = chunk_size.unwrap_or(providers::file_writer::DEFAULT_FILE_CHUNK_SIZE); + validate_batch_size(cs).map_err(|e| PyValueError::new_err(e.to_string()))?; + + let file_path = std::path::PathBuf::from(path); + + // Build progress callback that calls the Python callable. + // GIL is held throughout, so we can call the Python callback directly. + // + // To preserve the original Python exception type and traceback, we stash + // the PyErr in a RefCell when the callback fails. The Rust side sees a + // generic error string to abort the loop, and we re-raise the original + // PyErr after records_to_file returns. + use std::cell::RefCell; + let cb_ref = on_progress.map(|cb| cb.clone().unbind()); + let stashed_pyerr: RefCell> = RefCell::new(None); + let progress_closure = |written: usize, total: usize| -> Result<(), String> { + cb_ref + .as_ref() + .expect("progress_closure called without callback") + .call1(py, (written, total)) + .map(|_| ()) + .map_err(|e| { + let msg = e.to_string(); + *stashed_pyerr.borrow_mut() = Some(e); + msg + }) + }; + let progress_arg: Option<&dyn Fn(usize, usize) -> Result<(), String>> = if cb_ref.is_some() + { + Some(&progress_closure) + } else { + None + }; + + let result = providers::file_writer::records_to_file( + &mut self.rng, + self.locale, + n, + &rust_schema, + &self.custom_providers, + &file_path, + output_format, + cs, + table, + progress_arg, + ); + + // If the callback stashed a PyErr, re-raise it with the original type + if let Some(pyerr) = stashed_pyerr.into_inner() { + return Err(pyerr); + } + + result.map_err(|e| match e { + providers::file_writer::FileWriteError::Io(ref io_err) => { + let kind = io_err.kind(); + match kind { + std::io::ErrorKind::NotFound => { + pyo3::exceptions::PyFileNotFoundError::new_err(e.to_string()) + } + std::io::ErrorKind::PermissionDenied => { + pyo3::exceptions::PyPermissionError::new_err(e.to_string()) + } + _ => pyo3::exceptions::PyOSError::new_err(e.to_string()), + } + } + _ => PyValueError::new_err(e.to_string()), + }) + } + + /// Estimate memory usage in bytes for generating n records with a schema. + /// + /// Useful for deciding chunk_size for records_to_file(). + /// + /// Example: + /// ```python + /// from forgery import Faker + /// est = Faker.estimate_memory(1_000_000, {"name": "name", "email": "email"}) + /// print(f"~{est / 1024 / 1024:.0f} MB") + /// ``` + #[pyo3(name = "estimate_memory")] + #[staticmethod] + fn py_estimate_memory(n: usize, schema: &Bound<'_, PyDict>) -> PyResult { + let custom_names = HashSet::new(); + let rust_schema = parse_py_schema_with_custom(schema, &custom_names)?; + Ok(providers::file_writer::estimate_memory(n, &rust_schema)) + } + // ============================================================================ // Async Methods // ============================================================================ diff --git a/src/providers/file_writer.rs b/src/providers/file_writer.rs new file mode 100644 index 0000000..0135c43 --- /dev/null +++ b/src/providers/file_writer.rs @@ -0,0 +1,1208 @@ +//! Streaming file writer for bounded-memory record generation. +//! +//! Generates records in chunks and writes each chunk to a file before +//! generating the next. Memory stays bounded by `chunk_size` regardless +//! of the total number of records. +//! +//! Supports CSV, NDJSON, SQL, and Parquet output formats. + +use std::collections::{BTreeMap, HashMap}; +use std::fmt; +use std::fs::File; +use std::io::{BufWriter, Write}; +use std::path::Path; +use std::sync::Arc; + +use arrow_schema::{Field, Schema}; + +use crate::locale::Locale; +use crate::providers::custom::CustomProvider; +use crate::providers::records::{ + field_spec_to_arrow_type, generate_records_with_custom, records_to_record_batch, FieldSpec, + SchemaError, Value, +}; +use crate::providers::serialize::{ + record_to_json_object, sql_quote_identifier, value_to_sql, SQL_BATCH_SIZE, +}; +use crate::rng::ForgeryRng; + +/// Default chunk size for streaming file generation (1 million records). +pub const DEFAULT_FILE_CHUNK_SIZE: usize = 1_000_000; + +// ============================================================================ +// Output format +// ============================================================================ + +/// Supported streaming output formats. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum OutputFormat { + /// CSV with RFC 4180 quoting. + Csv, + /// Newline-delimited JSON (one JSON object per line). + Ndjson, + /// ANSI SQL INSERT statements. + Sql, + /// Apache Parquet (row-major data written as row groups). + Parquet, +} + +impl OutputFormat { + /// Detect format from a file extension. + /// + /// # Errors + /// + /// Returns an error for unsupported extensions (including `.json`). + pub fn from_extension(path: &str) -> Result { + let ext = Path::new(path) + .extension() + .and_then(|e| e.to_str()) + .unwrap_or("") + .to_lowercase(); + + match ext.as_str() { + "csv" => Ok(OutputFormat::Csv), + "ndjson" | "jsonl" => Ok(OutputFormat::Ndjson), + "sql" => Ok(OutputFormat::Sql), + "parquet" => Ok(OutputFormat::Parquet), + "json" => Err(FileWriteError::Config( + "JSON array format is not streaming-friendly; use .ndjson instead".to_string(), + )), + _ => Err(FileWriteError::Config(format!( + "cannot auto-detect format from extension '.{}'; \ + use format='csv'|'ndjson'|'sql'|'parquet'", + ext + ))), + } + } + + /// Parse a format name string. + /// + /// # Errors + /// + /// Returns an error for unsupported format names. + pub fn from_name(name: &str) -> Result { + match name.to_lowercase().as_str() { + "csv" => Ok(OutputFormat::Csv), + "ndjson" | "jsonl" => Ok(OutputFormat::Ndjson), + "sql" => Ok(OutputFormat::Sql), + "parquet" => Ok(OutputFormat::Parquet), + "json" => Err(FileWriteError::Config( + "JSON array format is not streaming-friendly; use 'ndjson' instead".to_string(), + )), + _ => Err(FileWriteError::Config(format!( + "unsupported format '{}'; use 'csv', 'ndjson', 'sql', or 'parquet'", + name + ))), + } + } +} + +// ============================================================================ +// Error type +// ============================================================================ + +/// Error type for file writing operations. +#[derive(Debug)] +pub enum FileWriteError { + /// Schema validation error. + Schema(SchemaError), + /// File I/O error. + Io(std::io::Error), + /// Format-specific error (e.g., Parquet serialization). + Format(String), + /// Invalid configuration (e.g., missing table name for SQL). + Config(String), +} + +impl fmt::Display for FileWriteError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + FileWriteError::Schema(e) => write!(f, "{}", e), + FileWriteError::Io(e) => write!(f, "I/O error: {}", e), + FileWriteError::Format(msg) => write!(f, "format error: {}", msg), + FileWriteError::Config(msg) => write!(f, "configuration error: {}", msg), + } + } +} + +impl std::error::Error for FileWriteError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + FileWriteError::Schema(e) => Some(e), + FileWriteError::Io(e) => Some(e), + _ => None, + } + } +} + +impl From for FileWriteError { + fn from(err: SchemaError) -> Self { + FileWriteError::Schema(err) + } +} + +impl From for FileWriteError { + fn from(err: std::io::Error) -> Self { + FileWriteError::Io(err) + } +} + +// ============================================================================ +// Chunk writers +// ============================================================================ + +/// Trait for format-specific chunk writers. +trait ChunkWriter { + /// Write the file header (e.g., CSV header row). Called once before any chunks. + fn write_header(&mut self) -> Result<(), FileWriteError>; + + /// Write a chunk of records. + fn write_chunk(&mut self, records: &[BTreeMap]) -> Result<(), FileWriteError>; + + /// Finalize the file. Called once after all chunks. + fn finish(&mut self) -> Result<(), FileWriteError>; +} + +// --- CSV --- + +/// Writes CSV chunks with RFC 4180 quoting via the `csv` crate. +struct CsvChunkWriter { + wtr: csv::Writer>, + keys: Vec, +} + +impl CsvChunkWriter { + fn new(file: File, keys: Vec) -> Self { + Self { + wtr: csv::Writer::from_writer(BufWriter::new(file)), + keys, + } + } +} + +impl ChunkWriter for CsvChunkWriter { + fn write_header(&mut self) -> Result<(), FileWriteError> { + self.wtr + .write_record(&self.keys) + .map_err(|e| FileWriteError::Format(e.to_string())) + } + + fn write_chunk(&mut self, records: &[BTreeMap]) -> Result<(), FileWriteError> { + for record in records { + let row: Vec = self.keys.iter().map(|k| record[k].as_string()).collect(); + self.wtr + .write_record(&row) + .map_err(|e| FileWriteError::Format(e.to_string()))?; + } + Ok(()) + } + + fn finish(&mut self) -> Result<(), FileWriteError> { + self.wtr + .flush() + .map_err(|e| FileWriteError::Format(e.to_string())) + } +} + +// --- NDJSON --- + +/// Writes one JSON object per line. +struct NdjsonChunkWriter { + writer: BufWriter, +} + +impl NdjsonChunkWriter { + fn new(file: File) -> Self { + Self { + writer: BufWriter::new(file), + } + } +} + +impl ChunkWriter for NdjsonChunkWriter { + fn write_header(&mut self) -> Result<(), FileWriteError> { + Ok(()) // no header + } + + fn write_chunk(&mut self, records: &[BTreeMap]) -> Result<(), FileWriteError> { + for record in records { + let json = serde_json::to_string(&record_to_json_object(record)) + .map_err(|e| FileWriteError::Format(e.to_string()))?; + self.writer.write_all(json.as_bytes())?; + self.writer.write_all(b"\n")?; + } + Ok(()) + } + + fn finish(&mut self) -> Result<(), FileWriteError> { + self.writer.flush()?; + Ok(()) + } +} + +// --- SQL --- + +/// Writes ANSI SQL INSERT statements, batched at `SQL_BATCH_SIZE` rows. +struct SqlChunkWriter { + writer: BufWriter, + quoted_table: String, + columns: String, + keys: Vec, + /// Whether we've written at least one INSERT statement. + has_content: bool, +} + +impl SqlChunkWriter { + fn new(file: File, table: &str, keys: Vec) -> Self { + let columns = keys + .iter() + .map(|k| sql_quote_identifier(k)) + .collect::>() + .join(", "); + Self { + writer: BufWriter::new(file), + quoted_table: sql_quote_identifier(table), + columns, + keys, + has_content: false, + } + } +} + +impl ChunkWriter for SqlChunkWriter { + fn write_header(&mut self) -> Result<(), FileWriteError> { + Ok(()) // no header + } + + fn write_chunk(&mut self, records: &[BTreeMap]) -> Result<(), FileWriteError> { + if records.is_empty() { + return Ok(()); + } + + for sql_batch in records.chunks(SQL_BATCH_SIZE) { + if self.has_content { + self.writer.write_all(b"\n")?; + } + self.has_content = true; + + writeln!( + self.writer, + "INSERT INTO {} ({}) VALUES", + self.quoted_table, self.columns + )?; + + for (i, record) in sql_batch.iter().enumerate() { + let values = self + .keys + .iter() + .map(|k| value_to_sql(&record[k])) + .collect::>() + .join(", "); + + if i < sql_batch.len() - 1 { + writeln!(self.writer, "({values}),")?; + } else { + writeln!(self.writer, "({values});")?; + } + } + } + Ok(()) + } + + fn finish(&mut self) -> Result<(), FileWriteError> { + self.writer.flush()?; + Ok(()) + } +} + +// --- Parquet --- + +/// Writes Parquet row groups from row-major records converted to Arrow. +struct ParquetChunkWriter { + /// Wrapped in Option so `finish()` can take ownership via `.take()` to call `close()`. + writer: Option>>, + schema: BTreeMap, +} + +impl ParquetChunkWriter { + fn new(file: File, schema: &BTreeMap) -> Result { + // Build Arrow schema + let arrow_fields: Vec = schema + .iter() + .map(|(name, spec)| Field::new(name, field_spec_to_arrow_type(spec), false)) + .collect(); + let arrow_schema = Arc::new(Schema::new(arrow_fields)); + + let props = parquet::file::properties::WriterProperties::builder().build(); + let writer = + parquet::arrow::ArrowWriter::try_new(BufWriter::new(file), arrow_schema, Some(props)) + .map_err(|e| FileWriteError::Format(e.to_string()))?; + + Ok(Self { + writer: Some(writer), + schema: schema.clone(), + }) + } +} + +impl ChunkWriter for ParquetChunkWriter { + fn write_header(&mut self) -> Result<(), FileWriteError> { + Ok(()) // schema set at construction + } + + fn write_chunk(&mut self, records: &[BTreeMap]) -> Result<(), FileWriteError> { + if records.is_empty() { + return Ok(()); + } + let batch = records_to_record_batch(records, &self.schema)?; + self.writer + .as_mut() + .expect("write_chunk called after finish") + .write(&batch) + .map_err(|e| FileWriteError::Format(e.to_string()))?; + Ok(()) + } + + fn finish(&mut self) -> Result<(), FileWriteError> { + // ArrowWriter::close() takes self by value and writes the Parquet footer. + // We take it out of the Option to transfer ownership. + if let Some(writer) = self.writer.take() { + writer + .close() + .map_err(|e| FileWriteError::Format(e.to_string()))?; + } + Ok(()) + } +} + +// ============================================================================ +// Main orchestration +// ============================================================================ + +/// Write generated records to a file in chunks, keeping memory bounded. +/// +/// Generates `n` records total in chunks of `chunk_size`, writing each chunk +/// to disk before generating the next. Memory usage is proportional to +/// `chunk_size`, not `n`. +/// +/// # Arguments +/// +/// * `rng` - The random number generator +/// * `locale` - The locale for locale-aware generation +/// * `n` - Total number of records to generate +/// * `schema` - The schema specification +/// * `custom_providers` - Custom provider definitions +/// * `path` - The output file path +/// * `format` - The output format +/// * `chunk_size` - Number of records per chunk +/// * `table` - Table name for SQL format (required for SQL, ignored otherwise) +/// * `progress_callback` - Optional callback invoked with (records_written, total). +/// Return `Ok(())` to continue, or `Err(message)` to abort the write. +/// +/// # Returns +/// +/// The total number of records written. +/// +/// # Errors +/// +/// Returns an error if the schema is invalid, the file cannot be created, +/// a write operation fails, or the progress callback returns an error. +#[allow(clippy::too_many_arguments)] +pub fn records_to_file( + rng: &mut ForgeryRng, + locale: Locale, + n: usize, + schema: &BTreeMap, + custom_providers: &HashMap, + path: &Path, + format: OutputFormat, + chunk_size: usize, + table: Option<&str>, + progress_callback: Option<&dyn Fn(usize, usize) -> Result<(), String>>, +) -> Result { + // Validate chunk_size > 0 to prevent infinite loop + if chunk_size == 0 { + return Err(FileWriteError::Config( + "chunk_size must be greater than 0".to_string(), + )); + } + + // Validate SQL requires table name + if format == OutputFormat::Sql { + match table { + None => { + return Err(FileWriteError::Config( + "table name is required for SQL format".to_string(), + )) + } + Some("") => { + return Err(FileWriteError::Config( + "table name must not be empty".to_string(), + )) + } + _ => {} + } + } + + // Validate schema upfront (before creating file) so we fail fast + // and don't truncate/create an empty file on invalid input + if schema.is_empty() && format == OutputFormat::Sql { + return Err(FileWriteError::Schema(SchemaError { + message: "schema must have at least one field for SQL output".to_string(), + })); + } + crate::providers::records::validate_schema_with_custom(schema, custom_providers)?; + + // Get sorted field names (BTreeMap order) + let keys: Vec = schema.keys().cloned().collect(); + + // Create the file + let file = File::create(path)?; + + // Create format-specific writer + let mut writer: Box = match format { + OutputFormat::Csv => Box::new(CsvChunkWriter::new(file, keys)), + OutputFormat::Ndjson => Box::new(NdjsonChunkWriter::new(file)), + OutputFormat::Sql => Box::new(SqlChunkWriter::new(file, table.unwrap_or("data"), keys)), + OutputFormat::Parquet => Box::new(ParquetChunkWriter::new(file, schema)?), + }; + + // Write header + writer.write_header()?; + + // Chunk loop + let mut remaining = n; + let mut records_written: usize = 0; + + while remaining > 0 { + let this_chunk = remaining.min(chunk_size); + + let records = + generate_records_with_custom(rng, locale, this_chunk, schema, custom_providers)?; + + writer.write_chunk(&records)?; + + records_written += this_chunk; + remaining -= this_chunk; + + // Drop records before callback to free memory + drop(records); + + if let Some(cb) = progress_callback { + cb(records_written, n).map_err(FileWriteError::Config)?; + } + } + + writer.finish()?; + + Ok(records_written) +} + +// ============================================================================ +// Memory estimation +// ============================================================================ + +/// Estimate memory usage in bytes for generating `n` records with the given schema. +/// +/// Provides a rough estimate based on average field sizes. Actual usage may +/// vary by +/- 30% depending on generated values (string lengths, etc.). +/// +/// This is useful for deciding `chunk_size` for `records_to_file()`. +pub fn estimate_memory(n: usize, schema: &BTreeMap) -> usize { + // Per-record overhead: BTreeMap struct + let btree_base: usize = 24; + + let mut per_record: usize = btree_base; + + for (name, spec) in schema { + // BTreeMap node overhead per entry + let node_overhead: usize = 48; + // Key: String header (24 bytes) + heap allocation + let key_size: usize = 24 + name.len(); + // Value: enum discriminant + data + let value_size: usize = estimate_value_size(spec); + + per_record += node_overhead + key_size + value_size; + } + + // Vec overhead (ptr + len + cap) + let vec_overhead: usize = 24; + + vec_overhead + n.saturating_mul(per_record) +} + +/// Estimate the in-memory size of a single `Value` for the given field spec. +fn estimate_value_size(spec: &FieldSpec) -> usize { + // Value enum is ~32 bytes (largest variant + discriminant) + let enum_size: usize = 32; + + // Additional heap allocation for String variants + let heap_size: usize = match spec { + FieldSpec::Name => 15, + FieldSpec::FirstName | FieldSpec::LastName => 10, + FieldSpec::Email | FieldSpec::SafeEmail | FieldSpec::FreeEmail => 25, + FieldSpec::Uuid => 36, + FieldSpec::Int | FieldSpec::IntRange { .. } => 0, + FieldSpec::Float | FieldSpec::FloatRange { .. } => 0, + FieldSpec::Boolean => 0, + FieldSpec::Phone => 14, + FieldSpec::Address | FieldSpec::StreetAddress => 40, + FieldSpec::City | FieldSpec::State | FieldSpec::Country => 12, + FieldSpec::ZipCode => 5, + FieldSpec::Url | FieldSpec::DomainName => 30, + FieldSpec::Ipv4 => 15, + FieldSpec::Ipv6 => 39, + FieldSpec::MacAddress => 17, + FieldSpec::CreditCard | FieldSpec::Iban => 25, + FieldSpec::Sentence => 80, + FieldSpec::Paragraph => 500, + FieldSpec::Text { + min_chars, + max_chars, + } => (min_chars + max_chars) / 2, + FieldSpec::Md5 => 32, + FieldSpec::Sha256 => 64, + FieldSpec::Date | FieldSpec::DateTime | FieldSpec::DateRange { .. } => 20, + FieldSpec::Company | FieldSpec::Job | FieldSpec::CatchPhrase => 20, + FieldSpec::Color | FieldSpec::HexColor => 10, + FieldSpec::RgbColor => 0, + FieldSpec::Latitude | FieldSpec::Longitude => 0, + FieldSpec::Coordinate => 0, + FieldSpec::Ssn => 11, + FieldSpec::FileName | FieldSpec::FilePath => 20, + FieldSpec::FileExtension | FieldSpec::MimeType => 10, + FieldSpec::LicensePlate => 8, + FieldSpec::VehicleMake | FieldSpec::VehicleModel => 12, + FieldSpec::VehicleYear => 0, + FieldSpec::Vin => 17, + FieldSpec::Ean13 | FieldSpec::Ean8 | FieldSpec::UpcA | FieldSpec::UpcE => 13, + FieldSpec::Isbn10 | FieldSpec::Isbn13 => 17, + FieldSpec::ProductName | FieldSpec::ProductCategory | FieldSpec::Department => 15, + FieldSpec::ProductMaterial => 10, + FieldSpec::UrlPath | FieldSpec::UrlSlug | FieldSpec::QueryString => 20, + FieldSpec::Choice(opts) => { + if opts.is_empty() { + 0 + } else { + opts.iter().map(|s| s.len()).sum::() / opts.len() + } + } + FieldSpec::Custom(_) | FieldSpec::Simple(_) => 20, + }; + + enum_size + heap_size +} + +// ============================================================================ +// Tests +// ============================================================================ + +#[cfg(test)] +mod tests { + use super::*; + use crate::rng::ForgeryRng; + + fn test_rng() -> ForgeryRng { + ForgeryRng::from_seed(42) + } + + fn simple_schema() -> BTreeMap { + let mut schema = BTreeMap::new(); + schema.insert("age".to_string(), FieldSpec::IntRange { min: 18, max: 65 }); + schema.insert("name".to_string(), FieldSpec::Name); + schema + } + + // === OutputFormat tests === + + #[test] + fn test_format_from_extension() { + assert_eq!( + OutputFormat::from_extension("data.csv").unwrap(), + OutputFormat::Csv + ); + assert_eq!( + OutputFormat::from_extension("data.ndjson").unwrap(), + OutputFormat::Ndjson + ); + assert_eq!( + OutputFormat::from_extension("data.jsonl").unwrap(), + OutputFormat::Ndjson + ); + assert_eq!( + OutputFormat::from_extension("data.sql").unwrap(), + OutputFormat::Sql + ); + assert_eq!( + OutputFormat::from_extension("data.parquet").unwrap(), + OutputFormat::Parquet + ); + } + + #[test] + fn test_format_from_extension_json_error() { + let err = OutputFormat::from_extension("data.json").unwrap_err(); + assert!(err.to_string().contains("ndjson")); + } + + #[test] + fn test_format_from_extension_unknown_error() { + let err = OutputFormat::from_extension("data.xyz").unwrap_err(); + assert!(err.to_string().contains("auto-detect")); + } + + #[test] + fn test_format_from_name() { + assert_eq!(OutputFormat::from_name("csv").unwrap(), OutputFormat::Csv); + assert_eq!( + OutputFormat::from_name("NDJSON").unwrap(), + OutputFormat::Ndjson + ); + assert_eq!(OutputFormat::from_name("SQL").unwrap(), OutputFormat::Sql); + assert_eq!( + OutputFormat::from_name("parquet").unwrap(), + OutputFormat::Parquet + ); + } + + #[test] + fn test_format_from_name_json_error() { + let err = OutputFormat::from_name("json").unwrap_err(); + assert!(err.to_string().contains("ndjson")); + } + + // === CSV file writer tests === + + #[test] + fn test_csv_file_basic() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("test.csv"); + let mut rng = test_rng(); + let schema = simple_schema(); + + let count = records_to_file( + &mut rng, + Locale::EnUS, + 10, + &schema, + &HashMap::new(), + &path, + OutputFormat::Csv, + 100, + None, + None, + ) + .unwrap(); + + assert_eq!(count, 10); + let content = std::fs::read_to_string(&path).unwrap(); + let lines: Vec<&str> = content.lines().collect(); + assert_eq!(lines[0], "age,name"); // header + assert_eq!(lines.len(), 11); // 1 header + 10 data + } + + // === NDJSON file writer tests === + + #[test] + fn test_ndjson_file_basic() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("test.ndjson"); + let mut rng = test_rng(); + let schema = simple_schema(); + + let count = records_to_file( + &mut rng, + Locale::EnUS, + 10, + &schema, + &HashMap::new(), + &path, + OutputFormat::Ndjson, + 100, + None, + None, + ) + .unwrap(); + + assert_eq!(count, 10); + let content = std::fs::read_to_string(&path).unwrap(); + let lines: Vec<&str> = content.lines().collect(); + assert_eq!(lines.len(), 10); + + // Each line must be valid JSON + for line in &lines { + let parsed: serde_json::Value = serde_json::from_str(line).unwrap(); + assert!(parsed.is_object()); + } + } + + // === SQL file writer tests === + + #[test] + fn test_sql_file_basic() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("test.sql"); + let mut rng = test_rng(); + let schema = simple_schema(); + + let count = records_to_file( + &mut rng, + Locale::EnUS, + 10, + &schema, + &HashMap::new(), + &path, + OutputFormat::Sql, + 100, + Some("users"), + None, + ) + .unwrap(); + + assert_eq!(count, 10); + let content = std::fs::read_to_string(&path).unwrap(); + assert!(content.contains("INSERT INTO \"users\"")); + assert!(content.contains("\"age\", \"name\"")); + } + + #[test] + fn test_sql_requires_table() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("test.sql"); + let mut rng = test_rng(); + let schema = simple_schema(); + + let err = records_to_file( + &mut rng, + Locale::EnUS, + 10, + &schema, + &HashMap::new(), + &path, + OutputFormat::Sql, + 100, + None, + None, + ) + .unwrap_err(); + + assert!(err.to_string().contains("table name")); + } + + #[test] + fn test_chunk_size_zero_errors() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("test.csv"); + let mut rng = test_rng(); + let schema = simple_schema(); + + let err = records_to_file( + &mut rng, + Locale::EnUS, + 10, + &schema, + &HashMap::new(), + &path, + OutputFormat::Csv, + 0, + None, + None, + ) + .unwrap_err(); + + assert!(err.to_string().contains("chunk_size")); + } + + #[test] + fn test_sql_batching_across_chunks() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("test.sql"); + let mut rng = test_rng(); + let schema = simple_schema(); + + // 2500 records in chunks of 1000 -> chunks of [1000, 1000, 500] + // Each chunk generates INSERT statements with SQL_BATCH_SIZE (1000) rows + records_to_file( + &mut rng, + Locale::EnUS, + 2500, + &schema, + &HashMap::new(), + &path, + OutputFormat::Sql, + 1000, + Some("users"), + None, + ) + .unwrap(); + + let content = std::fs::read_to_string(&path).unwrap(); + // 3 INSERT statements: 1000 + 1000 + 500 + assert_eq!(content.matches("INSERT INTO").count(), 3); + } + + // === Parquet file writer tests === + + #[test] + fn test_parquet_file_basic() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("test.parquet"); + let mut rng = test_rng(); + let schema = simple_schema(); + + let count = records_to_file( + &mut rng, + Locale::EnUS, + 10, + &schema, + &HashMap::new(), + &path, + OutputFormat::Parquet, + 100, + None, + None, + ) + .unwrap(); + + assert_eq!(count, 10); + let bytes = std::fs::read(&path).unwrap(); + assert_eq!(&bytes[..4], b"PAR1"); + assert_eq!(&bytes[bytes.len() - 4..], b"PAR1"); + } + + // === Chunking tests === + + #[test] + fn test_multiple_chunks_csv() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("test.csv"); + let mut rng = test_rng(); + let schema = simple_schema(); + + // 50 records in chunks of 10 + let count = records_to_file( + &mut rng, + Locale::EnUS, + 50, + &schema, + &HashMap::new(), + &path, + OutputFormat::Csv, + 10, + None, + None, + ) + .unwrap(); + + assert_eq!(count, 50); + let content = std::fs::read_to_string(&path).unwrap(); + let lines: Vec<&str> = content.lines().collect(); + assert_eq!(lines.len(), 51); // 1 header + 50 data + } + + #[test] + fn test_deterministic_csv() { + let dir = tempfile::tempdir().unwrap(); + let schema = simple_schema(); + + let path1 = dir.path().join("test1.csv"); + let mut rng1 = test_rng(); + records_to_file( + &mut rng1, + Locale::EnUS, + 100, + &schema, + &HashMap::new(), + &path1, + OutputFormat::Csv, + 30, + None, + None, + ) + .unwrap(); + + let path2 = dir.path().join("test2.csv"); + let mut rng2 = test_rng(); + records_to_file( + &mut rng2, + Locale::EnUS, + 100, + &schema, + &HashMap::new(), + &path2, + OutputFormat::Csv, + 30, + None, + None, + ) + .unwrap(); + + let content1 = std::fs::read_to_string(&path1).unwrap(); + let content2 = std::fs::read_to_string(&path2).unwrap(); + assert_eq!(content1, content2); + } + + #[test] + fn test_chunk_size_does_not_affect_output() { + // Same seed with different chunk sizes should produce identical data + let dir = tempfile::tempdir().unwrap(); + let schema = simple_schema(); + + let path1 = dir.path().join("chunk50.ndjson"); + let mut rng1 = test_rng(); + records_to_file( + &mut rng1, + Locale::EnUS, + 100, + &schema, + &HashMap::new(), + &path1, + OutputFormat::Ndjson, + 50, + None, + None, + ) + .unwrap(); + + let path2 = dir.path().join("chunk10.ndjson"); + let mut rng2 = test_rng(); + records_to_file( + &mut rng2, + Locale::EnUS, + 100, + &schema, + &HashMap::new(), + &path2, + OutputFormat::Ndjson, + 10, + None, + None, + ) + .unwrap(); + + let content1 = std::fs::read_to_string(&path1).unwrap(); + let content2 = std::fs::read_to_string(&path2).unwrap(); + assert_eq!(content1, content2); + } + + // === Progress callback tests === + + #[test] + fn test_progress_callback() { + use std::cell::RefCell; + use std::rc::Rc; + + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("test.csv"); + let mut rng = test_rng(); + let schema = simple_schema(); + + let calls: Rc>> = Rc::new(RefCell::new(Vec::new())); + let calls_clone = Rc::clone(&calls); + + records_to_file( + &mut rng, + Locale::EnUS, + 50, + &schema, + &HashMap::new(), + &path, + OutputFormat::Csv, + 20, + None, + Some(&move |written, total| { + calls_clone.borrow_mut().push((written, total)); + Ok(()) + }), + ) + .unwrap(); + + let calls = calls.borrow(); + // 50 records / 20 chunk_size = 3 chunks (20 + 20 + 10) + assert_eq!(calls.len(), 3); + assert_eq!(calls[0], (20, 50)); + assert_eq!(calls[1], (40, 50)); + assert_eq!(calls[2], (50, 50)); + } + + #[test] + fn test_progress_callback_error_aborts() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("abort.csv"); + let mut rng = test_rng(); + let schema = simple_schema(); + + let err = records_to_file( + &mut rng, + Locale::EnUS, + 50, + &schema, + &HashMap::new(), + &path, + OutputFormat::Csv, + 20, + None, + Some(&|written, _total| { + if written >= 20 { + Err("user cancelled".to_string()) + } else { + Ok(()) + } + }), + ) + .unwrap_err(); + + assert!(err.to_string().contains("user cancelled")); + } + + #[test] + fn test_invalid_schema_does_not_create_file() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("should_not_exist.csv"); + let mut rng = test_rng(); + + let mut bad_schema = BTreeMap::new(); + bad_schema.insert( + "x".to_string(), + FieldSpec::IntRange { min: 10, max: 1 }, // invalid: min > max + ); + + let result = records_to_file( + &mut rng, + Locale::EnUS, + 10, + &bad_schema, + &HashMap::new(), + &path, + OutputFormat::Csv, + 100, + None, + None, + ); + + assert!(result.is_err()); + // File must not have been created + assert!(!path.exists()); + } + + #[test] + fn test_invalid_schema_rejected_for_zero_records() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("zero.csv"); + let mut rng = test_rng(); + + let mut bad_schema = BTreeMap::new(); + bad_schema.insert("x".to_string(), FieldSpec::IntRange { min: 10, max: 1 }); + + // Even with n=0, invalid schema should be rejected + let result = records_to_file( + &mut rng, + Locale::EnUS, + 0, + &bad_schema, + &HashMap::new(), + &path, + OutputFormat::Csv, + 100, + None, + None, + ); + + assert!(result.is_err()); + } + + // === Empty records tests === + + #[test] + fn test_empty_csv() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("empty.csv"); + let mut rng = test_rng(); + let schema = simple_schema(); + + let count = records_to_file( + &mut rng, + Locale::EnUS, + 0, + &schema, + &HashMap::new(), + &path, + OutputFormat::Csv, + 100, + None, + None, + ) + .unwrap(); + + assert_eq!(count, 0); + let content = std::fs::read_to_string(&path).unwrap(); + let lines: Vec<&str> = content.lines().collect(); + assert_eq!(lines.len(), 1); // header only + assert_eq!(lines[0], "age,name"); + } + + #[test] + fn test_empty_ndjson() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("empty.ndjson"); + let mut rng = test_rng(); + let schema = simple_schema(); + + records_to_file( + &mut rng, + Locale::EnUS, + 0, + &schema, + &HashMap::new(), + &path, + OutputFormat::Ndjson, + 100, + None, + None, + ) + .unwrap(); + + let content = std::fs::read_to_string(&path).unwrap(); + assert_eq!(content, ""); + } + + // === Memory estimation tests === + + #[test] + fn test_estimate_memory_basic() { + let schema = simple_schema(); + let estimate = estimate_memory(1000, &schema); + assert!(estimate > 0); + } + + #[test] + fn test_estimate_memory_scales_linearly() { + let schema = simple_schema(); + let est_1k = estimate_memory(1_000, &schema); + let est_2k = estimate_memory(2_000, &schema); + // Should be approximately 2x (within overhead tolerance) + assert!((est_2k as f64 / est_1k as f64 - 2.0).abs() < 0.01); + } + + #[test] + fn test_estimate_memory_more_fields_more_memory() { + let small_schema = simple_schema(); + + let mut large_schema = simple_schema(); + large_schema.insert("email".to_string(), FieldSpec::Email); + large_schema.insert("phone".to_string(), FieldSpec::Phone); + large_schema.insert("address".to_string(), FieldSpec::Address); + + let small_est = estimate_memory(1000, &small_schema); + let large_est = estimate_memory(1000, &large_schema); + assert!(large_est > small_est); + } + + #[test] + fn test_estimate_memory_zero_records() { + let schema = simple_schema(); + let estimate = estimate_memory(0, &schema); + // Just Vec overhead + assert_eq!(estimate, 24); + } +} diff --git a/src/providers/mod.rs b/src/providers/mod.rs index 6508f98..80163e5 100644 --- a/src/providers/mod.rs +++ b/src/providers/mod.rs @@ -14,6 +14,7 @@ pub mod currency; pub mod custom; pub mod datetime; pub mod file; +pub mod file_writer; pub mod finance; pub mod geo; pub mod html; diff --git a/src/providers/records.rs b/src/providers/records.rs index 7c4d8e3..d2d8a43 100644 --- a/src/providers/records.rs +++ b/src/providers/records.rs @@ -1133,6 +1133,209 @@ fn generate_arrow_column( } } +/// Convert pre-generated row-major records to an Arrow RecordBatch. +/// +/// This transposes rows of `BTreeMap` into columnar Arrow arrays. +/// Used by the streaming file writer to write Parquet chunks from row-major +/// records, preserving RNG order consistency with CSV/NDJSON/SQL formats. +/// +/// # Errors +/// +/// Returns an error if the records contain values that don't match the schema +/// field types. +pub fn records_to_record_batch( + records: &[BTreeMap], + schema: &BTreeMap, +) -> Result { + let n = records.len(); + + // Build Arrow schema + let mut arrow_fields: Vec = Vec::with_capacity(schema.len()); + let field_names: Vec<&String> = schema.keys().collect(); + + for (name, spec) in schema.iter() { + let arrow_type = field_spec_to_arrow_type(spec); + arrow_fields.push(Field::new(name, arrow_type, false)); + } + + let arrow_schema = Arc::new(Schema::new(arrow_fields)); + + // Build columns by iterating over fields, then over records + let mut columns: Vec = Vec::with_capacity(schema.len()); + + for name in &field_names { + let spec = &schema[*name]; + let column = build_column_from_records(records, name, spec, n)?; + columns.push(column); + } + + RecordBatch::try_new(arrow_schema, columns).map_err(|e| SchemaError { + message: format!("Failed to create RecordBatch: {}", e), + }) +} + +/// Build a single Arrow column from pre-generated row-major records. +fn build_column_from_records( + records: &[BTreeMap], + field_name: &str, + spec: &FieldSpec, + n: usize, +) -> Result { + match spec { + FieldSpec::Int | FieldSpec::IntRange { .. } | FieldSpec::VehicleYear => { + extract_i64_column(records, field_name) + } + FieldSpec::Float + | FieldSpec::FloatRange { .. } + | FieldSpec::Latitude + | FieldSpec::Longitude => extract_f64_column(records, field_name), + FieldSpec::Boolean => extract_bool_column(records, field_name), + FieldSpec::Coordinate => extract_coordinate_column(records, field_name, n), + FieldSpec::RgbColor => extract_rgb_column(records, field_name, n), + _ => { + let values: Vec = records.iter().map(|r| r[field_name].as_string()).collect(); + Ok(Arc::new(StringArray::from(values))) + } + } +} + +/// Extract an Int64 column from row-major records. +fn extract_i64_column( + records: &[BTreeMap], + field_name: &str, +) -> Result { + let values: Vec = records + .iter() + .map(|r| match &r[field_name] { + Value::Int(i) => Ok(*i), + other => Err(SchemaError { + message: format!("expected Int for field '{}', got {:?}", field_name, other), + }), + }) + .collect::, _>>()?; + Ok(Arc::new(Int64Array::from(values))) +} + +/// Extract a Float64 column from row-major records. +fn extract_f64_column( + records: &[BTreeMap], + field_name: &str, +) -> Result { + let values: Vec = records + .iter() + .map(|r| match &r[field_name] { + Value::Float(f) => Ok(*f), + other => Err(SchemaError { + message: format!("expected Float for field '{}', got {:?}", field_name, other), + }), + }) + .collect::, _>>()?; + Ok(Arc::new(Float64Array::from(values))) +} + +/// Extract a Boolean column from row-major records. +fn extract_bool_column( + records: &[BTreeMap], + field_name: &str, +) -> Result { + let values: Vec = records + .iter() + .map(|r| match &r[field_name] { + Value::Bool(b) => Ok(*b), + other => Err(SchemaError { + message: format!("expected Bool for field '{}', got {:?}", field_name, other), + }), + }) + .collect::, _>>()?; + Ok(Arc::new(BooleanArray::from(values))) +} + +/// Extract a Coordinate struct column (lat, lng) from row-major records. +fn extract_coordinate_column( + records: &[BTreeMap], + field_name: &str, + n: usize, +) -> Result { + let mut lat_values: Vec = Vec::with_capacity(n); + let mut lng_values: Vec = Vec::with_capacity(n); + + for record in records { + match &record[field_name] { + Value::Tuple2F64(lat, lng) => { + lat_values.push(*lat); + lng_values.push(*lng); + } + other => { + return Err(SchemaError { + message: format!( + "expected Tuple2F64 for field '{}', got {:?}", + field_name, other + ), + }) + } + } + } + + let lat_array = Arc::new(Float64Array::from(lat_values)) as ArrayRef; + let lng_array = Arc::new(Float64Array::from(lng_values)) as ArrayRef; + + let struct_fields: Vec = vec![ + Field::new("lat", DataType::Float64, false), + Field::new("lng", DataType::Float64, false), + ]; + + Ok(Arc::new(StructArray::new( + struct_fields.into(), + vec![lat_array, lng_array], + None::, + ))) +} + +/// Extract an RGB struct column (r, g, b) from row-major records. +fn extract_rgb_column( + records: &[BTreeMap], + field_name: &str, + n: usize, +) -> Result { + let mut r_values: Vec = Vec::with_capacity(n); + let mut g_values: Vec = Vec::with_capacity(n); + let mut b_values: Vec = Vec::with_capacity(n); + + for record in records { + match &record[field_name] { + Value::Tuple3U8(r, g, b) => { + r_values.push(*r); + g_values.push(*g); + b_values.push(*b); + } + other => { + return Err(SchemaError { + message: format!( + "expected Tuple3U8 for field '{}', got {:?}", + field_name, other + ), + }) + } + } + } + + let r_array = Arc::new(UInt8Array::from(r_values)) as ArrayRef; + let g_array = Arc::new(UInt8Array::from(g_values)) as ArrayRef; + let b_array = Arc::new(UInt8Array::from(b_values)) as ArrayRef; + + let struct_fields: Vec = vec![ + Field::new("r", DataType::UInt8, false), + Field::new("g", DataType::UInt8, false), + Field::new("b", DataType::UInt8, false), + ]; + + Ok(Arc::new(StructArray::new( + struct_fields.into(), + vec![r_array, g_array, b_array], + None::, + ))) +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/providers/serialize.rs b/src/providers/serialize.rs index 31cef1e..355f985 100644 --- a/src/providers/serialize.rs +++ b/src/providers/serialize.rs @@ -13,7 +13,7 @@ use crate::providers::records::{ use crate::rng::ForgeryRng; /// Maximum number of rows per SQL INSERT statement. -const SQL_BATCH_SIZE: usize = 1000; +pub(crate) const SQL_BATCH_SIZE: usize = 1000; // ============================================================================ // Value conversion helpers @@ -23,7 +23,7 @@ const SQL_BATCH_SIZE: usize = 1000; /// /// Integers and floats become JSON numbers, booleans become JSON booleans, /// tuples become JSON arrays, and strings become JSON strings. -fn value_to_json(v: &Value) -> serde_json::Value { +pub(crate) fn value_to_json(v: &Value) -> serde_json::Value { match v { Value::String(s) => serde_json::Value::String(s.clone()), Value::Int(i) => serde_json::json!(*i), @@ -35,7 +35,7 @@ fn value_to_json(v: &Value) -> serde_json::Value { } /// Convert a record (BTreeMap) to a `serde_json::Value::Object`. -fn record_to_json_object(record: &BTreeMap) -> serde_json::Value { +pub(crate) fn record_to_json_object(record: &BTreeMap) -> serde_json::Value { let obj: serde_json::Map = record .iter() .map(|(k, v)| (k.clone(), value_to_json(v))) @@ -48,18 +48,18 @@ fn record_to_json_object(record: &BTreeMap) -> serde_json::Value /// Doubles single quotes per ANSI SQL. Backslashes are left as-is /// because ANSI SQL and PostgreSQL (with `standard_conforming_strings = on`, /// the default since 9.1) treat them as literal characters. -fn sql_escape(s: &str) -> String { +pub(crate) fn sql_escape(s: &str) -> String { s.replace('\'', "''") } /// Escape a SQL identifier by doubling any embedded double-quotes, /// per ANSI SQL (e.g., `my"col` becomes `"my""col"`). -fn sql_quote_identifier(name: &str) -> String { +pub(crate) fn sql_quote_identifier(name: &str) -> String { format!("\"{}\"", name.replace('"', "\"\"")) } /// Format a `Value` as a SQL literal. -fn value_to_sql(v: &Value) -> String { +pub(crate) fn value_to_sql(v: &Value) -> String { match v { Value::String(s) => format!("'{}'", sql_escape(s)), Value::Int(i) => i.to_string(), diff --git a/tests/test_convenience_new.py b/tests/test_convenience_new.py index 7105f90..d3b6ea5 100644 --- a/tests/test_convenience_new.py +++ b/tests/test_convenience_new.py @@ -237,3 +237,18 @@ def test_profiles(self) -> None: forgery.seed(42) ps = forgery.profiles(5) assert len(ps) == 5 + + +class TestFileWriterConvenience: + """Tests for streaming file writer convenience functions.""" + + def test_records_to_file(self, tmp_path) -> None: + path = str(tmp_path / "test.csv") + forgery.seed(42) + count = forgery.records_to_file(10, {"name": "name"}, path) + assert count == 10 + + def test_estimate_memory(self) -> None: + est = forgery.estimate_memory(1000, {"name": "name", "age": ("int", 18, 65)}) + assert isinstance(est, int) + assert est > 0 diff --git a/tests/test_file_writer.py b/tests/test_file_writer.py new file mode 100644 index 0000000..e9263f7 --- /dev/null +++ b/tests/test_file_writer.py @@ -0,0 +1,413 @@ +"""Tests for streaming file writer (records_to_file) and memory estimation.""" + +import csv +import json +from pathlib import Path + +import pytest + +from forgery import ( + Faker, + estimate_memory, + records_to_file, + seed, +) + +# Check if pyarrow is available for parquet roundtrip tests +try: + import pyarrow.parquet as pq + + HAS_PYARROW = True +except ImportError: + HAS_PYARROW = False + +SIMPLE_SCHEMA = {"age": ("int", 18, 65), "name": "name"} + +MULTI_TYPE_SCHEMA = { + "active": "boolean", + "age": ("int", 18, 65), + "name": "name", + "salary": ("float", 30000.0, 150000.0), +} + + +class TestRecordsToFileCSV: + """Tests for records_to_file() with CSV format.""" + + def test_basic(self, tmp_path: Path) -> None: + path = str(tmp_path / "test.csv") + seed(42) + count = records_to_file(100, SIMPLE_SCHEMA, path) + assert count == 100 + + with Path(path).open() as f: + reader = csv.DictReader(f) + rows = list(reader) + assert len(rows) == 100 + assert set(rows[0].keys()) == {"age", "name"} + + def test_auto_detect_extension(self, tmp_path: Path) -> None: + path = str(tmp_path / "data.csv") + seed(42) + count = records_to_file(10, SIMPLE_SCHEMA, path) + assert count == 10 + + def test_explicit_format_overrides_extension(self, tmp_path: Path) -> None: + path = str(tmp_path / "data.txt") + seed(42) + count = records_to_file(10, SIMPLE_SCHEMA, path, format="csv") + assert count == 10 + + with Path(path).open() as f: + reader = csv.reader(f) + rows = list(reader) + assert rows[0] == ["age", "name"] + assert len(rows) == 11 + + def test_chunked_matches_single(self, tmp_path: Path) -> None: + """Different chunk sizes produce identical output.""" + path1 = tmp_path / "chunk50.csv" + seed(42) + records_to_file(100, SIMPLE_SCHEMA, str(path1), chunk_size=50) + + path2 = tmp_path / "chunk10.csv" + seed(42) + records_to_file(100, SIMPLE_SCHEMA, str(path2), chunk_size=10) + + assert path1.read_text() == path2.read_text() + + def test_empty(self, tmp_path: Path) -> None: + path = str(tmp_path / "empty.csv") + seed(42) + count = records_to_file(0, SIMPLE_SCHEMA, path) + assert count == 0 + + with Path(path).open() as f: + reader = csv.reader(f) + rows = list(reader) + assert len(rows) == 1 # header only + assert rows[0] == ["age", "name"] + + def test_deterministic(self, tmp_path: Path) -> None: + path1 = tmp_path / "det1.csv" + seed(42) + records_to_file(100, SIMPLE_SCHEMA, str(path1)) + + path2 = tmp_path / "det2.csv" + seed(42) + records_to_file(100, SIMPLE_SCHEMA, str(path2)) + + assert path1.read_text() == path2.read_text() + + def test_instance_method(self, tmp_path: Path) -> None: + path = str(tmp_path / "test.csv") + f = Faker() + f.seed(42) + count = f.records_to_file(50, SIMPLE_SCHEMA, path) + assert count == 50 + + +class TestRecordsToFileNDJSON: + """Tests for records_to_file() with NDJSON format.""" + + def test_basic(self, tmp_path: Path) -> None: + path = str(tmp_path / "test.ndjson") + seed(42) + count = records_to_file(100, SIMPLE_SCHEMA, path) + assert count == 100 + + with Path(path).open() as f: + lines = f.read().strip().split("\n") + assert len(lines) == 100 + + for line in lines: + obj = json.loads(line) + assert "name" in obj + assert "age" in obj + + def test_auto_detect_jsonl(self, tmp_path: Path) -> None: + path = str(tmp_path / "data.jsonl") + seed(42) + count = records_to_file(10, SIMPLE_SCHEMA, path) + assert count == 10 + + def test_chunked(self, tmp_path: Path) -> None: + path = str(tmp_path / "chunked.ndjson") + seed(42) + records_to_file(50, SIMPLE_SCHEMA, path, chunk_size=15) + + with Path(path).open() as f: + lines = f.read().strip().split("\n") + assert len(lines) == 50 + + def test_empty(self, tmp_path: Path) -> None: + path = tmp_path / "empty.ndjson" + seed(42) + records_to_file(0, SIMPLE_SCHEMA, str(path)) + + assert path.read_text() == "" + + def test_types_preserved(self, tmp_path: Path) -> None: + path = str(tmp_path / "types.ndjson") + seed(42) + records_to_file(10, MULTI_TYPE_SCHEMA, path) + + with Path(path).open() as f: + obj = json.loads(f.readline()) + assert isinstance(obj["age"], int) + assert isinstance(obj["salary"], float) + assert isinstance(obj["active"], bool) + assert isinstance(obj["name"], str) + + +class TestRecordsToFileSQL: + """Tests for records_to_file() with SQL format.""" + + def test_basic(self, tmp_path: Path) -> None: + path = tmp_path / "test.sql" + seed(42) + count = records_to_file(10, SIMPLE_SCHEMA, str(path), table="users") + assert count == 10 + + content = path.read_text() + assert 'INSERT INTO "users"' in content + assert '"age", "name"' in content + + def test_requires_table(self, tmp_path: Path) -> None: + path = str(tmp_path / "test.sql") + with pytest.raises(ValueError, match="table name"): + records_to_file(10, SIMPLE_SCHEMA, path) + + def test_sql_batching(self, tmp_path: Path) -> None: + path = tmp_path / "batch.sql" + seed(42) + records_to_file(2500, SIMPLE_SCHEMA, str(path), table="users", chunk_size=1000) + + content = path.read_text() + # 2500 records -> 3 INSERT statements (1000 + 1000 + 500) + assert content.count("INSERT INTO") == 3 + + +class TestRecordsToFileParquet: + """Tests for records_to_file() with Parquet format.""" + + def test_basic(self, tmp_path: Path) -> None: + path = tmp_path / "test.parquet" + seed(42) + count = records_to_file(100, SIMPLE_SCHEMA, str(path)) + assert count == 100 + + # Check magic bytes + data = path.read_bytes() + assert data[:4] == b"PAR1" + assert data[-4:] == b"PAR1" + + @pytest.mark.skipif(not HAS_PYARROW, reason="pyarrow not installed") + def test_roundtrip(self, tmp_path: Path) -> None: + path = str(tmp_path / "roundtrip.parquet") + seed(42) + records_to_file(100, SIMPLE_SCHEMA, path, chunk_size=30) + + table = pq.read_table(path) + assert table.num_rows == 100 + assert set(table.column_names) == {"age", "name"} + + @pytest.mark.skipif(not HAS_PYARROW, reason="pyarrow not installed") + def test_chunked_roundtrip(self, tmp_path: Path) -> None: + """Multiple chunks still produce a valid file with all rows.""" + path = str(tmp_path / "multi_chunk.parquet") + seed(42) + records_to_file(100, SIMPLE_SCHEMA, path, chunk_size=30) + + table = pq.read_table(path) + assert table.num_rows == 100 + assert set(table.column_names) == {"age", "name"} + + +class TestRecordsToFileFormat: + """Tests for format detection and errors.""" + + def test_json_not_supported(self, tmp_path: Path) -> None: + path = str(tmp_path / "data.json") + with pytest.raises(ValueError, match="ndjson"): + records_to_file(10, SIMPLE_SCHEMA, path) + + def test_json_format_name_not_supported(self, tmp_path: Path) -> None: + path = str(tmp_path / "data.ndjson") + with pytest.raises(ValueError, match="ndjson"): + records_to_file(10, SIMPLE_SCHEMA, path, format="json") + + def test_unknown_extension(self, tmp_path: Path) -> None: + path = str(tmp_path / "data.xyz") + with pytest.raises(ValueError, match="auto-detect"): + records_to_file(10, SIMPLE_SCHEMA, path) + + def test_unsupported_format_name(self, tmp_path: Path) -> None: + path = str(tmp_path / "data.csv") + with pytest.raises(ValueError, match="unsupported"): + records_to_file(10, SIMPLE_SCHEMA, path, format="xml") + + def test_invalid_schema(self, tmp_path: Path) -> None: + path = str(tmp_path / "data.csv") + with pytest.raises(ValueError): + records_to_file(10, {"x": "nonexistent_type"}, path) + + def test_chunk_size_zero(self, tmp_path: Path) -> None: + path = str(tmp_path / "data.csv") + with pytest.raises(ValueError, match="chunk_size"): + records_to_file(10, SIMPLE_SCHEMA, path, chunk_size=0) + + +class TestRecordsToFileProgress: + """Tests for the progress callback.""" + + def test_callback_fires(self, tmp_path: Path) -> None: + path = str(tmp_path / "prog.csv") + seed(42) + calls: list[tuple[int, int]] = [] + + def on_progress(written: int, total: int) -> None: + calls.append((written, total)) + + records_to_file(50, SIMPLE_SCHEMA, path, chunk_size=20, on_progress=on_progress) + + # 50 / 20 = 3 chunks (20 + 20 + 10) + assert len(calls) == 3 + assert calls[0] == (20, 50) + assert calls[1] == (40, 50) + assert calls[2] == (50, 50) + + def test_callback_single_chunk(self, tmp_path: Path) -> None: + path = str(tmp_path / "prog.csv") + seed(42) + calls: list[tuple[int, int]] = [] + + records_to_file( + 10, SIMPLE_SCHEMA, path, chunk_size=100, on_progress=lambda w, t: calls.append((w, t)) + ) + + assert len(calls) == 1 + assert calls[0] == (10, 10) + + def test_callback_exception_aborts(self, tmp_path: Path) -> None: + path = str(tmp_path / "abort.csv") + seed(42) + + def on_progress(written: int, _total: int) -> None: + if written >= 20: + raise RuntimeError("user cancelled") + + with pytest.raises(RuntimeError, match="user cancelled"): + records_to_file(50, SIMPLE_SCHEMA, path, chunk_size=20, on_progress=on_progress) + + def test_callback_exception_preserves_type(self, tmp_path: Path) -> None: + """Verify the original exception type is preserved, not wrapped in ValueError.""" + path = str(tmp_path / "type.csv") + seed(42) + + with pytest.raises(KeyError): + records_to_file( + 10, + SIMPLE_SCHEMA, + path, + chunk_size=5, + on_progress=lambda _w, _t: (_ for _ in ()).throw(KeyError("bad key")), + ) + + +class TestRecordsToFileIOErrors: + """Tests for proper OSError mapping.""" + + def test_nonexistent_directory_raises_oserror(self, tmp_path: Path) -> None: + path = str(tmp_path / "no" / "such" / "dir" / "test.csv") + with pytest.raises(OSError): + records_to_file(10, SIMPLE_SCHEMA, path) + + def test_invalid_schema_does_not_create_file(self, tmp_path: Path) -> None: + path = tmp_path / "should_not_exist.csv" + with pytest.raises(ValueError): + records_to_file(10, {"x": ("int", 10, 1)}, str(path)) + assert not path.exists() + + def test_invalid_schema_rejected_for_zero_records(self, tmp_path: Path) -> None: + path = str(tmp_path / "zero.csv") + with pytest.raises(ValueError): + records_to_file(0, {"x": ("int", 10, 1)}, path) + + +class TestRecordsToFileCustomProvider: + """Tests for custom provider support in file writer.""" + + def test_custom_provider(self, tmp_path: Path) -> None: + path = str(tmp_path / "custom.ndjson") + f = Faker() + f.seed(42) + f.add_provider("fav_animal", ["cat", "dog", "fish"]) + count = f.records_to_file( + 20, + {"name": "name", "pet": "fav_animal"}, + path, + ) + assert count == 20 + + with Path(path).open() as fh: + for line in fh: + obj = json.loads(line) + assert obj["pet"] in {"cat", "dog", "fish"} + + +class TestEstimateMemory: + """Tests for estimate_memory().""" + + def test_basic(self) -> None: + est = estimate_memory(1000, SIMPLE_SCHEMA) + assert isinstance(est, int) + assert est > 0 + + def test_zero_records(self) -> None: + est = estimate_memory(0, SIMPLE_SCHEMA) + assert est == 24 # just Vec overhead + + def test_scales_linearly(self) -> None: + est_1k = estimate_memory(1_000, SIMPLE_SCHEMA) + est_2k = estimate_memory(2_000, SIMPLE_SCHEMA) + ratio = est_2k / est_1k + assert 1.95 < ratio < 2.05 + + def test_more_fields_more_memory(self) -> None: + small_est = estimate_memory(1000, SIMPLE_SCHEMA) + large_est = estimate_memory(1000, MULTI_TYPE_SCHEMA) + assert large_est > small_est + + def test_static_method(self) -> None: + est = Faker.estimate_memory(1000, SIMPLE_SCHEMA) + assert est > 0 + + def test_convenience_function(self) -> None: + est = estimate_memory(1000, SIMPLE_SCHEMA) + assert est > 0 + + +class TestRecordsToFileLargeChunked: + """Test with larger record counts to verify chunking works end-to-end.""" + + def test_10k_records_csv(self, tmp_path: Path) -> None: + path = str(tmp_path / "large.csv") + seed(42) + count = records_to_file(10_000, SIMPLE_SCHEMA, path, chunk_size=1_000) + assert count == 10_000 + + with Path(path).open() as f: + reader = csv.reader(f) + rows = list(reader) + assert len(rows) == 10_001 # header + 10k data + + def test_10k_records_ndjson(self, tmp_path: Path) -> None: + path = str(tmp_path / "large.ndjson") + seed(42) + count = records_to_file(10_000, SIMPLE_SCHEMA, path, chunk_size=1_000) + assert count == 10_000 + + with Path(path).open() as f: + line_count = sum(1 for _ in f) + assert line_count == 10_000