Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
273 changes: 269 additions & 4 deletions recce/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,13 +287,21 @@ def version():
@cli.command(cls=TrackCommand)
@add_options(dbt_related_options)
@add_options(recce_dbt_artifact_dir_options)
@add_options(recce_cloud_options)
@add_options(recce_cloud_auth_options)
@click.option(
"--cache-db",
help="Path to the column-level lineage cache database.",
type=click.Path(),
default=None,
show_default=False,
)
@click.option(
"--session-id",
help="Recce Cloud session ID (for --cloud mode).",
type=click.STRING,
envvar="RECCE_SESSION_ID",
)
def init(cache_db, **kwargs):
"""
Pre-compute column-level lineage cache from dbt artifacts.
Expand All @@ -302,12 +310,18 @@ def init(cache_db, **kwargs):
(~/.recce/cll_cache.db by default) so that subsequent `recce server`
sessions start with a warm cache.

With --cloud, downloads artifacts from Recce Cloud, computes CLL, and
uploads the cache and CLL map back to the session's S3 bucket.

Works with one or both environments (target/ and/or target-base/).
"""

import json
import logging
import tempfile
import time

import requests
from rich.console import Console
from rich.progress import Progress

Expand All @@ -319,6 +333,159 @@ def init(cache_db, **kwargs):
console = Console()
console.rule("Recce Init — Building column-level lineage cache", style="orange3")

# Timeouts for HTTP requests (seconds): short for metadata, long for large files
_METADATA_TIMEOUT = 30
_DOWNLOAD_TIMEOUT = 300
_UPLOAD_TIMEOUT = 600

is_cloud = kwargs.get("cloud", False)
session_id = kwargs.get("session_id")
cloud_client = None
cloud_org_id = None
cloud_project_id = None

if is_cloud:
from recce.util.recce_cloud import RecceCloud, RecceCloudException

cloud_token = kwargs.get("cloud_token") or kwargs.get("api_token")
if not cloud_token:
console.print("[[red]Error[/red]] --cloud requires --cloud-token or --api-token (or GITHUB_TOKEN env var).")
exit(1)
if not session_id:
console.print("[[red]Error[/red]] --cloud requires --session-id (or RECCE_SESSION_ID env var).")
exit(1)

cloud_client = RecceCloud(token=cloud_token)
if kwargs.get("state_file_host"):
host = kwargs["state_file_host"]
cloud_client.base_url = f"{host}/api/v1"
cloud_client.base_url_v2 = f"{host}/api/v2"

console.print(f"[bold]Cloud mode[/bold]: session {session_id}")

# Get session info
try:
session_info = cloud_client.get_session(session_id)
except RecceCloudException as e:
console.print(f"[[red]Error[/red]] Failed to get session: {e}")
exit(1)
if session_info.get("status") == "error":
console.print(f"[[red]Error[/red]] Failed to get session: {session_info.get('message', 'Access denied')}")
exit(1)
cloud_org_id = session_info.get("org_id")
cloud_project_id = session_info.get("project_id")
if not cloud_org_id or not cloud_project_id:
console.print(f"[[red]Error[/red]] Session {session_id} missing org_id or project_id.")
exit(1)

# Download artifacts to local target directories
console.print("Downloading artifacts from Cloud...")
try:
download_urls = cloud_client.get_download_urls_by_session_id(cloud_org_id, cloud_project_id, session_id)
except RecceCloudException as e:
console.print(f"[[red]Error[/red]] Failed to get download URLs: {e}")
exit(1)

project_dir_path = Path(kwargs.get("project_dir") or "./")
target_path = project_dir_path / kwargs.get("target_path", "target")
target_base_path = project_dir_path / kwargs.get("target_base_path", "target-base")
target_path.mkdir(parents=True, exist_ok=True)
target_base_path.mkdir(parents=True, exist_ok=True)

# Download current session artifacts
for artifact_key, filename in [("manifest_url", "manifest.json"), ("catalog_url", "catalog.json")]:
url = download_urls.get(artifact_key)
if url:
try:
resp = requests.get(url, timeout=_METADATA_TIMEOUT)
if resp.status_code == 200:
(target_path / filename).write_bytes(resp.content)
console.print(f" Downloaded {filename} to {target_path}")
else:
console.print(
f" [[yellow]Warning[/yellow]] Failed to download {filename}: HTTP {resp.status_code}"
)
except requests.RequestException as e:
console.print(f" [[yellow]Warning[/yellow]] Failed to download {filename}: {e}")

