Skip to content

Commit 3066bfe

Browse files
fix: simplify to use ipywidgets directly, skip tqdm
The threading model in notebooks/Colab makes widget updates from background threads unreliable. Use ipywidgets directly instead of tqdm which has its own threading complexity.
1 parent c945016 commit 3066bfe

1 file changed

Lines changed: 76 additions & 171 deletions

File tree

pysr/jupyter_progress.py

Lines changed: 76 additions & 171 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import threading
99
from contextlib import contextmanager
1010
from dataclasses import dataclass
11-
from typing import Callable, Iterator, Optional, Protocol, Tuple
11+
from typing import Callable, Iterator, Protocol
1212

1313

1414
_PROGRESS_PATTERN = re.compile(r"Progress:\s*(\d+)\s*/\s*(\d+)\s*total iterations")
@@ -17,90 +17,83 @@
1717

1818
class _ProgressDisplay(Protocol):
1919
def update(self, current: int, total: int) -> None: ...
20-
2120
def close(self) -> None: ...
2221

2322

2423
class _NullProgressDisplay:
2524
def update(self, current: int, total: int) -> None:
26-
return None
27-
28-
def close(self) -> None:
29-
return None
30-
31-
32-
class _TqdmProgressDisplay:
33-
def __init__(self, total: int):
34-
from tqdm.notebook import tqdm
35-
36-
self._bar = tqdm(total=total, desc="PySR fit", leave=True)
37-
self._current = 0
38-
39-
def update(self, current: int, total: int) -> None:
40-
if total != self._bar.total:
41-
self._bar.total = total
42-
delta = max(0, current - self._current)
43-
if delta > 0:
44-
self._bar.update(delta)
45-
self._current = current
46-
25+
pass
4726
def close(self) -> None:
48-
self._bar.close()
27+
pass
4928

5029

5130
class _IpywidgetsProgressDisplay:
31+
"""Widget display using ipywidgets - more reliable than tqdm in notebooks."""
32+
5233
def __init__(self, total: int):
5334
from IPython.display import display
5435
from ipywidgets import HTML, IntProgress, VBox
36+
import ipywidgets as widgets
5537

56-
self._bar = IntProgress(value=0, min=0, max=max(total, 1), description="PySR fit")
57-
self._label = HTML(value=f"0 / {total} iterations")
38+
self._total = total
39+
self._current = 0
40+
41+
# Create widget with explicit layout
42+
self._bar = IntProgress(
43+
value=0,
44+
min=0,
45+
max=max(total, 1),
46+
description="PySR",
47+
style={'description_width': 'initial'},
48+
layout=widgets.Layout(width='100%')
49+
)
50+
self._label = HTML(value=f"0 / {total}")
5851
self._widget = VBox([self._bar, self._label])
52+
53+
# Display and keep reference
5954
display(self._widget)
6055

6156
def update(self, current: int, total: int) -> None:
62-
self._bar.max = max(total, 1)
57+
"""Update widget values."""
58+
import ipywidgets as widgets
59+
60+
self._current = current
61+
if total != self._total:
62+
self._total = total
63+
self._bar.max = max(total, 1)
64+
65+
# Direct value assignment
6366
self._bar.value = min(max(current, 0), self._bar.max)
64-
self._label.value = f"{current} / {total} iterations"
67+
self._label.value = f"{current} / {total}"
6568

6669
def close(self) -> None:
67-
return None
70+
"""Close the widget."""
71+
try:
72+
self._bar.close()
73+
self._label.close()
74+
self._widget.close()
75+
except Exception:
76+
pass
6877

6978

7079
def _is_notebook_session() -> bool:
71-
# Colab does not always expose the same IPython shell metadata as classic Jupyter.
80+
"""Detect if we're running in a notebook."""
7281
if "google.colab" in sys.modules or os.environ.get("COLAB_RELEASE_TAG"):
7382
return True
74-
7583
try:
7684
from IPython import get_ipython
85+
ipython = get_ipython()
86+
if ipython is None:
87+
return False
88+
return ipython.__class__.__name__ == "ZMQInteractiveShell"
7789
except Exception:
7890
return False
7991

