From 7336214e15e0fab8eec79ecf4f0f0565c1a66c8f Mon Sep 17 00:00:00 2001 From: Roland Walker Date: Fri, 19 Jun 2026 15:29:24 -0400 Subject: [PATCH] spinners for setup steps with --progress --batch For large --batch scripts (the very kind of SQL scripts for which the progress bar is useful) the initial steps of pre-parsing the batch and optionally replaying the checkpoint log can take substantial time. Therefore, add text and spinners showing the stage and progress of the setup steps. A lightweight new dependency is added due to limitations in the spinners from prompt-toolkit: the prompt-toolkit spinners must have a known upper bound to which to iterate, and they leave behind random "mid-spin" characters on completion like pre-parsing batch \ which creates visual confusion about whether the step completed successfully. yaspin spinners can integrate with prompt-toolkit more deeply than is done here, and could also for instance be used during long interactive queries (though that particular usecase may have little value). Another option would be spinners from the "rich" library, which we may try in the future, if adopting "rich" for tabular output. --- changelog.md | 3 +- mycli/main_modes/batch.py | 31 +++++- pyproject.toml | 1 + test/pytests/test_main_modes_batch.py | 145 +++++++++++++++++++++++++- 4 files changed, 174 insertions(+), 6 deletions(-) diff --git a/changelog.md b/changelog.md index 33fb9aa4..5e1febe9 100644 --- a/changelog.md +++ b/changelog.md @@ -2,8 +2,9 @@ Upcoming (TBD) ============== Features ---------- +-------- * Silently accept forward slash to introduce special commands. +* `--progress` spinners for setup steps in `--batch` mode. Internal diff --git a/mycli/main_modes/batch.py b/mycli/main_modes/batch.py index d87a27e4..b009486b 100644 --- a/mycli/main_modes/batch.py +++ b/mycli/main_modes/batch.py @@ -10,6 +10,7 @@ from prompt_toolkit.shortcuts import ProgressBar from prompt_toolkit.shortcuts.progress_bar import formatters as progress_bar_formatters import pymysql +from yaspin import yaspin from mycli.packages.batch_utils import statements_from_filehandle from mycli.packages.interactive_utils import confirm_destructive_query @@ -28,6 +29,7 @@ def replay_checkpoint_file( batch_path: str, checkpoint_path: str | None, resume: bool, + progress: bool = False, ) -> int: if not resume: return 0 @@ -42,27 +44,46 @@ def replay_checkpoint_file( raise CheckpointReplayError('--resume is incompatible with reading from the standard input.') completed_count = 0 + if progress: + spinner = yaspin(text='replaying checkpoint', side='right', stream=sys.stderr) + spinner.start() try: with click.open_file(batch_path) as batch_h, click.open_file(checkpoint_path, mode='r', encoding='utf-8') as checkpoint_h: try: batch_gen = statements_from_filehandle(batch_h) except ValueError as e: + if progress: + spinner.fail('✘') raise CheckpointReplayError(f'Error reading --batch file: {batch_path}: {e}') from None for checkpoint_statement, _checkpoint_counter in statements_from_filehandle(checkpoint_h): try: batch_statement, _batch_counter = next(batch_gen) except StopIteration: + if progress: + spinner.fail('✘') raise CheckpointReplayError('Checkpoint script longer than batch script.') from None except ValueError as e: + if progress: + spinner.fail('✘') raise CheckpointReplayError(f'Error reading --batch file: {batch_path}: {e}') from None if checkpoint_statement != batch_statement: + if progress: + spinner.fail('✘') raise CheckpointReplayError(f'Statement mismatch: {checkpoint_statement}.') completed_count += 1 + if progress: + spinner.ok('✔') except ValueError as e: + if progress: + spinner.fail('✘') raise CheckpointReplayError(f'Error reading --checkpoint file: {checkpoint_path}: {e}') from None except FileNotFoundError as e: + if progress: + spinner.fail('✘') raise CheckpointReplayError(f'FileNotFoundError: {e}') from None except OSError as e: + if progress: + spinner.fail('✘') raise CheckpointReplayError(f'OSError: {e}') from None return completed_count @@ -119,10 +140,12 @@ def main_batch_with_progress_bar(mycli: 'MyCli', cli_args: 'CliArgs') -> int: click.secho('--progress is only compatible with a plain file.', err=True, fg='red') return 1 try: - completed_statement_count = replay_checkpoint_file(cli_args.batch, cli_args.checkpoint, cli_args.resume) + completed_statement_count = replay_checkpoint_file(cli_args.batch, cli_args.checkpoint, cli_args.resume, progress=True) batch_count_h = click.open_file(cli_args.batch) - for _statement, _counter in statements_from_filehandle(batch_count_h): - goal_statements += 1 + with yaspin(text='validating batch ', side='right', stream=sys.stderr) as spinner: + for _statement, _counter in statements_from_filehandle(batch_count_h): + goal_statements += 1 + spinner.ok('✔') batch_count_h.close() batch_h = click.open_file(cli_args.batch) batch_gen = statements_from_filehandle(batch_h) @@ -140,7 +163,7 @@ def main_batch_with_progress_bar(mycli: 'MyCli', cli_args: 'CliArgs') -> int: if goal_statements: pb_style = prompt_toolkit.styles.Style.from_dict({'bar-a': 'reverse'}) custom_formatters = [ - progress_bar_formatters.Bar(start='[', end=']', sym_a=' ', sym_b=' ', sym_c=' '), + progress_bar_formatters.Bar(start='running queries [', end=']', sym_a=' ', sym_b=' ', sym_c=' '), progress_bar_formatters.Text(' '), progress_bar_formatters.Progress(), progress_bar_formatters.Text(' '), diff --git a/pyproject.toml b/pyproject.toml index 22d870b4..3165a381 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ dependencies = [ "pyfzf ~= 0.3.1", "rapidfuzz ~= 3.14.3", "keyring ~= 25.7.0", + "yaspin ~= 3.4.0", ] [project.urls] diff --git a/test/pytests/test_main_modes_batch.py b/test/pytests/test_main_modes_batch.py index 777f7d67..250e462a 100644 --- a/test/pytests/test_main_modes_batch.py +++ b/test/pytests/test_main_modes_batch.py @@ -68,6 +68,23 @@ def close(self) -> None: self.closed = True +class DummyStream: + def __init__(self, tty: bool = False) -> None: + self.closed = False + self.tty = tty + self.writes: list[str] = [] + + def isatty(self) -> bool: + return self.tty + + def write(self, value: str) -> int: + self.writes.append(value) + return len(value) + + def flush(self) -> None: + return None + + class DummyProgressBar: calls: list[list[int]] = [] @@ -86,6 +103,32 @@ def __call__(self, iterable) -> list[int]: return values +class DummySpinner: + instances: list['DummySpinner'] = [] + + def __init__(self, *args, **kwargs) -> None: + self.fail_calls: list[str] = [] + self.ok_calls: list[str] = [] + self.started = False + DummySpinner.instances.append(self) + + def __enter__(self) -> 'DummySpinner': + self.start() + return self + + def __exit__(self, exc_type, exc, tb) -> Literal[False]: + return False + + def start(self) -> None: + self.started = True + + def fail(self, text: str) -> None: + self.fail_calls.append(text) + + def ok(self, text: str) -> None: + self.ok_calls.append(text) + + def dispatch_batch_statements( mycli: DummyMyCli, cli_args: DummyCliArgs, @@ -108,7 +151,7 @@ def main_batch_from_stdin(mycli: DummyMyCli, cli_args: DummyCliArgs) -> int: def make_fake_sys(stdin_tty: bool, stderr_tty: bool | None = None) -> SimpleNamespace: - stderr = SimpleNamespace(isatty=lambda: stderr_tty) if stderr_tty is not None else object() + stderr = DummyStream(bool(stderr_tty)) return SimpleNamespace( stdin=SimpleNamespace(isatty=lambda: stdin_tty), stderr=stderr, @@ -186,6 +229,21 @@ def test_replay_checkpoint_file_rejects_checkpoint_longer_than_batch(tmp_path: P batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True) +def test_replay_checkpoint_file_marks_progress_failed_when_checkpoint_is_longer( + monkeypatch, + tmp_path: Path, +) -> None: + batch_path = write_batch_file(tmp_path, 'select 1;\n') + checkpoint = write_checkpoint_file(tmp_path, 'select 1;\nselect 2;\n') + DummySpinner.instances.clear() + monkeypatch.setattr(batch_mode, 'yaspin', DummySpinner) + + with pytest.raises(batch_mode.CheckpointReplayError, match='Checkpoint script longer than batch script.'): + batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True, progress=True) + + assert DummySpinner.instances[0].fail_calls == ['✘'] + + @pytest.mark.skipif(os.name == 'nt', reason='todo: unknown') def test_replay_checkpoint_file_rejects_batch_read_error(monkeypatch, tmp_path: Path) -> None: batch_path = write_batch_file(tmp_path, 'select 1;\n') @@ -198,6 +256,20 @@ def test_replay_checkpoint_file_rejects_batch_read_error(monkeypatch, tmp_path: batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True) +@pytest.mark.skipif(os.name == 'nt', reason='todo: unknown') +def test_replay_checkpoint_file_marks_progress_failed_for_batch_read_error(monkeypatch, tmp_path: Path) -> None: + batch_path = write_batch_file(tmp_path, 'select 1;\n') + checkpoint = write_checkpoint_file(tmp_path, 'select 1;\n') + DummySpinner.instances.clear() + monkeypatch.setattr(batch_mode, 'statements_from_filehandle', lambda _handle: (_ for _ in ()).throw(ValueError('bad batch'))) + monkeypatch.setattr(batch_mode, 'yaspin', DummySpinner) + + with pytest.raises(batch_mode.CheckpointReplayError, match=f'Error reading --batch file: {batch_path}: bad batch'): + batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True, progress=True) + + assert DummySpinner.instances[0].fail_calls == ['✘'] + + @pytest.mark.skipif(os.name == 'nt', reason='todo: unknown') def test_replay_checkpoint_file_rejects_batch_iteration_error(monkeypatch, tmp_path: Path) -> None: batch_path = write_batch_file(tmp_path, 'select 1;\n') @@ -219,6 +291,31 @@ def fake_statements_from_filehandle(handle): batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True) +@pytest.mark.skipif(os.name == 'nt', reason='todo: unknown') +def test_replay_checkpoint_file_marks_progress_failed_for_batch_iteration_error(monkeypatch, tmp_path: Path) -> None: + batch_path = write_batch_file(tmp_path, 'select 1;\n') + + def raise_on_next(): + raise ValueError('bad batch iterator') + yield + + def fake_statements_from_filehandle(handle): + if handle.name == batch_path: + return raise_on_next() + return iter([('select 1;', 0)]) + + DummySpinner.instances.clear() + monkeypatch.setattr(batch_mode, 'statements_from_filehandle', fake_statements_from_filehandle) + monkeypatch.setattr(batch_mode, 'yaspin', DummySpinner) + + checkpoint = write_checkpoint_file(tmp_path, 'select 1;\n') + + with pytest.raises(batch_mode.CheckpointReplayError, match=f'Error reading --batch file: {batch_path}: bad batch iterator'): + batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True, progress=True) + + assert DummySpinner.instances[0].fail_calls == ['✘'] + + @pytest.mark.skipif(os.name == 'nt', reason='todo: unknown') def test_replay_checkpoint_file_rejects_checkpoint_read_error(monkeypatch, tmp_path: Path) -> None: batch_path = write_batch_file(tmp_path, 'select 1;\n') @@ -236,6 +333,27 @@ def fake_statements_from_filehandle(handle): batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True) +@pytest.mark.skipif(os.name == 'nt', reason='todo: unknown') +def test_replay_checkpoint_file_marks_progress_failed_for_checkpoint_read_error(monkeypatch, tmp_path: Path) -> None: + batch_path = write_batch_file(tmp_path, 'select 1;\n') + + def fake_statements_from_filehandle(handle): + if handle.name == batch_path: + return iter([('select 1;', 0)]) + return (_ for _ in ()).throw(ValueError('bad checkpoint')) + + DummySpinner.instances.clear() + monkeypatch.setattr(batch_mode, 'statements_from_filehandle', fake_statements_from_filehandle) + monkeypatch.setattr(batch_mode, 'yaspin', DummySpinner) + + checkpoint = write_checkpoint_file(tmp_path, 'select 1;\n') + + with pytest.raises(batch_mode.CheckpointReplayError, match=f'Error reading --checkpoint file: {checkpoint}: bad checkpoint'): + batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True, progress=True) + + assert DummySpinner.instances[0].fail_calls == ['✘'] + + def test_replay_checkpoint_file_rejects_missing_files(tmp_path: Path) -> None: batch_path = str(tmp_path / 'missing.sql') @@ -245,6 +363,18 @@ def test_replay_checkpoint_file_rejects_missing_files(tmp_path: Path) -> None: batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True) +def test_replay_checkpoint_file_marks_progress_failed_for_missing_files(monkeypatch, tmp_path: Path) -> None: + batch_path = str(tmp_path / 'missing.sql') + checkpoint = write_checkpoint_file(tmp_path, 'select 1;\n') + DummySpinner.instances.clear() + monkeypatch.setattr(batch_mode, 'yaspin', DummySpinner) + + with pytest.raises(batch_mode.CheckpointReplayError, match='FileNotFoundError'): + batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True, progress=True) + + assert DummySpinner.instances[0].fail_calls == ['✘'] + + def test_replay_checkpoint_file_rejects_open_errors(monkeypatch, tmp_path: Path) -> None: batch_path = write_batch_file(tmp_path, 'select 1;\n') @@ -256,6 +386,19 @@ def test_replay_checkpoint_file_rejects_open_errors(monkeypatch, tmp_path: Path) batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True) +def test_replay_checkpoint_file_marks_progress_failed_for_open_errors(monkeypatch, tmp_path: Path) -> None: + batch_path = write_batch_file(tmp_path, 'select 1;\n') + checkpoint = write_checkpoint_file(tmp_path, 'select 1;\n') + DummySpinner.instances.clear() + monkeypatch.setattr(batch_mode.click, 'open_file', lambda *_args, **_kwargs: (_ for _ in ()).throw(OSError('open failed'))) + monkeypatch.setattr(batch_mode, 'yaspin', DummySpinner) + + with pytest.raises(batch_mode.CheckpointReplayError, match='OSError'): + batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True, progress=True) + + assert DummySpinner.instances[0].fail_calls == ['✘'] + + @pytest.mark.parametrize( ('format_name', 'batch_counter', 'expected'), (