88import threading
99from contextlib import contextmanager
1010from dataclasses import dataclass
11- from typing import Callable , Iterator , Optional , Protocol , Tuple
11+ from typing import Callable , Iterator , Protocol
1212
1313
1414_PROGRESS_PATTERN = re .compile (r"Progress:\s*(\d+)\s*/\s*(\d+)\s*total iterations" )
1717
1818class _ProgressDisplay (Protocol ):
1919 def update (self , current : int , total : int ) -> None : ...
20-
2120 def close (self ) -> None : ...
2221
2322
2423class _NullProgressDisplay :
2524 def update (self , current : int , total : int ) -> None :
26- return None
27-
28- def close (self ) -> None :
29- return None
30-
31-
32- class _TqdmProgressDisplay :
33- def __init__ (self , total : int ):
34- from tqdm .notebook import tqdm
35-
36- self ._bar = tqdm (total = total , desc = "PySR fit" , leave = True )
37- self ._current = 0
38-
39- def update (self , current : int , total : int ) -> None :
40- if total != self ._bar .total :
41- self ._bar .total = total
42- delta = max (0 , current - self ._current )
43- if delta > 0 :
44- self ._bar .update (delta )
45- self ._current = current
46-
25+ pass
4726 def close (self ) -> None :
48- self . _bar . close ()
27+ pass
4928
5029
5130class _IpywidgetsProgressDisplay :
31+ """Widget display using ipywidgets - more reliable than tqdm in notebooks."""
32+
5233 def __init__ (self , total : int ):
5334 from IPython .display import display
5435 from ipywidgets import HTML , IntProgress , VBox
36+ import ipywidgets as widgets
5537
56- self ._bar = IntProgress (value = 0 , min = 0 , max = max (total , 1 ), description = "PySR fit" )
57- self ._label = HTML (value = f"0 / { total } iterations" )
38+ self ._total = total
39+ self ._current = 0
40+
41+ # Create widget with explicit layout
42+ self ._bar = IntProgress (
43+ value = 0 ,
44+ min = 0 ,
45+ max = max (total , 1 ),
46+ description = "PySR" ,
47+ style = {'description_width' : 'initial' },
48+ layout = widgets .Layout (width = '100%' )
49+ )
50+ self ._label = HTML (value = f"0 / { total } " )
5851 self ._widget = VBox ([self ._bar , self ._label ])
52+
53+ # Display and keep reference
5954 display (self ._widget )
6055
6156 def update (self , current : int , total : int ) -> None :
62- self ._bar .max = max (total , 1 )
57+ """Update widget values."""
58+ import ipywidgets as widgets
59+
60+ self ._current = current
61+ if total != self ._total :
62+ self ._total = total
63+ self ._bar .max = max (total , 1 )
64+
65+ # Direct value assignment
6366 self ._bar .value = min (max (current , 0 ), self ._bar .max )
64- self ._label .value = f"{ current } / { total } iterations "
67+ self ._label .value = f"{ current } / { total } "
6568
6669 def close (self ) -> None :
67- return None
70+ """Close the widget."""
71+ try :
72+ self ._bar .close ()
73+ self ._label .close ()
74+ self ._widget .close ()
75+ except Exception :
76+ pass
6877
6978
7079def _is_notebook_session () -> bool :
71- # Colab does not always expose the same IPython shell metadata as classic Jupyter.
80+ """Detect if we're running in a notebook."""
7281 if "google.colab" in sys .modules or os .environ .get ("COLAB_RELEASE_TAG" ):
7382 return True
74-
7583 try :
7684 from IPython import get_ipython
85+ ipython = get_ipython ()
86+ if ipython is None :
87+ return False
88+ return ipython .__class__ .__name__ == "ZMQInteractiveShell"
7789 except Exception :
7890 return False
7991
80- ipython = get_ipython ()
81- if ipython is None :
82- return False
83- return ipython .__class__ .__name__ == "ZMQInteractiveShell"
84-
85-
86- def _get_iopub_thread ():
87- """Best-effort access to ipykernel's IOPubThread (for thread-safe widget updates)."""
88- try :
89- from ipykernel .ipkernel import IPythonKernel # type: ignore
90-
91- ip = IPythonKernel .instance ()
92- return getattr (ip , "iopub_thread" , None )
93- except Exception :
94- return None
95-
9692
9793def _create_display (total : int ) -> _ProgressDisplay :
94+ """Create progress display - prefer ipywidgets for reliability."""
9895 try :
99- return _TqdmProgressDisplay (total = total )
100- except Exception :
101- pass
102-
103- try :
96+ # Only try ipywidgets, skip tqdm due to threading issues
10497 return _IpywidgetsProgressDisplay (total = total )
10598 except Exception :
10699 return _NullProgressDisplay ()
@@ -128,26 +121,16 @@ def parse_line(self, line: str) -> None:
128121 self .on_progress (current , total )
129122
130123
131- @dataclass
132- class _StreamProxy :
133- _stream : object
134- write : Callable [[str ], int ]
135- flush : Callable [[], None ] | None = None
136-
137- def __getattr__ (self , name : str ):
138- return getattr (self ._stream , name )
139-
140-
141124class _ProgressCaptureStream :
125+ """Captures stdout/stderr and parses progress lines."""
126+
142127 def __init__ (self , target_stream , parser : _ProgressLineParser ):
143128 self ._target = target_stream
144129 self ._parser = parser
145130 self ._buffer = ""
146- self ._lock = threading .RLock ()
131+ self ._lock = threading .Lock ()
147132
148133 def _drain_complete_lines (self ) -> None :
149- # ProgressMeter often updates in-place with carriage returns (`\r`) rather
150- # than newline-terminated lines. Treat both as parse boundaries.
151134 while True :
152135 newline_idx = self ._buffer .find ("\n " )
153136 carriage_idx = self ._buffer .find ("\r " )
@@ -166,8 +149,6 @@ def write(self, text: str) -> int:
166149 written = self ._target .write (text )
167150 self ._buffer += text
168151 self ._drain_complete_lines ()
169- # Also parse the in-flight buffer in case progress updates arrive
170- # without line delimiters (seen in some notebook frontends).
171152 if self ._buffer :
172153 self ._parser .parse_line (self ._buffer )
173154 return written if isinstance (written , int ) else len (text )
@@ -180,137 +161,61 @@ def flush(self) -> None:
180161 if hasattr (self ._target , "flush" ):
181162 self ._target .flush ()
182163
183- def __getattr__ (self , name : str ):
184- return getattr (self ._target , name )
185-
186164
187165class JupyterProgressContext :
188- """Capture text progress lines and render a notebook progress widget ."""
166+ """Context manager for Jupyter progress display ."""
189167
190168 def __init__ (self , total_iterations : int ):
191169 self .total_iterations = max (int (total_iterations ), 1 )
192170 self .display : _ProgressDisplay = _NullProgressDisplay ()
193171 self ._parser = _ProgressLineParser (self ._on_progress )
194172 self ._current = 0
195-
196- # Progress updates often come from non-main threads (e.g., ipykernel watchfd thread,
197- # or Julia threads calling into Python). Some notebook frontends (notably Colab)
198- # are much more reliable if widget updates are sent from ipykernel's IOPub thread.
199- self ._iopub_thread = None
200- self ._update_lock = threading .Lock ()
201- self ._pending_update : Optional [Tuple [int , int ]] = None
202- self ._update_scheduled = False
203- self ._active = False
173+ self ._lock = threading .Lock ()
204174
205175 def _on_progress (self , current : int , total : int ) -> None :
206- self ._current = current
207- self ._queue_update (current , total )
208-
209- def _queue_update (self , current : int , total : int ) -> None :
210- # Coalesce frequent updates to avoid flooding the frontend with widget state
211- # messages. This also lets us route the actual widget update through the
212- # ipykernel IOPub thread when available.
213- with self ._update_lock :
214- if not self ._active :
215- return
216- self ._pending_update = (current , total )
217- if self ._update_scheduled :
218- return
219- self ._update_scheduled = True
220-
221- iopub_thread = self ._iopub_thread
222- if iopub_thread is not None :
223- iopub_thread .schedule (self ._apply_pending_update )
224- else :
225- self ._apply_pending_update ()
226-
227- def _apply_pending_update (self ) -> None :
228- # Runs either on ipykernel's IOPub thread or the current thread.
229- # Process all pending updates but limit iterations to prevent infinite loops.
230- max_iterations = 100
231- for _ in range (max_iterations ):
232- with self ._update_lock :
233- if not self ._active :
234- self ._pending_update = None
235- self ._update_scheduled = False
236- return
237- pending = self ._pending_update
238- self ._pending_update = None
239- if pending is None :
240- self ._update_scheduled = False
241- return
242-
243- current , total = pending
244- try :
245- self .display .update (current , total )
246- except Exception :
247- # Never let UI plumbing crash a model fit.
248- pass
176+ """Called when progress is detected (may be from any thread)."""
177+ with self ._lock :
178+ self ._current = current
179+ # Update display immediately - ipywidgets should handle thread safety
180+ try :
181+ self .display .update (current , total )
182+ except Exception :
183+ pass
249184
250185 @contextmanager
251186 def capture (self ) -> Iterator [None ]:
187+ """Capture stdout/stderr and display progress."""
252188 self .display = _create_display (self .total_iterations )
253189 self .display .update (0 , self .total_iterations )
254-
255- # Mark active *before* any output starts flowing.
256- self ._iopub_thread = _get_iopub_thread ()
257- with self ._update_lock :
258- self ._active = True
259- self ._pending_update = None
260- self ._update_scheduled = False
261-
190+
262191 old_stdout , old_stderr = sys .stdout , sys .stderr
263-
264- stdout_old_write = old_stdout .write
265- stdout_old_flush = getattr (old_stdout , "flush" , None )
266- stderr_old_write = old_stderr .write
267- stderr_old_flush = getattr (old_stderr , "flush" , None )
268-
269- stdout_proxy = _StreamProxy (old_stdout , stdout_old_write , stdout_old_flush )
270- stderr_proxy = _StreamProxy (old_stderr , stderr_old_write , stderr_old_flush )
271- stdout_capture_stream = _ProgressCaptureStream (stdout_proxy , self ._parser )
272- stderr_capture_stream = _ProgressCaptureStream (stderr_proxy , self ._parser )
273-
192+ stdout_capture = _ProgressCaptureStream (old_stdout , self ._parser )
193+ stderr_capture = _ProgressCaptureStream (old_stderr , self ._parser )
194+
274195 try :
275- old_stdout .write = stdout_capture_stream .write
276- if stdout_old_flush is not None :
277- old_stdout .flush = stdout_capture_stream .flush
278-
279- old_stderr .write = stderr_capture_stream .write
280- if stderr_old_flush is not None :
281- old_stderr .flush = stderr_capture_stream .flush
282-
196+ sys .stdout = stdout_capture
197+ sys .stderr = stderr_capture
283198 yield
284199 finally :
285- stdout_capture_stream .flush ()
286- stderr_capture_stream .flush ()
287-
288- old_stdout .write = stdout_old_write
289- if stdout_old_flush is not None :
290- old_stdout .flush = stdout_old_flush
291-
292- old_stderr .write = stderr_old_write
293- if stderr_old_flush is not None :
294- old_stderr .flush = stderr_old_flush
295-
296- # Prevent any late/asynchronous progress callbacks from trying to update
297- # a closed widget.
298- with self ._update_lock :
299- self ._active = False
300- self ._pending_update = None
301- self ._update_scheduled = False
302-
303- self .display .update (self .total_iterations , self .total_iterations )
200+ stdout_capture .flush ()
201+ stderr_capture .flush ()
202+ sys .stdout , sys .stderr = old_stdout , old_stderr
203+
204+ # Final update
205+ with self ._lock :
206+ final = self ._current
207+ try :
208+ self .display .update (self .total_iterations , self .total_iterations )
209+ except Exception :
210+ pass
304211 self .display .close ()
305212
306213
307214def should_use_jupyter_progress (* , progress : bool , verbosity : int , is_single_output : bool ) -> bool :
308215 """Whether PySR should use Python-side notebook progress handling."""
309216 if not progress or verbosity <= 0 or not is_single_output :
310217 return False
311-
312218 disable_progress = os .environ .get ("PYSR_DISABLE_JUPYTER_PROGRESS" , "" ).lower ()
313219 if disable_progress in {"1" , "true" , "yes" , "on" }:
314220 return False
315-
316221 return _is_notebook_session ()
0 commit comments