# Download base session artifacts
try:
base_download_urls = cloud_client.get_base_session_download_urls(
cloud_org_id, cloud_project_id, session_id=session_id
)
except RecceCloudException as e:
console.print(f" [[yellow]Warning[/yellow]] Failed to get base session URLs: {e}")
base_download_urls = {}
for artifact_key, filename in [("manifest_url", "manifest.json"), ("catalog_url", "catalog.json")]:
url = base_download_urls.get(artifact_key)
if url:
try:
resp = requests.get(url, timeout=_METADATA_TIMEOUT)
if resp.status_code == 200:
(target_base_path / filename).write_bytes(resp.content)
console.print(f" Downloaded base {filename} to {target_base_path}")
else:
console.print(
f" [[yellow]Warning[/yellow]] Failed to download base {filename}: HTTP {resp.status_code}"
)
except requests.RequestException as e:
console.print(f" [[yellow]Warning[/yellow]] Failed to download base {filename}: {e}")

# Download existing CLL cache for warm start.
# Try current session first, then fall back to production (base) session.
# Use streaming to avoid loading large cache files entirely into memory.
if cache_db is None:
cache_db = _DEFAULT_DB_PATH
Path(cache_db).parent.mkdir(parents=True, exist_ok=True)

def _stream_download_to_file(url: str, dest: Path) -> int:
"""Stream a URL to a file, returning bytes written. Raises on failure."""
resp = requests.get(url, stream=True, timeout=_DOWNLOAD_TIMEOUT)
if resp.status_code != 200:
return 0
total = 0
with tempfile.NamedTemporaryFile(dir=dest.parent, delete=False, suffix=".tmp") as tmp:
tmp_path = Path(tmp.name)
try:
for chunk in resp.iter_content(chunk_size=8192):
tmp.write(chunk)
total += len(chunk)
tmp.flush()
except Exception:
tmp_path.unlink(missing_ok=True)
raise
if total > 0:
tmp_path.rename(dest)
else:
tmp_path.unlink(missing_ok=True)
return total

cache_downloaded = False
cll_cache_url = download_urls.get("cll_cache_url")
if cll_cache_url:
try:
nbytes = _stream_download_to_file(cll_cache_url, Path(cache_db))
if nbytes > 0:
console.print(f" Downloaded CLL cache from session ({nbytes / 1024 / 1024:.1f} MB)")
cache_downloaded = True
except requests.RequestException as e:
console.print(f" [[yellow]Warning[/yellow]] Failed to download CLL cache: {e}")

if not cache_downloaded:
# Fall back to production (base) session cache
base_cache_url = base_download_urls.get("cll_cache_url")
if base_cache_url:
try:
nbytes = _stream_download_to_file(base_cache_url, Path(cache_db))
if nbytes > 0:
console.print(f" Downloaded CLL cache from base session ({nbytes / 1024 / 1024:.1f} MB)")
cache_downloaded = True
except requests.RequestException as e:
console.print(f" [[yellow]Warning[/yellow]] Failed to download base CLL cache: {e}")

if not cache_downloaded:
console.print(" [dim]No existing CLL cache found — will compute from scratch[/dim]")

if cache_db is None:
cache_db = _DEFAULT_DB_PATH

Expand All @@ -331,9 +498,10 @@ def init(cache_db, **kwargs):
console.print(f"Evicted {evicted} stale cache entries (>7 days unused)")

# Check which artifact directories exist
project_dir_path = Path(kwargs.get("project_dir") or "./")
target_path = project_dir_path / kwargs.get("target_path", "target")
target_base_path = project_dir_path / kwargs.get("target_base_path", "target-base")
if not is_cloud:
project_dir_path = Path(kwargs.get("project_dir") or "./")
target_path = project_dir_path / kwargs.get("target_path", "target")
target_base_path = project_dir_path / kwargs.get("target_base_path", "target-base")

has_target = (target_path / "manifest.json").is_file()
has_base = (target_base_path / "manifest.json").is_file()
Expand Down Expand Up @@ -471,9 +639,106 @@ def init(cache_db, **kwargs):
if fail > 3:
console.print(f" [dim]... and {fail - 3} more skipped (see logs for details)[/dim]")

