Skip to content

Commit e1e1f00

Browse files
authored
jarzul/use dce progress api for progress bar (#88)
* Use the DCE Progress API to display the progress bar, instead of reading logs * Create a singleton holding a common rich_console to use in both the logging configuration and when creating a rich Progress This allows for all logs to go through rich and making regular logs behave correctly out-of-the box with the rich Progress bar * Update the CLI Progress logic to separate the logic constructing state from the events received and the actual rendering using rich's Progress bar * Test the logic handling the progress state from received events * Use rich_console.is_terminal instead of sys.stderr.isatty()
1 parent 1b5f6d4 commit e1e1f00

6 files changed

Lines changed: 460 additions & 328 deletions

File tree

src/databao_cli/features/build.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66

77
def build_impl(project_layout: ProjectLayout, domain: str, should_index: bool) -> list[BuildDatasourceResult]:
88
dce_project_dir = project_layout.domains_dir / domain
9-
manager = DatabaoContextDomainManager(domain_dir=dce_project_dir)
109

11-
datasources = manager.get_configured_datasource_list()
12-
with cli_progress(total=len(datasources), label="Building datasources"):
13-
results: list[BuildDatasourceResult] = manager.build_context(datasource_ids=None, should_index=should_index)
10+
with cli_progress() as progress:
11+
results: list[BuildDatasourceResult] = DatabaoContextDomainManager(domain_dir=dce_project_dir).build_context(
12+
datasource_ids=None, should_index=should_index, progress=progress
13+
)
14+
1415
return results

src/databao_cli/features/index.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,7 @@ def index_impl(
1313

1414
manager = DatabaoContextDomainManager(domain_dir=dce_project_dir)
1515

16-
total = len(datasource_ids) if datasource_ids is not None else len(manager.get_configured_datasource_list())
17-
18-
with cli_progress(total=total, label="Indexing datasources"):
19-
results: list[IndexDatasourceResult] = manager.index_built_contexts(datasource_ids=datasource_ids)
16+
with cli_progress() as progress:
17+
results: list[IndexDatasourceResult] = manager.index_built_contexts(datasource_ids=datasource_ids, progress=progress)
2018

2119
return results
Lines changed: 157 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -1,107 +1,165 @@
11
from __future__ import annotations
22

3-
import logging
4-
import re
5-
import sys
63
from collections.abc import Iterator
74
from contextlib import contextmanager
8-
from typing import Any
9-
10-
# Log patterns emitted by databao_context_engine.build_sources.build_runner
11-
_BUILD_START_RE = re.compile(r'^Found datasource of type ".*" with name (.+)$')
12-
_INDEX_START_RE = re.compile(r"^Indexing datasource (.+)$")
13-
_ENRICH_START_RE = re.compile(r"^Enriching context for datasource (.+)$")
14-
_SKIP_RE = re.compile(r"^Skipping disabled datasource (.+)$")
15-
_FAIL_RE = re.compile(r"^Failed to build source at \((.+?)\)")
16-
_FAIL_ENRICH_RE = re.compile(r"^Failed to enrich context for datasource \((.+?)\)")
17-
18-
19-
class _ProgressTrackingHandler(logging.Handler):
20-
"""Intercepts databao_context_engine log messages to drive a Rich progress bar.
21-
22-
The library processes datasources sequentially. It logs "Found datasource..."
23-
at the START of each one. We advance the progress bar when we detect that
24-
a new datasource has started (meaning the previous one finished), and once
25-
more when the context manager exits (for the last datasource).
26-
"""
27-
28-
def __init__(
29-
self,
30-
progress: Any,
31-
overall_task: Any,
32-
datasource_task: Any,
33-
) -> None:
34-
super().__init__()
5+
from dataclasses import dataclass, field
6+
7+
from databao_context_engine.progress.progress import (
8+
ProgressCallback,
9+
ProgressEvent,
10+
ProgressKind,
11+
ProgressStep,
12+
)
13+
from rich.progress import (
14+
BarColumn,
15+
Progress,
16+
SpinnerColumn,
17+
TaskID,
18+
TaskProgressColumn,
19+
TextColumn,
20+
)
21+
from rich.table import Column
22+
23+
from databao_cli.shared.log.console import rich_console
24+
25+
26+
def _noop_progress_cb(_: ProgressEvent) -> None:
27+
return
28+
29+
30+
@dataclass
31+
class _StepProgress:
32+
step: ProgressStep
33+
units_completed: int
34+
units_total: int
35+
36+
37+
@dataclass
38+
class _DatasourceProgress:
39+
datasource_id: str
40+
step_plan: tuple[ProgressStep, ...] = ()
41+
completed_steps: set[ProgressStep] = field(default_factory=set)
42+
current_step_progress: _StepProgress | None = None
43+
finished: bool = False
44+
45+
46+
@dataclass
47+
class _OperationProgress:
48+
total_datasources: int | None = None
49+
completed_datasource_ids: list[str] = field(default_factory=list)
50+
current_datasource: _DatasourceProgress | None = None
51+
52+
53+
def _compute_percent(ds: _DatasourceProgress) -> float | None:
54+
if ds.finished:
55+
return 100.0
56+
if not ds.step_plan:
57+
return None
58+
fraction = 0.0
59+
p = ds.current_step_progress
60+
if p is not None and p.units_total > 0:
61+
fraction = p.units_completed / p.units_total
62+
return ((len(ds.completed_steps) + fraction) / len(ds.step_plan)) * 100.0
63+
64+
65+
def _apply_event(state: _OperationProgress, ev: ProgressEvent) -> _OperationProgress:
66+
match ev.kind:
67+
case ProgressKind.OPERATION_STARTED:
68+
state.total_datasources = ev.operation_total
69+
70+
case ProgressKind.OPERATION_FINISHED:
71+
pass
72+
73+
case ProgressKind.DATASOURCE_STARTED:
74+
if state.total_datasources is None:
75+
state.total_datasources = ev.datasource_total
76+
state.current_datasource = _DatasourceProgress(datasource_id=ev.datasource_id or "datasource")
77+
78+
case ProgressKind.DATASOURCE_STEP_PLAN_SET:
79+
if state.current_datasource is None:
80+
state.current_datasource = _DatasourceProgress(datasource_id=ev.datasource_id or "datasource")
81+
state.current_datasource.step_plan = ev.step_plan or ()
82+
83+
case ProgressKind.DATASOURCE_STEP_COMPLETED:
84+
ds = state.current_datasource
85+
if ds is not None and ev.step is not None and ev.step not in ds.completed_steps:
86+
ds.completed_steps.add(ev.step)
87+
if ds.current_step_progress is not None and ds.current_step_progress.step == ev.step:
88+
ds.current_step_progress = None
89+
90+
case ProgressKind.DATASOURCE_STEP_PROGRESS:
91+
ds = state.current_datasource
92+
if ds is not None and ev.step is not None and ev.step not in ds.completed_steps:
93+
current = ds.current_step_progress
94+
new_completed = ev.current_units_completed or 0
95+
if current is not None and current.step == ev.step:
96+
new_completed = max(current.units_completed, new_completed)
97+
ds.current_step_progress = _StepProgress(
98+
step=ev.step,
99+
units_completed=new_completed,
100+
units_total=ev.current_units_total or 0,
101+
)
102+
103+
case ProgressKind.DATASOURCE_FINISHED:
104+
ds = state.current_datasource
105+
if ds is not None:
106+
ds.finished = True
107+
state.completed_datasource_ids.append(ds.datasource_id)
108+
109+
return state
110+
111+
112+
class _ProgressRenderer:
113+
def __init__(self, progress: Progress) -> None:
35114
self._progress = progress
36-
self._overall_task = overall_task
37-
self._datasource_task = datasource_task
38-
self._has_active = False # whether a datasource is currently being processed
39-
40-
def emit(self, record: logging.LogRecord) -> None:
41-
msg = record.getMessage()
42-
43-
# Datasource processing started
44-
m = _BUILD_START_RE.match(msg) or _INDEX_START_RE.match(msg) or _ENRICH_START_RE.match(msg)
45-
if m:
46-
if self._has_active:
47-
# Previous datasource finished — advance
48-
self._progress.advance(self._overall_task)
49-
self._has_active = True
50-
name = m.group(1)
51-
self._progress.update(self._datasource_task, description=f" [dim]{name}[/dim]")
115+
self._overall_task_id: TaskID | None = None
116+
self._datasource_task_id: TaskID | None = None
117+
118+
def render(self, state: _OperationProgress) -> None:
119+
self._render_overall(state)
120+
self._render_datasource(state.current_datasource)
121+
122+
def _render_overall(self, state: _OperationProgress) -> None:
123+
completed = len(state.completed_datasource_ids)
124+
description = f"Datasources {completed}/{state.total_datasources}" if state.total_datasources else "Datasources"
125+
if self._overall_task_id is None:
126+
self._overall_task_id = self._progress.add_task(description, total=state.total_datasources, completed=completed)
127+
else:
128+
self._progress.update(
129+
self._overall_task_id,
130+
description=description,
131+
total=state.total_datasources,
132+
completed=float(completed),
133+
)
134+
135+
def _render_datasource(self, ds: _DatasourceProgress | None) -> None:
136+
if ds is None:
52137
return
53138

54-
# Datasource skipped (immediately done)
55-
if _SKIP_RE.match(msg):
56-
self._progress.advance(self._overall_task)
57-
return
58-
59-
# Datasource failed (after "Found datasource", so active is already True)
60-
if _FAIL_RE.match(msg) or _FAIL_ENRICH_RE.match(msg):
61-
if self._has_active:
62-
self._progress.advance(self._overall_task)
63-
self._has_active = False
64-
return
139+
description = f" {ds.datasource_id}"
140+
percent = _compute_percent(ds)
65141

66-
def finish(self) -> None:
67-
"""Advance for the last datasource that was being processed."""
68-
if self._has_active:
69-
self._progress.advance(self._overall_task)
70-
self._has_active = False
142+
if self._datasource_task_id is None:
143+
self._datasource_task_id = self._progress.add_task(
144+
description,
145+
total=100.0 if percent is not None else None,
146+
completed=int(percent or 0),
147+
)
148+
else:
149+
self._progress.update(
150+
self._datasource_task_id,
151+
description=description,
152+
total=100.0 if percent is not None else None,
153+
completed=percent or 0.0,
154+
)
71155

72156

73157
@contextmanager
74-
def cli_progress(total: int | None = None, label: str = "Datasources") -> Iterator[None]:
75-
"""Show a Rich progress bar during build/index operations.
76-
77-
Intercepts ``databao_context_engine`` log messages to track per-datasource progress.
78-
Redirects library logging through Rich for clean TTY output.
79-
80-
Args:
81-
total: Number of datasources to process.
82-
label: Label for the overall progress bar.
83-
"""
84-
try:
85-
from rich.console import Console
86-
from rich.logging import RichHandler
87-
from rich.progress import (
88-
BarColumn,
89-
MofNCompleteColumn,
90-
Progress,
91-
SpinnerColumn,
92-
TextColumn,
93-
)
94-
from rich.table import Column
95-
except ImportError:
96-
yield
97-
return
98-
99-
console = Console(stderr=True)
100-
158+
def cli_progress() -> Iterator[ProgressCallback]:
101159
# Rich's is_terminal already checks isatty(), NO_COLOR, TERM=dumb, etc.
102160
# This prevents progress bar ANSI output from breaking pexpect-based e2e tests.
103-
if not console.is_terminal:
104-
yield
161+
if not rich_console.is_terminal:
162+
yield _noop_progress_cb
105163
return
106164

107165
progress = Progress(
@@ -111,49 +169,16 @@ def cli_progress(total: int | None = None, label: str = "Datasources") -> Iterat
111169
table_column=Column(width=50, overflow="ellipsis", no_wrap=True),
112170
),
113171
BarColumn(),
114-
MofNCompleteColumn(),
172+
TaskProgressColumn(),
115173
transient=True,
116-
console=console,
117-
)
118-
119-
overall_task = progress.add_task(label, total=total)
120-
datasource_task = progress.add_task(" [dim]starting…[/dim]", total=None)
121-
122-
# --- logging setup ---
123-
engine_logger = logging.getLogger("databao_context_engine")
124-
cli_logger = logging.getLogger("databao_cli")
125-
126-
prev_engine = (list(engine_logger.handlers), engine_logger.propagate)
127-
prev_cli = (list(cli_logger.handlers), cli_logger.propagate)
128-
129-
def _is_console_handler(h: logging.Handler) -> bool:
130-
return isinstance(h, logging.StreamHandler) and getattr(h, "stream", None) in (sys.stderr, sys.stdout)
131-
132-
rich_handler = RichHandler(
133-
console=console,
134-
show_time=False,
135-
show_level=True,
136-
show_path=False,
137-
rich_tracebacks=False,
174+
console=rich_console,
138175
)
139176

140-
tracker = _ProgressTrackingHandler(progress, overall_task, datasource_task)
141-
tracker.setLevel(logging.DEBUG)
142-
143-
try:
144-
for lgr in (engine_logger, cli_logger):
145-
kept = [h for h in lgr.handlers if not _is_console_handler(h)]
146-
lgr.handlers = [*kept, rich_handler]
147-
lgr.propagate = False
148-
149-
engine_logger.addHandler(tracker)
177+
state = _OperationProgress()
178+
renderer = _ProgressRenderer(progress)
150179

151-
with progress:
152-
yield
180+
def on_event(ev: ProgressEvent) -> None:
181+
renderer.render(_apply_event(state, ev))
153182

154-
tracker.finish()
155-
finally:
156-
engine_logger.handlers = prev_engine[0]
157-
engine_logger.propagate = prev_engine[1]
158-
cli_logger.handlers = prev_cli[0]
159-
cli_logger.propagate = prev_cli[1]
183+
with progress:
184+
yield on_event
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from rich.console import Console
2+
3+
rich_console = Console(stderr=True)

0 commit comments

Comments
 (0)