80-
ipython = get_ipython()
81-
if ipython is None:
82-
return False
83-
return ipython.__class__.__name__ == "ZMQInteractiveShell"
84-
85-
86-
def _get_iopub_thread():
87-
"""Best-effort access to ipykernel's IOPubThread (for thread-safe widget updates)."""
88-
try:
89-
from ipykernel.ipkernel import IPythonKernel # type: ignore
90-
91-
ip = IPythonKernel.instance()
92-
return getattr(ip, "iopub_thread", None)
93-
except Exception:
94-
return None
95-
9692

9793
def _create_display(total: int) -> _ProgressDisplay:
94+
"""Create progress display - prefer ipywidgets for reliability."""
9895
try:
99-
return _TqdmProgressDisplay(total=total)
100-
except Exception:
101-
pass
102-
103-
try:
96+
# Only try ipywidgets, skip tqdm due to threading issues
10497
return _IpywidgetsProgressDisplay(total=total)
10598
except Exception:
10699
return _NullProgressDisplay()
@@ -128,26 +121,16 @@ def parse_line(self, line: str) -> None:
128121
self.on_progress(current, total)
129122

130123

131-
@dataclass
132-
class _StreamProxy:
133-
_stream: object
134-
write: Callable[[str], int]
135-
flush: Callable[[], None] | None = None
136-
137-
def __getattr__(self, name: str):
138-
return getattr(self._stream, name)
139-
140-
141124
class _ProgressCaptureStream:
125+
"""Captures stdout/stderr and parses progress lines."""
126+
142127
def __init__(self, target_stream, parser: _ProgressLineParser):
143128
self._target = target_stream
144129
self._parser = parser
145130
self._buffer = ""
146-
self._lock = threading.RLock()
131+
self._lock = threading.Lock()
147132

148133
def _drain_complete_lines(self) -> None:
149-
# ProgressMeter often updates in-place with carriage returns (`\r`) rather
150-
# than newline-terminated lines. Treat both as parse boundaries.
151134
while True:
152135
newline_idx = self._buffer.find("\n")
153136
carriage_idx = self._buffer.find("\r")
@@ -166,8 +149,6 @@ def write(self, text: str) -> int:
166149
written = self._target.write(text)
167150
self._buffer += text
168151
self._drain_complete_lines()
169-
# Also parse the in-flight buffer in case progress updates arrive
170-
# without line delimiters (seen in some notebook frontends).
171152
if self._buffer:
172153
self._parser.parse_line(self._buffer)
173154
return written if isinstance(written, int) else len(text)
@@ -180,137 +161,61 @@ def flush(self) -> None:
180161
if hasattr(self._target, "flush"):
181162
self._target.flush()
182163

183-
def __getattr__(self, name: str):
184-
return getattr(self._target, name)
185-
186164

187165
class JupyterProgressContext:
188-
"""Capture text progress lines and render a notebook progress widget."""
166+
"""Context manager for Jupyter progress display."""
189167

190168
def __init__(self, total_iterations: int):
191169
self.total_iterations = max(int(total_iterations), 1)
192170
self.display: _ProgressDisplay = _NullProgressDisplay()
193171
self._parser = _ProgressLineParser(self._on_progress)
194172
self._current = 0
195-
196-
# Progress updates often come from non-main threads (e.g., ipykernel watchfd thread,
197-
# or Julia threads calling into Python). Some notebook frontends (notably Colab)
198-
# are much more reliable if widget updates are sent from ipykernel's IOPub thread.
199-
self._iopub_thread = None
200-
self._update_lock = threading.Lock()
201-
self._pending_update: Optional[Tuple[int, int]] = None
202-
self._update_scheduled = False
203-
self._active = False
173+
self._lock = threading.Lock()
204174

205175
def _on_progress(self, current: int, total: int) -> None:
206-
self._current = current
207-
self._queue_update(current, total)
208-
209-
def _queue_update(self, current: int, total: int) -> None:
210-
# Coalesce frequent updates to avoid flooding the frontend with widget state
211-
# messages. This also lets us route the actual widget update through the
212-
# ipykernel IOPub thread when available.
213-
with self._update_lock:
214-
if not self._active:
215-
return
216-
self._pending_update = (current, total)
217-
if self._update_scheduled:
218-
return
219-
self._update_scheduled = True
220-
221-
iopub_thread = self._iopub_thread
222-
if iopub_thread is not None:
223-
iopub_thread.schedule(self._apply_pending_update)
224-
else:
225-
self._apply_pending_update()
226-
227-
def _apply_pending_update(self) -> None:
228-
# Runs either on ipykernel's IOPub thread or the current thread.
229-
# Process all pending updates but limit iterations to prevent infinite loops.
230-
max_iterations = 100
231-
for _ in range(max_iterations):
232-
with self._update_lock:
233-
if not self._active:
234-
self._pending_update = None
235-
self._update_scheduled = False
236-
return
237-
pending = self._pending_update
238-
self._pending_update = None
239-
if pending is None:
240-
self._update_scheduled = False
241-
return
242-
243-
current, total = pending
244-
try:
245-
self.display.update(current, total)
246-
except Exception:
247-
# Never let UI plumbing crash a model fit.
248-
pass
176+
"""Called when progress is detected (may be from any thread)."""
177+
with self._lock:
178+
self._current = current
179+
# Update display immediately - ipywidgets should handle thread safety
180+
try:
181+
self.display.update(current, total)
182+
except Exception:
183+
pass
249184

