Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion src/tpuff/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from tpuff.commands.list import list_cmd
from tpuff.commands.schema import schema
from tpuff.commands.search import search
from tpuff.utils.output import resolve_output_mode

# Context settings to enable -h as help alias for all commands
CONTEXT_SETTINGS = {"help_option_names": ["-h", "--help"]}
Expand All @@ -18,11 +19,19 @@
@click.group(context_settings=CONTEXT_SETTINGS)
@click.version_option(version=__version__, prog_name="tpuff")
@click.option("--debug", is_flag=True, help="Enable debug output")
@click.option(
"-o",
"--output",
type=click.Choice(["human", "plain"]),
default=None,
help="Output format: human (rich tables) or plain (pipe-delimited). Auto-detects TTY if omitted.",
)
@click.pass_context
def cli(ctx: click.Context, debug: bool) -> None:
def cli(ctx: click.Context, debug: bool, output: str | None) -> None:
"""tpuff - CLI tool for Turbopuffer vector database."""
ctx.ensure_object(dict)
ctx.obj["debug"] = debug
ctx.obj["output_mode"] = resolve_output_mode(output)


# Register commands
Expand Down
25 changes: 19 additions & 6 deletions src/tpuff/commands/get.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from tpuff.client import get_namespace
from tpuff.utils.debug import debug_log
from tpuff.utils.output import is_plain, status_print

console = Console()

Expand All @@ -24,8 +25,14 @@ def get(
region: str | None,
) -> None:
"""Get a document by ID from a namespace."""
plain = is_plain(ctx)

try:
console.print(f"\n[bold]Querying document with ID: {id} from namespace: {namespace}[/bold]\n")
status_print(
ctx,
f"\n[bold]Querying document with ID: {id} from namespace: {namespace}[/bold]\n",
console,
)

