diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..10917c4 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,34 @@ +name: CI + +on: + push: + branches: [main, master] + pull_request: + branches: [main, master] + +jobs: + test: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.10", "3.11", "3.12"] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e ".[dev]" + + - name: Run linter + run: ruff check src/ tests/ + + - name: Run tests + run: pytest tests/ -v diff --git a/.gitignore b/.gitignore index f1ff277..d9cdffd 100644 --- a/.gitignore +++ b/.gitignore @@ -56,3 +56,7 @@ htmlcov/ # Build output dist/ build/ + +# Agent/workflow tracking +.agent/ +.beads/ diff --git a/src/tpuff/cli.py b/src/tpuff/cli.py index 3b0506a..aebfd5f 100644 --- a/src/tpuff/cli.py +++ b/src/tpuff/cli.py @@ -3,13 +3,13 @@ import click from tpuff import __version__ -from tpuff.commands.list import list_cmd -from tpuff.commands.search import search from tpuff.commands.delete import delete from tpuff.commands.edit import edit -from tpuff.commands.get import get from tpuff.commands.export import export - +from tpuff.commands.get import get +from tpuff.commands.list import list_cmd +from tpuff.commands.schema import schema +from tpuff.commands.search import search # Context settings to enable -h as help alias for all commands CONTEXT_SETTINGS = {"help_option_names": ["-h", "--help"]} @@ -35,6 +35,7 @@ def cli(ctx: click.Context, debug: bool) -> None: cli.add_command(get) cli.add_command(export) cli.add_command(export, name="metrics") # alias +cli.add_command(schema) def main() -> None: diff --git a/src/tpuff/client.py b/src/tpuff/client.py index 5e29753..ac837d6 100644 --- a/src/tpuff/client.py +++ b/src/tpuff/client.py @@ -8,7 +8,6 @@ from tpuff.utils.debug import debug_log from tpuff.utils.regions import DEFAULT_REGION - # Global client cache to avoid re-creating clients for the same region _client_cache: dict[str, Turbopuffer] = {} diff --git a/src/tpuff/commands/delete.py b/src/tpuff/commands/delete.py index decdb57..15e029d 100644 --- a/src/tpuff/commands/delete.py +++ b/src/tpuff/commands/delete.py @@ -8,7 +8,6 @@ from tpuff.client import get_turbopuffer_client from tpuff.utils.debug import debug_log - console = Console() @@ -91,7 +90,7 @@ def delete( console.print(f"[dim] - {ns.id}[/dim]") console.print("\n[bold yellow]💀 This is your last chance to back out! 💀[/bold yellow]") - console.print(f"[dim]To confirm, please type: [bold red]yolo[/bold red][/dim]\n") + console.print("[dim]To confirm, please type: [bold red]yolo[/bold red][/dim]\n") answer = prompt_user(">") diff --git a/src/tpuff/commands/edit.py b/src/tpuff/commands/edit.py index e63ddff..2eb3056 100644 --- a/src/tpuff/commands/edit.py +++ b/src/tpuff/commands/edit.py @@ -12,7 +12,6 @@ from tpuff.client import get_namespace from tpuff.utils.debug import debug_log - console = Console() diff --git a/src/tpuff/commands/export.py b/src/tpuff/commands/export.py index 89c30a2..a530cac 100644 --- a/src/tpuff/commands/export.py +++ b/src/tpuff/commands/export.py @@ -6,7 +6,7 @@ import threading import time from datetime import datetime -from http.server import HTTPServer, BaseHTTPRequestHandler +from http.server import BaseHTTPRequestHandler, HTTPServer from typing import Any import click @@ -21,14 +21,13 @@ get_unindexed_bytes, ) from tpuff.utils.metrics import ( - PrometheusMetric, MetricValue, - format_prometheus_metrics, + PrometheusMetric, create_simple_gauge, + format_prometheus_metrics, get_current_timestamp, ) - console = Console() @@ -382,11 +381,11 @@ def shutdown_handler(signum: int, frame: Any) -> None: console.print("[dim] Note: Recall estimation runs queries and incurs costs[/dim]") else: console.print("[dim] Recall: disabled (use --include-recall to enable)[/dim]") - console.print(f"\n[dim] Endpoints:[/dim]") + console.print("\n[dim] Endpoints:[/dim]") console.print(f"[dim] http://localhost:{port}/metrics[/dim]") console.print(f"[dim] http://localhost:{port}/health[/dim]") console.print(f"[dim] http://localhost:{port}/[/dim]") - console.print(f"\n[dim] Press Ctrl+C to stop[/dim]\n") + console.print("\n[dim] Press Ctrl+C to stop[/dim]\n") # Start server try: diff --git a/src/tpuff/commands/get.py b/src/tpuff/commands/get.py index 1589c0f..3991576 100644 --- a/src/tpuff/commands/get.py +++ b/src/tpuff/commands/get.py @@ -9,7 +9,6 @@ from tpuff.client import get_namespace from tpuff.utils.debug import debug_log - console = Console() diff --git a/src/tpuff/commands/list.py b/src/tpuff/commands/list.py index 9171449..a8f3a00 100644 --- a/src/tpuff/commands/list.py +++ b/src/tpuff/commands/list.py @@ -8,7 +8,7 @@ from rich.console import Console from rich.table import Table -from tpuff.client import get_namespace, get_turbopuffer_client +from tpuff.client import get_namespace from tpuff.utils.debug import debug_log from tpuff.utils.metadata_fetcher import ( NamespaceWithMetadata, @@ -17,7 +17,6 @@ get_unindexed_bytes, ) - console = Console() diff --git a/src/tpuff/commands/schema.py b/src/tpuff/commands/schema.py new file mode 100644 index 0000000..28e77e9 --- /dev/null +++ b/src/tpuff/commands/schema.py @@ -0,0 +1,790 @@ +"""Schema management commands for tpuff CLI.""" + +import json +import re +import sys +from dataclasses import dataclass, field + +import click +from rich.console import Console +from rich.table import Table + +from tpuff.client import get_namespace, get_turbopuffer_client + +console = Console() + +# Valid simple schema types +VALID_SIMPLE_TYPES = {"string", "uint64", "uuid", "bool"} + +# Regex for vector types: [dims]f32 or [dims]f16 +VECTOR_TYPE_PATTERN = re.compile(r"^\[\d+\]f(16|32)$") + +# Valid keys for complex type objects +VALID_TYPE_KEYS = {"type", "full_text_search", "regex_index", "filterable"} + + +@dataclass +class SchemaDiff: + """Result of comparing two schemas.""" + + unchanged: dict[str, str] = field(default_factory=dict) + additions: dict[str, str] = field(default_factory=dict) + conflicts: dict[str, tuple[str, str]] = field(default_factory=dict) # attr -> (old_type, new_type) + + @property + def has_conflicts(self) -> bool: + """Check if there are any type conflicts.""" + return len(self.conflicts) > 0 + + @property + def has_changes(self) -> bool: + """Check if there are any additions or conflicts.""" + return len(self.additions) > 0 or len(self.conflicts) > 0 + + +def normalize_schema_type(attr_type: object) -> str: + """Normalize a schema type to a comparable string representation. + + Handles both simple string types and complex type objects from turbopuffer. + """ + if hasattr(attr_type, "model_dump"): + # Pydantic model - convert to JSON string for comparison + return json.dumps(attr_type.model_dump(), sort_keys=True) + elif hasattr(attr_type, "to_dict"): + return json.dumps(attr_type.to_dict(), sort_keys=True) + elif isinstance(attr_type, dict): + return json.dumps(attr_type, sort_keys=True) + else: + return str(attr_type) + + +def schema_type_for_display(attr_type: object) -> str: + """Convert a schema type to a human-readable display string.""" + if hasattr(attr_type, "model_dump"): + dumped = attr_type.model_dump() + # For complex types, show the full dict; for simple, just the string + if isinstance(dumped, dict) and len(dumped) == 1 and "type" in dumped: + return str(dumped["type"]) + return json.dumps(dumped) + elif hasattr(attr_type, "to_dict"): + return json.dumps(attr_type.to_dict()) + elif isinstance(attr_type, dict): + if len(attr_type) == 1 and "type" in attr_type: + return str(attr_type["type"]) + return json.dumps(attr_type) + else: + return str(attr_type) + + +def display_schema_diff(diff: SchemaDiff, namespace: str) -> None: + """Display a schema diff with Rich formatting. + + Args: + diff: The computed schema diff + namespace: The namespace name (for header) + """ + console.print(f"\n[bold]Schema changes for namespace: {namespace}[/bold]\n") + + if not diff.has_changes and not diff.unchanged: + console.print("[dim]No schema attributes[/dim]") + return + + # Sort all attributes for consistent output + all_attrs = sorted( + set(diff.unchanged.keys()) | set(diff.additions.keys()) | set(diff.conflicts.keys()) + ) + + for attr in all_attrs: + if attr in diff.unchanged: + # Unchanged attribute + console.print(f" {attr}: {diff.unchanged[attr]}") + elif attr in diff.additions: + # New attribute + console.print(f"[green]+{attr}: {diff.additions[attr]}[/green] [dim](new)[/dim]") + elif attr in diff.conflicts: + # Type conflict + old_type, new_type = diff.conflicts[attr] + console.print( + f"[red]!{attr}: {old_type} -> {new_type}[/red] " + f"[dim](type change not allowed)[/dim]" + ) + + console.print() # Blank line at end + + +def compute_schema_diff( + current_schema: dict[str, object] | None, + new_schema: dict[str, object], +) -> SchemaDiff: + """Compute the difference between current and new schemas. + + Args: + current_schema: The existing schema (None if namespace doesn't exist) + new_schema: The schema to be applied + + Returns: + SchemaDiff with unchanged, additions, and conflicts + """ + diff = SchemaDiff() + + if current_schema is None: + current_schema = {} + + # Normalize current schema for comparison + current_normalized = { + attr: normalize_schema_type(attr_type) + for attr, attr_type in current_schema.items() + } + + # Compare each attribute in the new schema + for attr, new_type in new_schema.items(): + new_type_normalized = normalize_schema_type(new_type) + new_type_display = schema_type_for_display(new_type) + + if attr not in current_normalized: + # New attribute + diff.additions[attr] = new_type_display + elif current_normalized[attr] == new_type_normalized: + # Unchanged + diff.unchanged[attr] = new_type_display + else: + # Type conflict + old_type_display = schema_type_for_display(current_schema[attr]) + diff.conflicts[attr] = (old_type_display, new_type_display) + + return diff + + +@click.group("schema", context_settings={"help_option_names": ["-h", "--help"]}) +def schema() -> None: + """Manage namespace schemas.""" + pass + + +@schema.command("get", context_settings={"help_option_names": ["-h", "--help"]}) +@click.option("-n", "--namespace", required=True, help="Namespace to get schema from") +@click.option("-r", "--region", help="Override the region (e.g., aws-us-east-1, gcp-us-central1)") +@click.option("--raw", is_flag=True, help="Output raw JSON without formatting (for piping)") +@click.pass_context +def schema_get( + ctx: click.Context, + namespace: str, + region: str | None, + raw: bool, +) -> None: + """Display the schema for a namespace.""" + try: + ns = get_namespace(namespace, region) + metadata = ns.metadata() + + # Extract schema from metadata + schema_data = metadata.schema if hasattr(metadata, "schema") else {} + + if not schema_data: + if raw: + print("{}") + else: + console.print(f"[yellow]No schema found for namespace: {namespace}[/yellow]") + return + + # Convert schema to serializable format + schema_dict = {} + for attr_name, attr_type in schema_data.items(): + # Handle both simple string types and complex type objects + if hasattr(attr_type, "model_dump"): + schema_dict[attr_name] = attr_type.model_dump() + elif hasattr(attr_type, "to_dict"): + schema_dict[attr_name] = attr_type.to_dict() + else: + schema_dict[attr_name] = str(attr_type) + + if raw: + print(json.dumps(schema_dict)) + else: + console.print(f"\n[bold]Schema for namespace: {namespace}[/bold]\n") + console.print(json.dumps(schema_dict, indent=2)) + + except Exception as e: + if raw: + print(json.dumps({"error": str(e)}), file=sys.stderr) + else: + console.print(f"[red]Error: {e}[/red]") + sys.exit(1) + + +def validate_schema_type(attr_name: str, attr_type: object) -> list[str]: + """Validate a single schema attribute type. + + Args: + attr_name: The attribute name (for error messages) + attr_type: The attribute type to validate + + Returns: + List of error messages (empty if valid) + """ + errors = [] + + if isinstance(attr_type, str): + # Simple string type + if attr_type in VALID_SIMPLE_TYPES: + return [] + if VECTOR_TYPE_PATTERN.match(attr_type): + return [] + errors.append( + f"Attribute '{attr_name}': invalid type '{attr_type}'. " + f"Valid types: {', '.join(sorted(VALID_SIMPLE_TYPES))}, or vector format [dims]f32/f16" + ) + elif isinstance(attr_type, dict): + # Complex type object + if "type" not in attr_type: + errors.append(f"Attribute '{attr_name}': complex type object must have a 'type' key") + else: + base_type = attr_type["type"] + if not isinstance(base_type, str): + errors.append(f"Attribute '{attr_name}': 'type' must be a string") + elif base_type not in VALID_SIMPLE_TYPES and not VECTOR_TYPE_PATTERN.match(base_type): + errors.append( + f"Attribute '{attr_name}': invalid base type '{base_type}'. " + f"Valid types: {', '.join(sorted(VALID_SIMPLE_TYPES))}, or vector format [dims]f32/f16" + ) + + # Check for unknown keys + unknown_keys = set(attr_type.keys()) - VALID_TYPE_KEYS + if unknown_keys: + errors.append( + f"Attribute '{attr_name}': unknown keys {sorted(unknown_keys)}. " + f"Valid keys: {', '.join(sorted(VALID_TYPE_KEYS))}" + ) + + # Validate specific option types + if "full_text_search" in attr_type and not isinstance(attr_type["full_text_search"], bool): + errors.append(f"Attribute '{attr_name}': 'full_text_search' must be a boolean") + if "regex_index" in attr_type and not isinstance(attr_type["regex_index"], bool): + errors.append(f"Attribute '{attr_name}': 'regex_index' must be a boolean") + if "filterable" in attr_type and not isinstance(attr_type["filterable"], bool): + errors.append(f"Attribute '{attr_name}': 'filterable' must be a boolean") + else: + errors.append( + f"Attribute '{attr_name}': type must be a string or object, got {type(attr_type).__name__}" + ) + + return errors + + +def validate_schema(schema_data: dict) -> list[str]: + """Validate a complete schema dictionary. + + Args: + schema_data: The schema dictionary to validate + + Returns: + List of error messages (empty if valid) + """ + errors = [] + + for attr_name, attr_type in schema_data.items(): + if not isinstance(attr_name, str): + errors.append(f"Attribute name must be a string, got {type(attr_name).__name__}") + continue + if not attr_name: + errors.append("Attribute name cannot be empty") + continue + + errors.extend(validate_schema_type(attr_name, attr_type)) + + return errors + + +def load_schema_file(file_path: str) -> dict[str, object]: + """Load and validate a schema from a JSON file. + + Args: + file_path: Path to the JSON schema file + + Returns: + The parsed schema dictionary + + Raises: + click.ClickException: If file cannot be read, parsed, or is invalid + """ + try: + with open(file_path) as f: + schema_data = json.load(f) + except FileNotFoundError: + raise click.ClickException(f"Schema file not found: {file_path}") + except json.JSONDecodeError as e: + raise click.ClickException(f"Invalid JSON in schema file: {e}") + + if not isinstance(schema_data, dict): + raise click.ClickException("Schema file must contain a JSON object") + + # Validate schema structure and types + errors = validate_schema(schema_data) + if errors: + error_msg = "Invalid schema:\n " + "\n ".join(errors) + raise click.ClickException(error_msg) + + return schema_data + + +def get_current_schema(ns) -> dict[str, object] | None: + """Get the current schema from a namespace. + + Args: + ns: The turbopuffer namespace object + + Returns: + The schema dict, or None if namespace doesn't exist/has no schema + """ + try: + metadata = ns.metadata() + schema_data = metadata.schema if hasattr(metadata, "schema") else None + if not schema_data: + return None + + # Convert to plain dict for comparison + result = {} + for attr_name, attr_type in schema_data.items(): + if hasattr(attr_type, "model_dump"): + result[attr_name] = attr_type.model_dump() + elif hasattr(attr_type, "to_dict"): + result[attr_name] = attr_type.to_dict() + else: + result[attr_name] = str(attr_type) + return result + except Exception: + # Namespace doesn't exist or other error + return None + + +def list_namespaces_by_prefix(prefix: str, region: str | None) -> list[str]: + """List namespaces matching a prefix. + + Args: + prefix: The prefix to match against namespace names + region: Optional region override + + Returns: + List of namespace IDs matching the prefix + """ + client = get_turbopuffer_client(region) + namespaces = list(client.namespaces()) + return sorted([ns.id for ns in namespaces if ns.id.startswith(prefix)]) + + +def list_all_namespaces(region: str | None) -> list[str]: + """List all namespaces. + + Args: + region: Optional region override + + Returns: + List of all namespace IDs + """ + client = get_turbopuffer_client(region) + namespaces = list(client.namespaces()) + return sorted([ns.id for ns in namespaces]) + + +@dataclass +class BatchApplyResult: + """Result of applying schema to a single namespace in a batch operation.""" + + namespace: str + success: bool + additions: int = 0 + conflicts: int = 0 + error: str | None = None + + +def display_batch_summary(results: list[BatchApplyResult], dry_run: bool = False) -> None: + """Display a summary table of batch apply results. + + Args: + results: List of BatchApplyResult objects + dry_run: Whether this was a dry run + """ + table = Table(show_header=True, header_style="cyan") + table.add_column("Namespace") + table.add_column("Changes") + table.add_column("Status") + + for result in results: + if result.conflicts > 0: + changes = f"+{result.additions} attributes [red]({result.conflicts} conflict(s))[/red]" + status = "[red]blocked[/red]" + elif result.error: + changes = "[dim]N/A[/dim]" + status = f"[red]error: {result.error}[/red]" + elif result.additions == 0: + changes = "[dim]no changes[/dim]" + status = "[green]up-to-date[/green]" if not dry_run else "[dim]would skip[/dim]" + else: + changes = f"+{result.additions} attribute(s)" + if dry_run: + status = "[yellow]would apply[/yellow]" + elif result.success: + status = "[green]applied[/green]" + else: + status = "[red]failed[/red]" + + table.add_row(f"[bold]{result.namespace}[/bold]", changes, status) + + console.print(table) + + +def apply_schema_to_single_namespace( + namespace: str, + new_schema: dict[str, object], + region: str | None, + dry_run: bool, + yes: bool, +) -> None: + """Apply schema to a single namespace with interactive diff display. + + Args: + namespace: Target namespace name + new_schema: Schema to apply + region: Optional region override + dry_run: If True, only show diff without applying + yes: If True, skip confirmation prompt + """ + # Get current schema from namespace + ns = get_namespace(namespace, region) + current_schema = get_current_schema(ns) + + # Compute diff + diff = compute_schema_diff(current_schema, new_schema) + + # Display diff + display_schema_diff(diff, namespace) + + # Check for conflicts + if diff.has_conflicts: + console.print("[red]Error: Cannot apply schema with type conflicts.[/red]") + console.print("[red]Changing an existing attribute's type is not allowed.[/red]") + sys.exit(1) + + # Check if there are any changes + if not diff.has_changes: + console.print("[green]Schema is already up to date, no changes needed.[/green]") + return + + # Dry run stops here + if dry_run: + console.print("[dim]Dry run mode - no changes applied[/dim]") + return + + # Confirm unless --yes + if not yes: + confirm = click.confirm("Apply these schema changes?", default=False) + if not confirm: + console.print("[yellow]Aborted[/yellow]") + return + + # Apply the schema + try: + console.print(f"[dim]Applying schema to {namespace}...[/dim]") + ns.write( + upsert_rows=[{"id": "__schema_placeholder__"}], + schema=new_schema, + ) + console.print(f"[green]Successfully applied schema to {namespace}[/green]") + except Exception as e: + console.print(f"[red]Error applying schema: {e}[/red]") + sys.exit(1) + + +def apply_schema_to_multiple_namespaces( + namespaces: list[str], + new_schema: dict[str, object], + region: str | None, + dry_run: bool, + yes: bool, + continue_on_error: bool = False, +) -> None: + """Apply schema to multiple namespaces with batch summary display. + + Args: + namespaces: List of target namespace names + new_schema: Schema to apply + region: Optional region override + dry_run: If True, only show what would change + yes: If True, skip confirmation prompt + continue_on_error: If True, skip namespaces with conflicts instead of aborting + """ + # Phase 1: Compute diffs for all namespaces + results: list[BatchApplyResult] = [] + has_any_conflicts = False + has_any_changes = False + + console.print(f"\n[bold]Analyzing schema for {len(namespaces)} namespace(s)...[/bold]\n") + + for ns_name in namespaces: + try: + ns = get_namespace(ns_name, region) + current_schema = get_current_schema(ns) + diff = compute_schema_diff(current_schema, new_schema) + + result = BatchApplyResult( + namespace=ns_name, + success=False, # Will be updated after apply + additions=len(diff.additions), + conflicts=len(diff.conflicts), + ) + + if diff.has_conflicts: + has_any_conflicts = True + if diff.has_changes: + has_any_changes = True + + results.append(result) + except Exception as e: + results.append(BatchApplyResult( + namespace=ns_name, + success=False, + error=str(e), + )) + + # Display summary table + console.print(f"[bold]Schema changes for {len(namespaces)} namespace(s):[/bold]\n") + display_batch_summary(results, dry_run=True) + + # Check for conflicts + if has_any_conflicts: + if continue_on_error: + console.print("\n[yellow]Warning: Some namespaces have type conflicts and will be skipped.[/yellow]") + console.print("[dim]Use --dry-run to see which namespaces have conflicts.[/dim]") + else: + console.print("\n[red]Error: Some namespaces have type conflicts.[/red]") + console.print("[red]Changing an existing attribute's type is not allowed.[/red]") + console.print("[dim]Fix conflicts or use --continue-on-error to skip them.[/dim]") + sys.exit(1) + + # Check if there are any changes + if not has_any_changes: + console.print("\n[green]All namespaces are already up to date, no changes needed.[/green]") + return + + # Dry run stops here + if dry_run: + console.print("\n[dim]Dry run mode - no changes applied[/dim]") + return + + # Count how many namespaces will be updated + to_update = [r for r in results if r.additions > 0 and r.conflicts == 0 and r.error is None] + + if not to_update: + console.print("\n[green]No namespaces need updates.[/green]") + return + + # Confirm unless --yes + if not yes: + confirm = click.confirm(f"\nApply schema to {len(to_update)} namespace(s)?", default=False) + if not confirm: + console.print("[yellow]Aborted[/yellow]") + return + + # Phase 2: Apply schema to each namespace + console.print(f"\n[dim]Applying schema to {len(to_update)} namespace(s)...[/dim]\n") + + success_count = 0 + fail_count = 0 + + for result in results: + if result.additions == 0 or result.conflicts > 0 or result.error is not None: + continue + + try: + ns = get_namespace(result.namespace, region) + ns.write( + upsert_rows=[{"id": "__schema_placeholder__"}], + schema=new_schema, + ) + result.success = True + success_count += 1 + except Exception as e: + result.success = False + result.error = str(e) + fail_count += 1 + + # Display final results + console.print("[bold]Results:[/bold]\n") + display_batch_summary(results, dry_run=False) + + # Summary message + if fail_count == 0: + console.print(f"\n[green]Successfully applied schema to {success_count} namespace(s)[/green]") + else: + console.print(f"\n[yellow]Applied schema to {success_count} namespace(s), {fail_count} failed[/yellow]") + sys.exit(1) + + +@schema.command("apply", context_settings={"help_option_names": ["-h", "--help"]}) +@click.option("-n", "--namespace", help="Target namespace to apply schema to") +@click.option("--prefix", help="Apply to all namespaces matching this prefix") +@click.option("--all", "apply_all", is_flag=True, help="Apply to all namespaces") +@click.option("-f", "--file", "schema_file", required=True, help="JSON file containing schema definition") +@click.option("-r", "--region", help="Override the region (e.g., aws-us-east-1, gcp-us-central1)") +@click.option("--dry-run", is_flag=True, help="Show diff only, don't apply changes") +@click.option("-y", "--yes", is_flag=True, help="Skip confirmation prompt") +@click.option("--continue-on-error", is_flag=True, help="Continue applying to other namespaces when conflicts occur (batch mode only)") +@click.pass_context +def schema_apply( + ctx: click.Context, + namespace: str | None, + prefix: str | None, + apply_all: bool, + schema_file: str, + region: str | None, + dry_run: bool, + yes: bool, + continue_on_error: bool, +) -> None: + """Apply a schema from a JSON file to namespace(s). + + Shows a diff of schema changes before applying. Type changes to existing + attributes are not allowed and will be flagged as conflicts. + + Use -n/--namespace for a single namespace, --prefix to apply to all + namespaces matching a prefix, or --all to apply to all namespaces. + """ + # Validate options: must have exactly one of namespace, prefix, or all + mode_count = sum([bool(namespace), bool(prefix), apply_all]) + + if mode_count > 1: + console.print("[red]Error: Cannot use more than one of --namespace, --prefix, and --all[/red]") + console.print("[dim]Use -n/--namespace for a single namespace, --prefix for prefix match, or --all for all namespaces[/dim]") + sys.exit(1) + + if mode_count == 0: + console.print("[red]Error: Must specify one of --namespace, --prefix, or --all[/red]") + console.print("[dim]Use -n/--namespace for a single namespace, --prefix for prefix match, or --all for all namespaces[/dim]") + sys.exit(1) + + # Load schema from file + new_schema = load_schema_file(schema_file) + + if not new_schema: + console.print("[yellow]Schema file is empty, nothing to apply[/yellow]") + return + + if namespace: + # Single namespace mode + apply_schema_to_single_namespace(namespace, new_schema, region, dry_run, yes) + elif prefix: + # Prefix mode - batch apply + namespaces = list_namespaces_by_prefix(prefix, region) + + if not namespaces: + console.print(f"[yellow]No namespaces found matching prefix: {prefix}[/yellow]") + return + + console.print(f"[dim]Found {len(namespaces)} namespace(s) matching prefix '{prefix}'[/dim]") + apply_schema_to_multiple_namespaces(namespaces, new_schema, region, dry_run, yes, continue_on_error) + else: + # All namespaces mode + namespaces = list_all_namespaces(region) + + if not namespaces: + console.print("[yellow]No namespaces found[/yellow]") + return + + console.print(f"[dim]Found {len(namespaces)} namespace(s)[/dim]") + apply_schema_to_multiple_namespaces(namespaces, new_schema, region, dry_run, yes, continue_on_error) + + +def get_namespace_row_count(ns) -> int | None: + """Get the row count for a namespace. + + Args: + ns: The turbopuffer namespace object + + Returns: + The number of rows, or None if namespace doesn't exist + """ + try: + metadata = ns.metadata() + return metadata.approx_count if hasattr(metadata, "approx_count") else 0 + except Exception: + return None + + +def display_schema_for_copy(schema_dict: dict[str, object], source: str, target: str) -> None: + """Display the schema that will be copied. + + Args: + schema_dict: The schema dictionary + source: Source namespace name + target: Target namespace name + """ + console.print(f"\n[bold]Copying schema from:[/bold] {source}") + console.print(f"[bold]Creating namespace: [/bold] {target}") + console.print("\n[bold]Schema:[/bold]") + + if not schema_dict: + console.print("[dim] (no schema attributes)[/dim]") + else: + for attr, attr_type in sorted(schema_dict.items()): + type_display = schema_type_for_display(attr_type) + console.print(f" {attr}: {type_display}") + + console.print("\n[dim]Note: A placeholder row will be created to initialize the namespace.[/dim]\n") + + +@schema.command("copy", context_settings={"help_option_names": ["-h", "--help"]}) +@click.option("-n", "--namespace", required=True, help="Source namespace to copy schema from") +@click.option("--to", "target", required=True, help="Target namespace name (must not exist or be empty)") +@click.option("-r", "--region", help="Override the region (e.g., aws-us-east-1, gcp-us-central1)") +@click.option("-y", "--yes", is_flag=True, help="Skip confirmation prompt") +@click.pass_context +def schema_copy( + ctx: click.Context, + namespace: str, + target: str, + region: str | None, + yes: bool, +) -> None: + """Copy schema from a source namespace to a new target namespace. + + The target namespace must be empty or non-existent. A placeholder row + will be created to initialize the namespace with the schema. + """ + # Get source namespace and schema + source_ns = get_namespace(namespace, region) + source_schema = get_current_schema(source_ns) + + if source_schema is None: + console.print(f"[red]Error: Source namespace '{namespace}' has no schema or does not exist[/red]") + sys.exit(1) + + # Check target namespace + target_ns = get_namespace(target, region) + target_row_count = get_namespace_row_count(target_ns) + + if target_row_count is not None and target_row_count > 0: + console.print(f"[red]Error: Target namespace '{target}' already has {target_row_count} row(s)[/red]") + console.print("[red]Target namespace must be empty or non-existent[/red]") + sys.exit(1) + + # Display what will be copied + display_schema_for_copy(source_schema, namespace, target) + + # Confirm unless --yes + if not yes: + confirm = click.confirm("Copy schema to target namespace?", default=False) + if not confirm: + console.print("[yellow]Aborted[/yellow]") + return + + # Create target namespace with schema + try: + console.print(f"[dim]Creating namespace {target} with schema...[/dim]") + target_ns.write( + upsert_rows=[{"id": "__schema_placeholder__"}], + schema=source_schema, + ) + console.print(f"[green]Successfully created namespace '{target}' with schema from '{namespace}'[/green]") + except Exception as e: + console.print(f"[red]Error creating namespace: {e}[/red]") + sys.exit(1) diff --git a/src/tpuff/commands/search.py b/src/tpuff/commands/search.py index 8e01052..aa6b0d8 100644 --- a/src/tpuff/commands/search.py +++ b/src/tpuff/commands/search.py @@ -13,7 +13,6 @@ from tpuff.utils.debug import debug_log from tpuff.utils.embeddings import embedding_generator - console = Console() diff --git a/src/tpuff/utils/metadata_fetcher.py b/src/tpuff/utils/metadata_fetcher.py index 8c3e869..3abacca 100644 --- a/src/tpuff/utils/metadata_fetcher.py +++ b/src/tpuff/utils/metadata_fetcher.py @@ -1,11 +1,11 @@ """Namespace metadata fetching utilities.""" +import concurrent.futures from dataclasses import dataclass from datetime import datetime from typing import Any -import concurrent.futures -from tpuff.client import get_turbopuffer_client, get_namespace, clear_client_cache +from tpuff.client import clear_client_cache, get_namespace, get_turbopuffer_client from tpuff.utils.debug import debug_log from tpuff.utils.regions import TURBOPUFFER_REGIONS diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..9947891 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +# Tests for tpuff CLI diff --git a/tests/test_schema.py b/tests/test_schema.py new file mode 100644 index 0000000..767f0ba --- /dev/null +++ b/tests/test_schema.py @@ -0,0 +1,1207 @@ +"""Tests for schema commands.""" + +from io import StringIO +from unittest.mock import MagicMock, patch + +import pytest +from click.testing import CliRunner +from rich.console import Console + +from tpuff.commands.schema import ( + SchemaDiff, + compute_schema_diff, + display_schema_diff, + load_schema_file, + normalize_schema_type, + schema, + schema_type_for_display, + validate_schema, + validate_schema_type, +) + + +class TestNormalizeSchemaType: + """Tests for normalize_schema_type function.""" + + def test_simple_string_type(self): + assert normalize_schema_type("string") == "string" + assert normalize_schema_type("uint64") == "uint64" + + def test_dict_type(self): + result = normalize_schema_type({"type": "string", "full_text_search": True}) + # Should be JSON with sorted keys + assert result == '{"full_text_search": true, "type": "string"}' + + def test_pydantic_model(self): + mock_model = MagicMock() + mock_model.model_dump.return_value = {"type": "string", "filterable": False} + result = normalize_schema_type(mock_model) + assert result == '{"filterable": false, "type": "string"}' + + +class TestSchemaTypeForDisplay: + """Tests for schema_type_for_display function.""" + + def test_simple_string(self): + assert schema_type_for_display("string") == "string" + assert schema_type_for_display("[1536]f32") == "[1536]f32" + + def test_simple_dict_with_only_type(self): + result = schema_type_for_display({"type": "string"}) + assert result == "string" + + def test_complex_dict(self): + result = schema_type_for_display({"type": "string", "full_text_search": True}) + assert "string" in result + assert "full_text_search" in result + + +class TestValidateSchemaType: + """Tests for validate_schema_type function.""" + + def test_valid_simple_types(self): + for type_name in ["string", "uint64", "uuid", "bool"]: + errors = validate_schema_type("test_attr", type_name) + assert errors == [], f"Expected no errors for {type_name}" + + def test_valid_vector_types(self): + errors = validate_schema_type("vec", "[1536]f32") + assert errors == [] + errors = validate_schema_type("vec", "[768]f16") + assert errors == [] + + def test_invalid_simple_type(self): + errors = validate_schema_type("test_attr", "invalid_type") + assert len(errors) == 1 + assert "invalid type" in errors[0].lower() + + def test_valid_complex_type(self): + errors = validate_schema_type("content", {"type": "string", "full_text_search": True}) + assert errors == [] + + def test_complex_type_missing_type_key(self): + errors = validate_schema_type("content", {"full_text_search": True}) + assert len(errors) == 1 + assert "'type' key" in errors[0] + + def test_complex_type_unknown_keys(self): + errors = validate_schema_type("content", {"type": "string", "unknown_key": True}) + assert len(errors) == 1 + assert "unknown keys" in errors[0].lower() + + +class TestValidateSchema: + """Tests for validate_schema function.""" + + def test_valid_schema(self): + schema = { + "content": "string", + "vector": "[1536]f32", + "timestamp": "uint64", + } + errors = validate_schema(schema) + assert errors == [] + + def test_invalid_attribute_type(self): + schema = {"content": "invalid"} + errors = validate_schema(schema) + assert len(errors) == 1 + + def test_empty_attribute_name(self): + schema = {"": "string"} + errors = validate_schema(schema) + assert len(errors) == 1 + assert "empty" in errors[0].lower() + + +class TestComputeSchemaDiff: + """Tests for compute_schema_diff function.""" + + def test_all_new_attributes(self): + diff = compute_schema_diff(None, {"field1": "string", "field2": "uint64"}) + assert len(diff.additions) == 2 + assert len(diff.unchanged) == 0 + assert len(diff.conflicts) == 0 + + def test_all_unchanged(self): + current = {"field1": "string", "field2": "uint64"} + new = {"field1": "string", "field2": "uint64"} + diff = compute_schema_diff(current, new) + assert len(diff.additions) == 0 + assert len(diff.unchanged) == 2 + assert len(diff.conflicts) == 0 + + def test_mixed_changes(self): + current = {"field1": "string"} + new = {"field1": "string", "field2": "uint64"} + diff = compute_schema_diff(current, new) + assert len(diff.additions) == 1 + assert "field2" in diff.additions + assert len(diff.unchanged) == 1 + assert "field1" in diff.unchanged + + def test_type_conflict(self): + current = {"field1": "string"} + new = {"field1": "uint64"} + diff = compute_schema_diff(current, new) + assert len(diff.conflicts) == 1 + assert "field1" in diff.conflicts + assert diff.has_conflicts + + def test_has_changes_property(self): + # No changes + diff = SchemaDiff(unchanged={"a": "string"}) + assert not diff.has_changes + + # With additions + diff = SchemaDiff(additions={"a": "string"}) + assert diff.has_changes + + # With conflicts + diff = SchemaDiff(conflicts={"a": ("string", "uint64")}) + assert diff.has_changes + + +class TestDisplaySchemaDiff: + """Tests for display_schema_diff function.""" + + def test_no_schema(self): + diff = SchemaDiff() + console = Console(file=StringIO(), force_terminal=True) + with patch("tpuff.commands.schema.console", console): + display_schema_diff(diff, "test-ns") + output = console.file.getvalue() + assert "No schema attributes" in output + + def test_displays_additions_in_green(self): + diff = SchemaDiff(additions={"new_field": "string"}) + console = Console(file=StringIO(), force_terminal=True) + with patch("tpuff.commands.schema.console", console): + display_schema_diff(diff, "test-ns") + output = console.file.getvalue() + assert "new_field" in output + # Rich adds ANSI codes around (new), so check for "new" instead + assert "new" in output + + def test_displays_conflicts_in_red(self): + diff = SchemaDiff(conflicts={"conflict_field": ("string", "uint64")}) + console = Console(file=StringIO(), force_terminal=True) + with patch("tpuff.commands.schema.console", console): + display_schema_diff(diff, "test-ns") + output = console.file.getvalue() + assert "conflict_field" in output + assert "type change not allowed" in output + + +class TestLoadSchemaFile: + """Tests for load_schema_file function.""" + + def test_file_not_found(self, tmp_path): + from click import ClickException + + with pytest.raises(ClickException) as exc_info: + load_schema_file(str(tmp_path / "nonexistent.json")) + assert "not found" in str(exc_info.value) + + def test_invalid_json(self, tmp_path): + from click import ClickException + + schema_file = tmp_path / "invalid.json" + schema_file.write_text("{ not valid json }") + with pytest.raises(ClickException) as exc_info: + load_schema_file(str(schema_file)) + assert "Invalid JSON" in str(exc_info.value) + + def test_not_an_object(self, tmp_path): + from click import ClickException + + schema_file = tmp_path / "array.json" + schema_file.write_text('["a", "b"]') + with pytest.raises(ClickException) as exc_info: + load_schema_file(str(schema_file)) + assert "JSON object" in str(exc_info.value) + + def test_valid_schema_file(self, tmp_path): + schema_file = tmp_path / "schema.json" + schema_file.write_text('{"content": "string", "timestamp": "uint64"}') + result = load_schema_file(str(schema_file)) + assert result == {"content": "string", "timestamp": "uint64"} + + +class TestSchemaApplyCommand: + """Tests for schema apply CLI command.""" + + @pytest.fixture + def runner(self): + return CliRunner() + + @pytest.fixture + def valid_schema_file(self, tmp_path): + schema_file = tmp_path / "schema.json" + schema_file.write_text('{"content": "string", "new_field": "uint64"}') + return str(schema_file) + + def test_dry_run_new_namespace(self, runner, valid_schema_file): + """Test dry-run on a namespace that doesn't exist (all additions).""" + mock_ns = MagicMock() + mock_ns.metadata.side_effect = Exception("Namespace not found") + + with patch("tpuff.commands.schema.get_namespace", return_value=mock_ns): + result = runner.invoke( + schema, ["apply", "-n", "test-ns", "-f", valid_schema_file, "--dry-run"] + ) + + assert result.exit_code == 0 + assert "test-ns" in result.output + assert "content" in result.output + assert "new_field" in result.output + assert "(new)" in result.output + assert "Dry run mode" in result.output + + def test_dry_run_existing_namespace(self, runner, valid_schema_file): + """Test dry-run on an existing namespace (mixed changes).""" + mock_metadata = MagicMock() + mock_metadata.schema = {"content": "string"} # existing field + mock_ns = MagicMock() + mock_ns.metadata.return_value = mock_metadata + + with patch("tpuff.commands.schema.get_namespace", return_value=mock_ns): + result = runner.invoke( + schema, ["apply", "-n", "test-ns", "-f", valid_schema_file, "--dry-run"] + ) + + assert result.exit_code == 0 + # content should be unchanged, new_field should be new + assert "content" in result.output + assert "new_field" in result.output + assert "(new)" in result.output + assert "Dry run mode" in result.output + + def test_dry_run_no_changes(self, runner, tmp_path): + """Test dry-run when schema is already up to date.""" + schema_file = tmp_path / "schema.json" + schema_file.write_text('{"content": "string"}') + + mock_metadata = MagicMock() + mock_metadata.schema = {"content": "string"} + mock_ns = MagicMock() + mock_ns.metadata.return_value = mock_metadata + + with patch("tpuff.commands.schema.get_namespace", return_value=mock_ns): + result = runner.invoke( + schema, ["apply", "-n", "test-ns", "-f", str(schema_file), "--dry-run"] + ) + + assert result.exit_code == 0 + assert "already up to date" in result.output + + def test_type_conflict_exits_with_error(self, runner, tmp_path): + """Test that type conflicts cause the command to exit with error.""" + schema_file = tmp_path / "schema.json" + schema_file.write_text('{"content": "uint64"}') # trying to change string to uint64 + + mock_metadata = MagicMock() + mock_metadata.schema = {"content": "string"} # existing type is string + mock_ns = MagicMock() + mock_ns.metadata.return_value = mock_metadata + + with patch("tpuff.commands.schema.get_namespace", return_value=mock_ns): + result = runner.invoke( + schema, ["apply", "-n", "test-ns", "-f", str(schema_file), "--dry-run"] + ) + + assert result.exit_code == 1 + assert "type change not allowed" in result.output.lower() + assert "conflict" in result.output.lower() + + def test_invalid_schema_file_exits_with_error(self, runner, tmp_path): + """Test that invalid schema files cause proper error messages.""" + schema_file = tmp_path / "schema.json" + schema_file.write_text('{"content": "invalid_type"}') + + result = runner.invoke( + schema, ["apply", "-n", "test-ns", "-f", str(schema_file), "--dry-run"] + ) + + assert result.exit_code != 0 + assert "invalid" in result.output.lower() + + def test_apply_with_yes_flag(self, runner, valid_schema_file): + """Test that --yes flag skips confirmation prompt.""" + mock_metadata = MagicMock() + mock_metadata.schema = {} + mock_ns = MagicMock() + mock_ns.metadata.return_value = mock_metadata + + with patch("tpuff.commands.schema.get_namespace", return_value=mock_ns): + result = runner.invoke( + schema, ["apply", "-n", "test-ns", "-f", valid_schema_file, "--yes"] + ) + + assert result.exit_code == 0 + assert "Successfully applied" in result.output + mock_ns.write.assert_called_once() + + def test_apply_without_yes_aborts_on_no(self, runner, valid_schema_file): + """Test that confirmation prompt works (user says no).""" + mock_metadata = MagicMock() + mock_metadata.schema = {} + mock_ns = MagicMock() + mock_ns.metadata.return_value = mock_metadata + + with patch("tpuff.commands.schema.get_namespace", return_value=mock_ns): + result = runner.invoke( + schema, ["apply", "-n", "test-ns", "-f", valid_schema_file], input="n\n" + ) + + assert result.exit_code == 0 + assert "Aborted" in result.output + mock_ns.write.assert_not_called() + + +class TestSchemaCopyCommand: + """Tests for schema copy CLI command.""" + + @pytest.fixture + def runner(self): + return CliRunner() + + def test_copy_source_not_found(self, runner): + """Test error when source namespace doesn't exist.""" + mock_ns = MagicMock() + mock_ns.metadata.side_effect = Exception("Namespace not found") + + with patch("tpuff.commands.schema.get_namespace", return_value=mock_ns): + result = runner.invoke( + schema, ["copy", "-n", "source-ns", "--to", "target-ns"] + ) + + assert result.exit_code == 1 + assert "no schema or does not exist" in result.output.lower() + + def test_copy_target_has_data(self, runner): + """Test error when target namespace already has data.""" + # Source namespace has schema + source_metadata = MagicMock() + source_metadata.schema = {"content": "string"} + + # Target namespace has data + target_metadata = MagicMock() + target_metadata.schema = {"content": "string"} + target_metadata.approx_count = 100 + + def get_ns(name, region=None): + mock = MagicMock() + if name == "source-ns": + mock.metadata.return_value = source_metadata + else: + mock.metadata.return_value = target_metadata + return mock + + with patch("tpuff.commands.schema.get_namespace", side_effect=get_ns): + result = runner.invoke( + schema, ["copy", "-n", "source-ns", "--to", "target-ns"] + ) + + assert result.exit_code == 1 + assert "already has" in result.output.lower() + assert "100" in result.output + + def test_copy_success_with_yes(self, runner): + """Test successful copy with --yes flag.""" + source_metadata = MagicMock() + source_metadata.schema = {"content": "string", "timestamp": "uint64"} + + # Target namespace doesn't exist (metadata raises exception) + target_mock = MagicMock() + target_mock.metadata.side_effect = Exception("Namespace not found") + + def get_ns(name, region=None): + if name == "source-ns": + mock = MagicMock() + mock.metadata.return_value = source_metadata + return mock + else: + return target_mock + + with patch("tpuff.commands.schema.get_namespace", side_effect=get_ns): + result = runner.invoke( + schema, ["copy", "-n", "source-ns", "--to", "target-ns", "--yes"] + ) + + assert result.exit_code == 0 + assert "Successfully created" in result.output + target_mock.write.assert_called_once() + call_kwargs = target_mock.write.call_args[1] + assert call_kwargs["schema"] == {"content": "string", "timestamp": "uint64"} + assert call_kwargs["upsert_rows"] == [{"id": "__schema_placeholder__"}] + + def test_copy_aborts_on_no(self, runner): + """Test that copy aborts when user says no.""" + source_metadata = MagicMock() + source_metadata.schema = {"content": "string"} + + # Target namespace doesn't exist + target_mock = MagicMock() + target_mock.metadata.side_effect = Exception("Namespace not found") + + def get_ns(name, region=None): + if name == "source-ns": + mock = MagicMock() + mock.metadata.return_value = source_metadata + return mock + else: + return target_mock + + with patch("tpuff.commands.schema.get_namespace", side_effect=get_ns): + result = runner.invoke( + schema, ["copy", "-n", "source-ns", "--to", "target-ns"], input="n\n" + ) + + assert result.exit_code == 0 + assert "Aborted" in result.output + target_mock.write.assert_not_called() + + def test_copy_displays_schema(self, runner): + """Test that copy displays the schema being copied.""" + source_metadata = MagicMock() + source_metadata.schema = {"content": "string", "vector": "[1536]f32"} + + # Target namespace doesn't exist + target_mock = MagicMock() + target_mock.metadata.side_effect = Exception("Namespace not found") + + def get_ns(name, region=None): + if name == "source-ns": + mock = MagicMock() + mock.metadata.return_value = source_metadata + return mock + else: + return target_mock + + with patch("tpuff.commands.schema.get_namespace", side_effect=get_ns): + result = runner.invoke( + schema, ["copy", "-n", "source-ns", "--to", "target-ns"], input="n\n" + ) + + assert "source-ns" in result.output + assert "target-ns" in result.output + assert "content" in result.output + assert "vector" in result.output + assert "placeholder row" in result.output.lower() + + def test_copy_target_does_not_exist(self, runner): + """Test copy when target namespace doesn't exist (metadata fails).""" + source_metadata = MagicMock() + source_metadata.schema = {"content": "string"} + + target_mock = MagicMock() + target_mock.metadata.side_effect = Exception("Namespace not found") + + def get_ns(name, region=None): + if name == "source-ns": + mock = MagicMock() + mock.metadata.return_value = source_metadata + return mock + else: + return target_mock + + with patch("tpuff.commands.schema.get_namespace", side_effect=get_ns): + result = runner.invoke( + schema, ["copy", "-n", "source-ns", "--to", "target-ns", "--yes"] + ) + + assert result.exit_code == 0 + assert "Successfully created" in result.output + target_mock.write.assert_called_once() + + +class TestSchemaApplyBatchCommand: + """Tests for schema apply CLI command with --prefix option.""" + + @pytest.fixture + def runner(self): + return CliRunner() + + @pytest.fixture + def valid_schema_file(self, tmp_path): + """Create a valid schema file for testing.""" + schema_file = tmp_path / "schema.json" + schema_file.write_text('{"content": "string", "category": "string"}') + return str(schema_file) + + def test_prefix_requires_no_namespace(self, runner, valid_schema_file): + """Test that --prefix and --namespace cannot be used together.""" + result = runner.invoke( + schema, + ["apply", "-n", "test-ns", "--prefix", "prod", "-f", valid_schema_file], + ) + + assert result.exit_code == 1 + assert "Cannot use more than one" in result.output + + def test_requires_namespace_or_prefix_or_all(self, runner, valid_schema_file): + """Test that --namespace, --prefix, or --all is required.""" + result = runner.invoke( + schema, ["apply", "-f", valid_schema_file] + ) + + assert result.exit_code == 1 + assert "Must specify one of" in result.output + + def test_prefix_no_matching_namespaces(self, runner, valid_schema_file): + """Test that prefix with no matches shows appropriate message.""" + mock_client = MagicMock() + mock_client.namespaces.return_value = [] + + with patch("tpuff.commands.schema.get_turbopuffer_client", return_value=mock_client): + result = runner.invoke( + schema, ["apply", "--prefix", "nonexistent", "-f", valid_schema_file] + ) + + assert result.exit_code == 0 + assert "No namespaces found matching prefix" in result.output + + def test_prefix_dry_run_multiple_namespaces(self, runner, valid_schema_file): + """Test dry run with prefix showing summary of multiple namespaces.""" + # Mock namespaces + mock_ns1 = MagicMock() + mock_ns1.id = "prod-users" + mock_ns2 = MagicMock() + mock_ns2.id = "prod-orders" + mock_ns3 = MagicMock() + mock_ns3.id = "test-ns" # Won't match prefix + + mock_client = MagicMock() + mock_client.namespaces.return_value = [mock_ns1, mock_ns2, mock_ns3] + + # Mock metadata for each namespace + def mock_get_ns(name, region=None): + mock = MagicMock() + mock_metadata = MagicMock() + mock_metadata.schema = {"content": "string"} # Existing schema + mock.metadata.return_value = mock_metadata + return mock + + with patch("tpuff.commands.schema.get_turbopuffer_client", return_value=mock_client): + with patch("tpuff.commands.schema.get_namespace", side_effect=mock_get_ns): + result = runner.invoke( + schema, ["apply", "--prefix", "prod", "-f", valid_schema_file, "--dry-run"] + ) + + assert result.exit_code == 0 + assert "Found 2 namespace(s)" in result.output + assert "prod-users" in result.output + assert "prod-orders" in result.output + assert "test-ns" not in result.output + assert "Dry run mode" in result.output + + def test_prefix_apply_with_yes(self, runner, valid_schema_file): + """Test batch apply with --yes flag.""" + mock_ns1 = MagicMock() + mock_ns1.id = "prod-users" + mock_ns2 = MagicMock() + mock_ns2.id = "prod-orders" + + mock_client = MagicMock() + mock_client.namespaces.return_value = [mock_ns1, mock_ns2] + + namespace_mocks = {} + + def mock_get_ns(name, region=None): + if name not in namespace_mocks: + mock = MagicMock() + mock_metadata = MagicMock() + mock_metadata.schema = {} # Empty schema (new namespaces) + mock.metadata.return_value = mock_metadata + namespace_mocks[name] = mock + return namespace_mocks[name] + + with patch("tpuff.commands.schema.get_turbopuffer_client", return_value=mock_client): + with patch("tpuff.commands.schema.get_namespace", side_effect=mock_get_ns): + result = runner.invoke( + schema, ["apply", "--prefix", "prod", "-f", valid_schema_file, "--yes"] + ) + + assert result.exit_code == 0 + assert "Successfully applied schema to 2 namespace(s)" in result.output + # Verify both namespaces had write called + for mock in namespace_mocks.values(): + mock.write.assert_called_once() + + def test_prefix_conflicts_block_all(self, runner, tmp_path): + """Test that conflicts in any namespace prevent all applies.""" + schema_file = tmp_path / "schema.json" + schema_file.write_text('{"content": "uint64"}') # Change type + + mock_ns1 = MagicMock() + mock_ns1.id = "prod-users" + + mock_client = MagicMock() + mock_client.namespaces.return_value = [mock_ns1] + + def mock_get_ns(name, region=None): + mock = MagicMock() + mock_metadata = MagicMock() + mock_metadata.schema = {"content": "string"} # Existing type differs + mock.metadata.return_value = mock_metadata + return mock + + with patch("tpuff.commands.schema.get_turbopuffer_client", return_value=mock_client): + with patch("tpuff.commands.schema.get_namespace", side_effect=mock_get_ns): + result = runner.invoke( + schema, ["apply", "--prefix", "prod", "-f", str(schema_file), "--yes"] + ) + + assert result.exit_code == 1 + assert "type conflict" in result.output.lower() + + def test_prefix_some_namespaces_up_to_date(self, runner, valid_schema_file): + """Test batch apply where some namespaces are already up to date.""" + mock_ns1 = MagicMock() + mock_ns1.id = "prod-users" + mock_ns2 = MagicMock() + mock_ns2.id = "prod-orders" + + mock_client = MagicMock() + mock_client.namespaces.return_value = [mock_ns1, mock_ns2] + + namespace_mocks = {} + + def mock_get_ns(name, region=None): + if name not in namespace_mocks: + mock = MagicMock() + mock_metadata = MagicMock() + # prod-users already has the schema + if name == "prod-users": + mock_metadata.schema = {"content": "string", "category": "string"} + else: + mock_metadata.schema = {} # prod-orders needs update + mock.metadata.return_value = mock_metadata + namespace_mocks[name] = mock + return namespace_mocks[name] + + with patch("tpuff.commands.schema.get_turbopuffer_client", return_value=mock_client): + with patch("tpuff.commands.schema.get_namespace", side_effect=mock_get_ns): + result = runner.invoke( + schema, ["apply", "--prefix", "prod", "-f", valid_schema_file, "--yes"] + ) + + assert result.exit_code == 0 + # Only prod-orders should have write called (1 namespace updated) + assert "Successfully applied schema to 1 namespace(s)" in result.output + namespace_mocks["prod-orders"].write.assert_called_once() + namespace_mocks["prod-users"].write.assert_not_called() + + +class TestSchemaApplyAllCommand: + """Tests for schema apply CLI command with --all option.""" + + @pytest.fixture + def runner(self): + return CliRunner() + + @pytest.fixture + def valid_schema_file(self, tmp_path): + """Create a valid schema file for testing.""" + schema_file = tmp_path / "schema.json" + schema_file.write_text('{"content": "string", "category": "string"}') + return str(schema_file) + + def test_all_cannot_be_used_with_namespace(self, runner, valid_schema_file): + """Test that --all and --namespace cannot be used together.""" + result = runner.invoke( + schema, + ["apply", "-n", "test-ns", "--all", "-f", valid_schema_file], + ) + + assert result.exit_code == 1 + assert "Cannot use more than one" in result.output + + def test_all_cannot_be_used_with_prefix(self, runner, valid_schema_file): + """Test that --all and --prefix cannot be used together.""" + result = runner.invoke( + schema, + ["apply", "--prefix", "prod", "--all", "-f", valid_schema_file], + ) + + assert result.exit_code == 1 + assert "Cannot use more than one" in result.output + + def test_all_no_namespaces(self, runner, valid_schema_file): + """Test that --all with no namespaces shows appropriate message.""" + mock_client = MagicMock() + mock_client.namespaces.return_value = [] + + with patch("tpuff.commands.schema.get_turbopuffer_client", return_value=mock_client): + result = runner.invoke( + schema, ["apply", "--all", "-f", valid_schema_file] + ) + + assert result.exit_code == 0 + assert "No namespaces found" in result.output + + def test_all_dry_run_multiple_namespaces(self, runner, valid_schema_file): + """Test dry run with --all showing summary of all namespaces.""" + # Mock namespaces + mock_ns1 = MagicMock() + mock_ns1.id = "prod-users" + mock_ns2 = MagicMock() + mock_ns2.id = "test-ns" + mock_ns3 = MagicMock() + mock_ns3.id = "dev-orders" + + mock_client = MagicMock() + mock_client.namespaces.return_value = [mock_ns1, mock_ns2, mock_ns3] + + # Mock metadata for each namespace + def mock_get_ns(name, region=None): + mock = MagicMock() + mock_metadata = MagicMock() + mock_metadata.schema = {"content": "string"} # Existing schema + mock.metadata.return_value = mock_metadata + return mock + + with patch("tpuff.commands.schema.get_turbopuffer_client", return_value=mock_client): + with patch("tpuff.commands.schema.get_namespace", side_effect=mock_get_ns): + result = runner.invoke( + schema, ["apply", "--all", "-f", valid_schema_file, "--dry-run"] + ) + + assert result.exit_code == 0 + assert "Found 3 namespace(s)" in result.output + assert "prod-users" in result.output + assert "test-ns" in result.output + assert "dev-orders" in result.output + assert "Dry run mode" in result.output + + def test_all_apply_with_yes(self, runner, valid_schema_file): + """Test batch apply with --all and --yes flag.""" + mock_ns1 = MagicMock() + mock_ns1.id = "prod-users" + mock_ns2 = MagicMock() + mock_ns2.id = "test-ns" + + mock_client = MagicMock() + mock_client.namespaces.return_value = [mock_ns1, mock_ns2] + + namespace_mocks = {} + + def mock_get_ns(name, region=None): + if name not in namespace_mocks: + mock = MagicMock() + mock_metadata = MagicMock() + mock_metadata.schema = {} # Empty schema (new namespaces) + mock.metadata.return_value = mock_metadata + namespace_mocks[name] = mock + return namespace_mocks[name] + + with patch("tpuff.commands.schema.get_turbopuffer_client", return_value=mock_client): + with patch("tpuff.commands.schema.get_namespace", side_effect=mock_get_ns): + result = runner.invoke( + schema, ["apply", "--all", "-f", valid_schema_file, "--yes"] + ) + + assert result.exit_code == 0 + assert "Successfully applied schema to 2 namespace(s)" in result.output + namespace_mocks["prod-users"].write.assert_called_once() + namespace_mocks["test-ns"].write.assert_called_once() + + +class TestSchemaApplyContinueOnError: + """Tests for schema apply CLI command with --continue-on-error option.""" + + @pytest.fixture + def runner(self): + return CliRunner() + + def test_conflicts_without_continue_on_error_exits(self, runner, tmp_path): + """Test that conflicts cause exit without --continue-on-error.""" + schema_file = tmp_path / "schema.json" + schema_file.write_text('{"content": "uint64"}') # Type change + + mock_ns1 = MagicMock() + mock_ns1.id = "prod-users" + + mock_client = MagicMock() + mock_client.namespaces.return_value = [mock_ns1] + + def mock_get_ns(name, region=None): + mock = MagicMock() + mock_metadata = MagicMock() + mock_metadata.schema = {"content": "string"} # Existing type differs + mock.metadata.return_value = mock_metadata + return mock + + with patch("tpuff.commands.schema.get_turbopuffer_client", return_value=mock_client): + with patch("tpuff.commands.schema.get_namespace", side_effect=mock_get_ns): + result = runner.invoke( + schema, ["apply", "--prefix", "prod", "-f", str(schema_file), "--yes"] + ) + + assert result.exit_code == 1 + assert "type conflict" in result.output.lower() + assert "continue-on-error" in result.output.lower() + + def test_conflicts_with_continue_on_error_skips_conflicted(self, runner, tmp_path): + """Test that --continue-on-error skips namespaces with conflicts.""" + schema_file = tmp_path / "schema.json" + schema_file.write_text('{"content": "uint64", "new_field": "string"}') + + mock_ns1 = MagicMock() + mock_ns1.id = "prod-users" # Will have conflict + mock_ns2 = MagicMock() + mock_ns2.id = "prod-orders" # Will be applied + + mock_client = MagicMock() + mock_client.namespaces.return_value = [mock_ns1, mock_ns2] + + namespace_mocks = {} + + def mock_get_ns(name, region=None): + if name not in namespace_mocks: + mock = MagicMock() + mock_metadata = MagicMock() + if name == "prod-users": + # This namespace has a type conflict (content: string -> uint64) + mock_metadata.schema = {"content": "string"} + else: + # This namespace has no conflicts (new namespace) + mock_metadata.schema = {} + mock.metadata.return_value = mock_metadata + namespace_mocks[name] = mock + return namespace_mocks[name] + + with patch("tpuff.commands.schema.get_turbopuffer_client", return_value=mock_client): + with patch("tpuff.commands.schema.get_namespace", side_effect=mock_get_ns): + result = runner.invoke( + schema, ["apply", "--prefix", "prod", "-f", str(schema_file), "--yes", "--continue-on-error"] + ) + + assert result.exit_code == 0 + assert "Warning" in result.output # Warning about conflicts + assert "Successfully applied schema to 1 namespace(s)" in result.output + # Only prod-orders should have write called + namespace_mocks["prod-orders"].write.assert_called_once() + namespace_mocks["prod-users"].write.assert_not_called() + + def test_continue_on_error_with_all_flag(self, runner, tmp_path): + """Test --continue-on-error works with --all flag.""" + schema_file = tmp_path / "schema.json" + schema_file.write_text('{"content": "uint64"}') + + mock_ns1 = MagicMock() + mock_ns1.id = "ns-with-conflict" + mock_ns2 = MagicMock() + mock_ns2.id = "ns-new" + + mock_client = MagicMock() + mock_client.namespaces.return_value = [mock_ns1, mock_ns2] + + namespace_mocks = {} + + def mock_get_ns(name, region=None): + if name not in namespace_mocks: + mock = MagicMock() + mock_metadata = MagicMock() + if name == "ns-with-conflict": + mock_metadata.schema = {"content": "string"} # Conflict + else: + mock_metadata.schema = {} # No conflict + mock.metadata.return_value = mock_metadata + namespace_mocks[name] = mock + return namespace_mocks[name] + + with patch("tpuff.commands.schema.get_turbopuffer_client", return_value=mock_client): + with patch("tpuff.commands.schema.get_namespace", side_effect=mock_get_ns): + result = runner.invoke( + schema, ["apply", "--all", "-f", str(schema_file), "--yes", "--continue-on-error"] + ) + + assert result.exit_code == 0 + assert "Warning" in result.output + namespace_mocks["ns-new"].write.assert_called_once() + namespace_mocks["ns-with-conflict"].write.assert_not_called() + + def test_continue_on_error_all_have_conflicts(self, runner, tmp_path): + """Test --continue-on-error when all namespaces have conflicts.""" + schema_file = tmp_path / "schema.json" + schema_file.write_text('{"content": "uint64"}') + + mock_ns1 = MagicMock() + mock_ns1.id = "ns1" + mock_ns2 = MagicMock() + mock_ns2.id = "ns2" + + mock_client = MagicMock() + mock_client.namespaces.return_value = [mock_ns1, mock_ns2] + + def mock_get_ns(name, region=None): + mock = MagicMock() + mock_metadata = MagicMock() + mock_metadata.schema = {"content": "string"} # All have conflicts + mock.metadata.return_value = mock_metadata + return mock + + with patch("tpuff.commands.schema.get_turbopuffer_client", return_value=mock_client): + with patch("tpuff.commands.schema.get_namespace", side_effect=mock_get_ns): + result = runner.invoke( + schema, ["apply", "--prefix", "ns", "-f", str(schema_file), "--yes", "--continue-on-error"] + ) + + # Should exit successfully but with no updates + assert result.exit_code == 0 + assert "Warning" in result.output + assert "No namespaces need updates" in result.output + + def test_continue_on_error_ignored_for_single_namespace(self, runner, tmp_path): + """Test that --continue-on-error has no effect for single namespace mode.""" + schema_file = tmp_path / "schema.json" + schema_file.write_text('{"content": "uint64"}') + + mock_ns = MagicMock() + mock_metadata = MagicMock() + mock_metadata.schema = {"content": "string"} # Type conflict + mock_ns.metadata.return_value = mock_metadata + + with patch("tpuff.commands.schema.get_namespace", return_value=mock_ns): + result = runner.invoke( + schema, ["apply", "-n", "test-ns", "-f", str(schema_file), "--continue-on-error"] + ) + + # Single namespace mode should still fail on conflicts + assert result.exit_code == 1 + assert "type change not allowed" in result.output.lower() + + +class TestSchemaApplyBatchErrorHandling: + """Tests for error handling during batch schema apply operations.""" + + @pytest.fixture + def runner(self): + return CliRunner() + + @pytest.fixture + def valid_schema_file(self, tmp_path): + """Create a valid schema file for testing.""" + schema_file = tmp_path / "schema.json" + schema_file.write_text('{"content": "string", "category": "string"}') + return str(schema_file) + + def test_write_error_during_apply_reports_failure(self, runner, valid_schema_file): + """Test that write errors during apply are reported as failures.""" + mock_ns1 = MagicMock() + mock_ns1.id = "prod-users" + mock_ns2 = MagicMock() + mock_ns2.id = "prod-orders" + + mock_client = MagicMock() + mock_client.namespaces.return_value = [mock_ns1, mock_ns2] + + namespace_mocks = {} + + def mock_get_ns(name, region=None): + if name not in namespace_mocks: + mock = MagicMock() + mock_metadata = MagicMock() + mock_metadata.schema = {} # Empty schema (new namespace) + mock.metadata.return_value = mock_metadata + # prod-users write will fail + if name == "prod-users": + mock.write.side_effect = Exception("API error: connection failed") + namespace_mocks[name] = mock + return namespace_mocks[name] + + with patch("tpuff.commands.schema.get_turbopuffer_client", return_value=mock_client): + with patch("tpuff.commands.schema.get_namespace", side_effect=mock_get_ns): + result = runner.invoke( + schema, ["apply", "--prefix", "prod", "-f", valid_schema_file, "--yes"] + ) + + # Should exit with failure code + assert result.exit_code == 1 + assert "1 failed" in result.output.lower() + # prod-orders should still succeed + namespace_mocks["prod-orders"].write.assert_called_once() + + def test_get_namespace_error_during_analysis_shows_error(self, runner, valid_schema_file): + """Test that get_namespace errors during analysis are shown in summary.""" + mock_ns1 = MagicMock() + mock_ns1.id = "prod-users" + mock_ns2 = MagicMock() + mock_ns2.id = "prod-orders" + + mock_client = MagicMock() + mock_client.namespaces.return_value = [mock_ns1, mock_ns2] + + namespace_mocks = {} + + def mock_get_ns(name, region=None): + # prod-users get_namespace will fail entirely + if name == "prod-users": + raise Exception("Network timeout") + if name not in namespace_mocks: + mock = MagicMock() + mock_metadata = MagicMock() + mock_metadata.schema = {} + mock.metadata.return_value = mock_metadata + namespace_mocks[name] = mock + return namespace_mocks[name] + + with patch("tpuff.commands.schema.get_turbopuffer_client", return_value=mock_client): + with patch("tpuff.commands.schema.get_namespace", side_effect=mock_get_ns): + result = runner.invoke( + schema, ["apply", "--prefix", "prod", "-f", valid_schema_file, "--yes"] + ) + + assert result.exit_code == 0 + assert "error" in result.output.lower() + assert "Network timeout" in result.output + # prod-orders should still be applied + namespace_mocks["prod-orders"].write.assert_called_once() + + def test_all_namespaces_fail_to_write(self, runner, valid_schema_file): + """Test behavior when all namespaces fail during write.""" + mock_ns1 = MagicMock() + mock_ns1.id = "ns1" + mock_ns2 = MagicMock() + mock_ns2.id = "ns2" + + mock_client = MagicMock() + mock_client.namespaces.return_value = [mock_ns1, mock_ns2] + + namespace_mocks = {} + + def mock_get_ns(name, region=None): + if name not in namespace_mocks: + mock = MagicMock() + mock_metadata = MagicMock() + mock_metadata.schema = {} # Empty schema + mock.metadata.return_value = mock_metadata + mock.write.side_effect = Exception("API unavailable") + namespace_mocks[name] = mock + return namespace_mocks[name] + + with patch("tpuff.commands.schema.get_turbopuffer_client", return_value=mock_client): + with patch("tpuff.commands.schema.get_namespace", side_effect=mock_get_ns): + result = runner.invoke( + schema, ["apply", "--all", "-f", valid_schema_file, "--yes"] + ) + + assert result.exit_code == 1 + assert "2 failed" in result.output + + def test_partial_success_reports_correctly(self, runner, valid_schema_file): + """Test that partial success is reported with correct counts.""" + mock_ns1 = MagicMock() + mock_ns1.id = "ns-success-1" + mock_ns2 = MagicMock() + mock_ns2.id = "ns-fail" + mock_ns3 = MagicMock() + mock_ns3.id = "ns-success-2" + + mock_client = MagicMock() + mock_client.namespaces.return_value = [mock_ns1, mock_ns2, mock_ns3] + + namespace_mocks = {} + + def mock_get_ns(name, region=None): + if name not in namespace_mocks: + mock = MagicMock() + mock_metadata = MagicMock() + mock_metadata.schema = {} + mock.metadata.return_value = mock_metadata + # Only ns-fail will fail + if name == "ns-fail": + mock.write.side_effect = Exception("Write failed") + namespace_mocks[name] = mock + return namespace_mocks[name] + + with patch("tpuff.commands.schema.get_turbopuffer_client", return_value=mock_client): + with patch("tpuff.commands.schema.get_namespace", side_effect=mock_get_ns): + result = runner.invoke( + schema, ["apply", "--all", "-f", valid_schema_file, "--yes"] + ) + + assert result.exit_code == 1 + assert "2 namespace(s)" in result.output # 2 successful + assert "1 failed" in result.output + + def test_error_in_dry_run_still_shows_status(self, runner, valid_schema_file): + """Test that errors during analysis appear in dry run output.""" + mock_ns1 = MagicMock() + mock_ns1.id = "ns-ok" + mock_ns2 = MagicMock() + mock_ns2.id = "ns-error" + + mock_client = MagicMock() + mock_client.namespaces.return_value = [mock_ns1, mock_ns2] + + def mock_get_ns(name, region=None): + # ns-error get_namespace fails entirely + if name == "ns-error": + raise Exception("Access denied") + mock = MagicMock() + mock_metadata = MagicMock() + mock_metadata.schema = {} + mock.metadata.return_value = mock_metadata + return mock + + with patch("tpuff.commands.schema.get_turbopuffer_client", return_value=mock_client): + with patch("tpuff.commands.schema.get_namespace", side_effect=mock_get_ns): + result = runner.invoke( + schema, ["apply", "--all", "-f", valid_schema_file, "--dry-run"] + ) + + assert result.exit_code == 0 + assert "Access denied" in result.output + assert "Dry run mode" in result.output + + def test_continue_on_error_with_mixed_errors(self, runner, tmp_path): + """Test --continue-on-error with both conflicts and write errors.""" + schema_file = tmp_path / "schema.json" + schema_file.write_text('{"content": "uint64", "new_field": "string"}') + + mock_ns1 = MagicMock() + mock_ns1.id = "ns-conflict" # Type conflict + mock_ns2 = MagicMock() + mock_ns2.id = "ns-write-error" # Write will fail + mock_ns3 = MagicMock() + mock_ns3.id = "ns-success" # Will succeed + + mock_client = MagicMock() + mock_client.namespaces.return_value = [mock_ns1, mock_ns2, mock_ns3] + + namespace_mocks = {} + + def mock_get_ns(name, region=None): + if name not in namespace_mocks: + mock = MagicMock() + mock_metadata = MagicMock() + if name == "ns-conflict": + # Has conflicting type + mock_metadata.schema = {"content": "string"} + else: + # No existing schema + mock_metadata.schema = {} + mock.metadata.return_value = mock_metadata + if name == "ns-write-error": + mock.write.side_effect = Exception("Write failed") + namespace_mocks[name] = mock + return namespace_mocks[name] + + with patch("tpuff.commands.schema.get_turbopuffer_client", return_value=mock_client): + with patch("tpuff.commands.schema.get_namespace", side_effect=mock_get_ns): + result = runner.invoke( + schema, + ["apply", "--all", "-f", str(schema_file), "--yes", "--continue-on-error"] + ) + + # Should fail because of write error + assert result.exit_code == 1 + assert "Warning" in result.output # Warning about conflicts + assert "1 failed" in result.output # Write failure + # ns-conflict should NOT have write called (skipped due to conflict) + namespace_mocks["ns-conflict"].write.assert_not_called() + # ns-success should succeed + namespace_mocks["ns-success"].write.assert_called_once()