250185
@contextmanager
251186
def capture(self) -> Iterator[None]:
187+
"""Capture stdout/stderr and display progress."""
252188
self.display = _create_display(self.total_iterations)
253189
self.display.update(0, self.total_iterations)
254-
255-
# Mark active *before* any output starts flowing.
256-
self._iopub_thread = _get_iopub_thread()
257-
with self._update_lock:
258-
self._active = True
259-
self._pending_update = None
260-
self._update_scheduled = False
261-
190+
262191
old_stdout, old_stderr = sys.stdout, sys.stderr
263-
264-
stdout_old_write = old_stdout.write
265-
stdout_old_flush = getattr(old_stdout, "flush", None)
266-
stderr_old_write = old_stderr.write
267-
stderr_old_flush = getattr(old_stderr, "flush", None)
268-
269-
stdout_proxy = _StreamProxy(old_stdout, stdout_old_write, stdout_old_flush)
270-
stderr_proxy = _StreamProxy(old_stderr, stderr_old_write, stderr_old_flush)
271-
stdout_capture_stream = _ProgressCaptureStream(stdout_proxy, self._parser)
272-
stderr_capture_stream = _ProgressCaptureStream(stderr_proxy, self._parser)
273-
192+
stdout_capture = _ProgressCaptureStream(old_stdout, self._parser)
193+
stderr_capture = _ProgressCaptureStream(old_stderr, self._parser)
194+
274195
try:
275-
old_stdout.write = stdout_capture_stream.write
276-
if stdout_old_flush is not None:
277-
old_stdout.flush = stdout_capture_stream.flush
278-
279-
old_stderr.write = stderr_capture_stream.write
280-
if stderr_old_flush is not None:
281-
old_stderr.flush = stderr_capture_stream.flush
282-
196+
sys.stdout = stdout_capture
197+
sys.stderr = stderr_capture
283198
yield
284199
finally:
285-
stdout_capture_stream.flush()
286-
stderr_capture_stream.flush()
287-
288-
old_stdout.write = stdout_old_write
289-
if stdout_old_flush is not None:
290-
old_stdout.flush = stdout_old_flush
291-
292-
old_stderr.write = stderr_old_write
293-
if stderr_old_flush is not None:
294-
old_stderr.flush = stderr_old_flush
295-
296-
# Prevent any late/asynchronous progress callbacks from trying to update
297-
# a closed widget.
298-
with self._update_lock:
299-
self._active = False
300-
self._pending_update = None
301-
self._update_scheduled = False
302-
303-
self.display.update(self.total_iterations, self.total_iterations)
200+
stdout_capture.flush()
201+
stderr_capture.flush()
202+
sys.stdout, sys.stderr = old_stdout, old_stderr
203+
204+
# Final update
205+
with self._lock:
206+
final = self._current
207+
try:
208+
self.display.update(self.total_iterations, self.total_iterations)
209+
except Exception:
210+
pass
304211
self.display.close()
305212

306213

307214
def should_use_jupyter_progress(*, progress: bool, verbosity: int, is_single_output: bool) -> bool:
308215
"""Whether PySR should use Python-side notebook progress handling."""
309216
if not progress or verbosity <= 0 or not is_single_output:
310217
return False
311-
312218
disable_progress = os.environ.get("PYSR_DISABLE_JUPYTER_PROGRESS", "").lower()
313219
if disable_progress in {"1", "true", "yes", "on"}:
314220
return False
315-
316221
return _is_notebook_session()

0 commit comments

Comments
 (0)