# Get namespace reference
ns = get_namespace(namespace, region)
Expand Down Expand Up @@ -64,14 +71,20 @@ def get(
else:
doc_dict = {"id": getattr(doc, "id", "N/A")}

# Display document
console.print("[cyan]Document:[/cyan]")
console.print(json.dumps(doc_dict, indent=2, default=str))
if plain:
# Plain mode: raw JSON only
click.echo(json.dumps(doc_dict, default=str))
else:
# Display document
console.print("[cyan]Document:[/cyan]")
console.print(json.dumps(doc_dict, indent=2, default=str))

# Show performance info
if hasattr(result, "performance") and result.performance:
console.print(
f"\n[dim]Query took {result.performance.query_execution_ms:.2f}ms[/dim]"
status_print(
ctx,
f"\n[dim]Query took {result.performance.query_execution_ms:.2f}ms[/dim]",
console,
)

except Exception as e:
Expand Down
195 changes: 121 additions & 74 deletions src/tpuff/commands/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
get_index_status,
get_unindexed_bytes,
)
from tpuff.utils.output import is_plain, print_table_plain, status_print

console = Console()

Expand All @@ -33,10 +34,10 @@ def format_bytes(bytes_count: int) -> str:
return f"{bytes_count:.2f} {sizes[i]}"


def format_updated_at(timestamp: str | datetime | None) -> str:
def format_updated_at(timestamp: str | datetime | None, plain: bool = False) -> str:
"""Format timestamp smartly: time if today, date otherwise."""
if timestamp is None:
return "[dim]N/A[/dim]"
return "N/A" if plain else "[dim]N/A[/dim]"

try:
# Handle datetime objects directly
Expand All @@ -58,17 +59,22 @@ def format_updated_at(timestamp: str | datetime | None) -> str:
else:
return date.strftime("%b %-d, %Y")
except Exception:
return str(timestamp) if timestamp else "[dim]N/A[/dim]"
if timestamp:
return str(timestamp)
return "N/A" if plain else "[dim]N/A[/dim]"


def format_recall(recall_data) -> str:
def format_recall(recall_data, plain: bool = False) -> str:
"""Format recall as a color-coded percentage."""
if not recall_data:
return "[dim]N/A[/dim]"
return "N/A" if plain else "[dim]N/A[/dim]"

percentage = recall_data.avg_recall * 100
display_value = f"{percentage:.1f}%"

if plain:
return display_value

if recall_data.avg_recall > 0.95:
return f"[green]{display_value}[/green]"
elif recall_data.avg_recall > 0.8:
Expand Down Expand Up @@ -104,12 +110,17 @@ def extract_vector_info(schema: dict) -> dict | None:


def display_namespace_documents(
namespace: str, top_k: int, region: str | None = None
ctx: click.Context, namespace: str, top_k: int, region: str | None = None
) -> None:
"""List documents in a specific namespace."""
plain = is_plain(ctx)
ns = get_namespace(namespace, region)

console.print(f"\n[bold]Querying namespace: {namespace} (top {top_k} results)[/bold]\n")
status_print(
ctx,
f"\n[bold]Querying namespace: {namespace} (top {top_k} results)[/bold]\n",
console,
)

# Get namespace metadata to extract schema
metadata = ns.metadata()
Expand All @@ -124,8 +135,10 @@ def display_namespace_documents(
console.print("[red]Error: No vector attribute found in namespace schema[/red]")
sys.exit(1)

console.print(
f"[dim]Using {vector_info['dimensions']}-dimensional zero vector for query[/dim]\n"
status_print(
ctx,
f"[dim]Using {vector_info['dimensions']}-dimensional zero vector for query[/dim]\n",
console,
)

# Create zero vector
Expand Down Expand Up @@ -153,14 +166,10 @@ def display_namespace_documents(
console.print("No documents found in namespace")
return

console.print(f"[bold]Found {len(rows)} document(s):[/bold]\n")
status_print(ctx, f"[bold]Found {len(rows)} document(s):[/bold]\n", console)

# Create table for results
table = Table(show_header=True, header_style="cyan")
table.add_column("ID")
table.add_column("Contents")

# Add rows to table
# Collect row data
table_rows = []
for row in rows:
# Get the row as a dict using model_dump() or fallback
if hasattr(row, "model_dump"):
Expand All @@ -179,30 +188,36 @@ def display_namespace_documents(
if key not in exclude_keys and not key.startswith("_"):
contents[key] = value

# Stringify and truncate contents
contents_str = json.dumps(contents, default=str)
max_length = 80
display_contents = (
contents_str[:max_length] + "..." if len(contents_str) > max_length else contents_str
)
table_rows.append([str(row_id), contents_str])

table.add_row(str(row_id), display_contents)

console.print(table)
if plain:
print_table_plain(["ID", "Contents"], table_rows)
else:
table = Table(show_header=True, header_style="cyan")
table.add_column("ID")
table.add_column("Contents")
for r in table_rows:
table.add_row(*r)
console.print(table)

# Show performance info if available
if hasattr(result, "performance") and result.performance:
console.print(
f"\n[dim]Query took {result.performance.query_execution_ms:.2f}ms[/dim]"
status_print(
ctx,
f"\n[dim]Query took {result.performance.query_execution_ms:.2f}ms[/dim]",
console,
)


def display_namespaces(
ctx: click.Context,
all_regions: bool = False,
region: str | None = None,
include_recall: bool = False,
) -> None:
"""List all namespaces."""
plain = is_plain(ctx)
namespaces_with_metadata = fetch_namespaces_with_metadata(
all_regions=all_regions,
region=region,
Expand All @@ -213,7 +228,11 @@ def display_namespaces(
console.print("No namespaces found")
return

console.print(f"\n[bold]Found {len(namespaces_with_metadata)} namespace(s):[/bold]\n")
status_print(
ctx,
f"\n[bold]Found {len(namespaces_with_metadata)} namespace(s):[/bold]\n",
console,
)

# Sort by updated_at in descending order (most recent first)
def sort_key(item: NamespaceWithMetadata):
Expand All @@ -229,62 +248,90 @@ def sort_key(item: NamespaceWithMetadata):

namespaces_with_metadata.sort(key=sort_key, reverse=True)

# Create table with conditional region and recall columns
table = Table(show_header=True, header_style="cyan")
table.add_column("Namespace")
# Build headers
headers = ["Namespace"]
if all_regions:
table.add_column("Region")
table.add_column("Rows")
table.add_column("Logical Bytes")
table.add_column("Index Status")
table.add_column("Unindexed Bytes")
headers.append("Region")
headers.extend(["Rows", "Logical Bytes", "Index Status", "Unindexed Bytes"])
if include_recall:
table.add_column("Recall")
table.add_column("Updated")
headers.append("Recall")
headers.append("Updated")

# Add rows to table
# Collect row data
table_rows = []
for item in namespaces_with_metadata:
if item.metadata:
index_status = get_index_status(item.metadata)
index_status_display = (
"[green]up-to-date[/green]"
if index_status == "up-to-date"
else "[red]updating[/red]"
)

unindexed = get_unindexed_bytes(item.metadata)
unindexed_display = (
f"[red]{format_bytes(unindexed)}[/red]"
if unindexed > 0
else format_bytes(0)
)

row = [f"[bold]{item.namespace_id}[/bold]"]
if all_regions and item.region:
row.append(f"[dim]{item.region}[/dim]")
row.extend([
f"{item.metadata.approx_row_count:,}",
format_bytes(item.metadata.approx_logical_bytes),
index_status_display,
unindexed_display,
])
if include_recall:
row.append(format_recall(item.recall))
row.append(format_updated_at(item.metadata.updated_at))

table.add_row(*row)

if plain:
row = [item.namespace_id]
if all_regions:
row.append(item.region or "")
row.extend([
f"{item.metadata.approx_row_count:,}",
format_bytes(item.metadata.approx_logical_bytes),
index_status,
format_bytes(unindexed),
])
if include_recall:
row.append(format_recall(item.recall, plain=True))
row.append(format_updated_at(item.metadata.updated_at, plain=True))
else:
index_status_display = (
"[green]up-to-date[/green]"
if index_status == "up-to-date"
else "[red]updating[/red]"
)
unindexed_display = (
f"[red]{format_bytes(unindexed)}[/red]"
if unindexed > 0
else format_bytes(0)
)

row = [f"[bold]{item.namespace_id}[/bold]"]
if all_regions and item.region:
row.append(f"[dim]{item.region}[/dim]")
row.extend([
f"{item.metadata.approx_row_count:,}",
format_bytes(item.metadata.approx_logical_bytes),
index_status_display,
unindexed_display,
])
if include_recall:
row.append(format_recall(item.recall))
row.append(format_updated_at(item.metadata.updated_at))

table_rows.append(row)
else:
row = [f"[bold]{item.namespace_id}[/bold]"]
if all_regions and item.region:
row.append(f"[dim]{item.region}[/dim]")
row.extend(["[dim]N/A[/dim]"] * 4)
if include_recall:
if plain:
row = [item.namespace_id]
if all_regions:
row.append(item.region or "")
row.extend(["N/A"] * 4)
if include_recall:
row.append("N/A")
row.append("N/A")
else:
row = [f"[bold]{item.namespace_id}[/bold]"]
if all_regions and item.region:
row.append(f"[dim]{item.region}[/dim]")
row.extend(["[dim]N/A[/dim]"] * 4)
if include_recall:
row.append("[dim]N/A[/dim]")
row.append("[dim]N/A[/dim]")
row.append("[dim]N/A[/dim]")

table.add_row(*row)
table_rows.append(row)

console.print(table)
if plain:
print_table_plain(headers, table_rows)
else:
table = Table(show_header=True, header_style="cyan")
for h in headers:
table.add_column(h)
for r in table_rows:
table.add_row(*r)
console.print(table)


@click.command("list", context_settings={"help_option_names": ["-h", "--help"]})
Expand Down Expand Up @@ -327,10 +374,10 @@ def list_cmd(
)
sys.exit(1)

display_namespace_documents(namespace, top_k, region)
display_namespace_documents(ctx, namespace, top_k, region)
else:
# List all namespaces
display_namespaces(all_regions, region, include_recall)
display_namespaces(ctx, all_regions, region, include_recall)
except Exception as e:
console.print(f"[red]Error: {e}[/red]")
sys.exit(1)
Loading
Loading