# Build and save the full CLL map as JSON.
# The per-node SQLite cache is warm from the loop above, so this is fast.
console.print("\n[bold]Building full CLL map...[/bold]")
t_map_start = time.perf_counter()
cll_map_path = Path(cache_db).parent / "cll_map.json"
try:
full_cll_map = dbt_adapter.build_full_cll_map()
cll_map_data = full_cll_map.model_dump(mode="json")
# Write to temp file first to avoid corrupted JSON on partial write
tmp_fd, tmp_name = tempfile.mkstemp(dir=cll_map_path.parent, suffix=".tmp")
try:
with os.fdopen(tmp_fd, "w") as f:
json.dump(cll_map_data, f)
Path(tmp_name).rename(cll_map_path)
except Exception:
Path(tmp_name).unlink(missing_ok=True)
raise
map_elapsed = time.perf_counter() - t_map_start
map_size_mb = cll_map_path.stat().st_size / 1024 / 1024
console.print(
f" CLL map saved to [bold]{cll_map_path}[/bold] "
f"({len(full_cll_map.nodes)} nodes, {len(full_cll_map.columns)} columns, "
f"{map_size_mb:.1f} MB, {map_elapsed:.1f}s)"
)
except Exception as e:
logger.warning("[recce init] Failed to build CLL map: %s", e)
console.print(f" [[yellow]Warning[/yellow]] Failed to build CLL map: {e}")

stats = cache.stats
console.print(f"\nCache saved to [bold]{cache_db}[/bold] ({stats['entries']} entries)")
console.print("Run [bold]recce server --enable-cll-cache[/bold] to use the cached lineage.")

# Upload results to Cloud if in cloud mode
if is_cloud and cloud_client:
console.print("\n[bold]Uploading results to Cloud...[/bold]")
upload_failures = []
try:
upload_urls = cloud_client.get_upload_urls_by_session_id(cloud_org_id, cloud_project_id, session_id)

# Upload CLL map
cll_map_upload_url = upload_urls.get("cll_map_url")
if cll_map_upload_url and cll_map_path.is_file():
try:
with open(cll_map_path, "rb") as f:
resp = requests.put(
cll_map_upload_url,
data=f,
headers={"Content-Type": "application/json"},
timeout=_UPLOAD_TIMEOUT,
)
if resp.status_code in (200, 204):
console.print(f" Uploaded cll_map.json ({cll_map_path.stat().st_size / 1024 / 1024:.1f} MB)")
else:
upload_failures.append("cll_map.json")
console.print(
f" [[yellow]Warning[/yellow]] Failed to upload cll_map.json: HTTP {resp.status_code}"
)
except requests.RequestException as e:
upload_failures.append("cll_map.json")
console.print(f" [[yellow]Warning[/yellow]] Failed to upload cll_map.json: {e}")
elif not cll_map_upload_url:
console.print(
" [[yellow]Warning[/yellow]] No cll_map_url in upload URLs (Cloud server may need update)"
)

# Upload CLL cache
cll_cache_upload_url = upload_urls.get("cll_cache_url")
if cll_cache_upload_url and Path(cache_db).is_file():
try:
with open(cache_db, "rb") as f:
resp = requests.put(
cll_cache_upload_url,
data=f,
headers={"Content-Type": "application/octet-stream"},
timeout=_UPLOAD_TIMEOUT,
)
if resp.status_code in (200, 204):
console.print(f" Uploaded cll_cache.db ({Path(cache_db).stat().st_size / 1024 / 1024:.1f} MB)")
else:
upload_failures.append("cll_cache.db")
console.print(
f" [[yellow]Warning[/yellow]] Failed to upload cll_cache.db: HTTP {resp.status_code}"
)
except requests.RequestException as e:
upload_failures.append("cll_cache.db")
console.print(f" [[yellow]Warning[/yellow]] Failed to upload cll_cache.db: {e}")
elif not cll_cache_upload_url:
logger.debug("No cll_cache_url in upload URLs — cache upload not supported yet")

if upload_failures:
console.print(
f"[bold yellow]Cloud upload completed with warnings[/bold yellow] "
f"(failed: {', '.join(upload_failures)})"
)
else:
console.print("[bold green]Cloud upload complete.[/bold green]")
except Exception as e:
logger.warning("[recce init] Cloud upload failed: %s", e)
console.print(f" [[yellow]Warning[/yellow]] Cloud upload failed: {e}")
else:
console.print("Run [bold]recce server --enable-cll-cache[/bold] to use the cached lineage.")


@cli.command(cls=TrackCommand)
Expand Down
Loading
Loading