Skip to content

Commit 19e1f56

Browse files
fix: parse carriage-return progress updates in notebooks
1 parent daf3dca commit 19e1f56

3 files changed

Lines changed: 42 additions & 3 deletions

File tree

pysr/jupyter_progress.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,14 +120,26 @@ def __init__(self, target_stream, parser: _ProgressLineParser):
120120
self._parser = parser
121121
self._buffer = ""
122122

123+
def _drain_complete_lines(self) -> None:
124+
# ProgressMeter often updates in-place with carriage returns (`\r`) rather
125+
# than newline-terminated lines. Treat both as parse boundaries.
126+
while True:
127+
newline_idx = self._buffer.find("\n")
128+
carriage_idx = self._buffer.find("\r")
129+
candidates = [idx for idx in (newline_idx, carriage_idx) if idx != -1]
130+
if not candidates:
131+
break
132+
split_idx = min(candidates)
133+
line = self._buffer[:split_idx]
134+
self._buffer = self._buffer[split_idx + 1 :]
135+
self._parser.parse_line(line)
136+
123137
def write(self, text: str) -> int:
124138
if not isinstance(text, str):
125139
text = str(text)
126140
written = self._target.write(text)
127141
self._buffer += text
128-
while "\n" in self._buffer:
129-
line, self._buffer = self._buffer.split("\n", 1)
130-
self._parser.parse_line(line)
142+
self._drain_complete_lines()
131143
return written if isinstance(written, int) else len(text)
132144

133145
def flush(self) -> None:

pysr/test/test_jupyter_progress_helpers.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,20 @@ def test_progress_parser_extracts_evolving_percent_lines(self):
4848
parser.parse_line("Evolving for 40 iterations... 11%|██ | ETA: 0:00:05")
4949
self.assertEqual(updates, [(4, 40)])
5050

51+
def test_progress_capture_stream_handles_carriage_return_updates(self):
52+
module = _load_module()
53+
updates = []
54+
parser = module._ProgressLineParser(
55+
lambda current, total: updates.append((current, total))
56+
)
57+
target = io.StringIO()
58+
stream = module._ProgressCaptureStream(target, parser)
59+
stream.write("Evolving for 100 iterations... 1%|\r")
60+
stream.write("Evolving for 100 iterations... 2%|\r")
61+
stream.write("Evolving for 100 iterations... 3%|\r")
62+
stream.flush()
63+
self.assertEqual(updates, [(1, 100), (2, 100), (3, 100)])
64+
5165
def test_should_use_jupyter_progress_gating(self):
5266
module = _load_module()
5367
self.assertFalse(

scripts/test_jupyter_progress_helpers.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,23 @@ def test_parser_extracts_evolving_percent() -> None:
4545
assert updates == [(4, 40)], updates
4646

4747

48+
def test_capture_stream_handles_carriage_return_updates() -> None:
49+
module = _load_module()
50+
updates = []
51+
parser = module._ProgressLineParser(lambda current, total: updates.append((current, total)))
52+
capture = module._ProgressCaptureStream(io.StringIO(), parser)
53+
capture.write("Evolving for 100 iterations... 1%|\r")
54+
capture.write("Evolving for 100 iterations... 2%|\r")
55+
capture.write("Evolving for 100 iterations... 3%|\r")
56+
capture.flush()
57+
assert updates == [(1, 100), (2, 100), (3, 100)], updates
58+
59+
4860
def run() -> None:
4961
test_parser_extracts_progress()
5062
test_capture_stream_handles_split_lines()
5163
test_parser_extracts_evolving_percent()
64+
test_capture_stream_handles_carriage_return_updates()
5265
print("helpers-tests=ok")
5366

5467

0 commit comments

Comments
 (0)