diff --git a/changelog.md b/changelog.md index 33fb9aa4..5e1febe9 100644 --- a/changelog.md +++ b/changelog.md @@ -2,8 +2,9 @@ Upcoming (TBD) ============== Features ---------- +-------- * Silently accept forward slash to introduce special commands. +* `--progress` spinners for setup steps in `--batch` mode. Internal diff --git a/mycli/main_modes/batch.py b/mycli/main_modes/batch.py index d87a27e4..b009486b 100644 --- a/mycli/main_modes/batch.py +++ b/mycli/main_modes/batch.py @@ -10,6 +10,7 @@ from prompt_toolkit.shortcuts import ProgressBar from prompt_toolkit.shortcuts.progress_bar import formatters as progress_bar_formatters import pymysql +from yaspin import yaspin from mycli.packages.batch_utils import statements_from_filehandle from mycli.packages.interactive_utils import confirm_destructive_query @@ -28,6 +29,7 @@ def replay_checkpoint_file( batch_path: str, checkpoint_path: str | None, resume: bool, + progress: bool = False, ) -> int: if not resume: return 0 @@ -42,27 +44,46 @@ def replay_checkpoint_file( raise CheckpointReplayError('--resume is incompatible with reading from the standard input.') completed_count = 0 + if progress: + spinner = yaspin(text='replaying checkpoint', side='right', stream=sys.stderr) + spinner.start() try: with click.open_file(batch_path) as batch_h, click.open_file(checkpoint_path, mode='r', encoding='utf-8') as checkpoint_h: try: batch_gen = statements_from_filehandle(batch_h) except ValueError as e: + if progress: + spinner.fail('✘') raise CheckpointReplayError(f'Error reading --batch file: {batch_path}: {e}') from None for checkpoint_statement, _checkpoint_counter in statements_from_filehandle(checkpoint_h): try: batch_statement, _batch_counter = next(batch_gen) except StopIteration: + if progress: + spinner.fail('✘') raise CheckpointReplayError('Checkpoint script longer than batch script.') from None except ValueError as e: + if progress: + spinner.fail('✘') raise CheckpointReplayError(f'Error reading --batch file: {batch_path}: {e}') from None if checkpoint_statement != batch_statement: + if progress: + spinner.fail('✘') raise CheckpointReplayError(f'Statement mismatch: {checkpoint_statement}.') completed_count += 1 + if progress: + spinner.ok('✔') except ValueError as e: + if progress: + spinner.fail('✘') raise CheckpointReplayError(f'Error reading --checkpoint file: {checkpoint_path}: {e}') from None except FileNotFoundError as e: + if progress: + spinner.fail('✘') raise CheckpointReplayError(f'FileNotFoundError: {e}') from None except OSError as e: + if progress: + spinner.fail('✘') raise CheckpointReplayError(f'OSError: {e}') from None return completed_count @@ -119,10 +140,12 @@ def main_batch_with_progress_bar(mycli: 'MyCli', cli_args: 'CliArgs') -> int: click.secho('--progress is only compatible with a plain file.', err=True, fg='red') return 1 try: - completed_statement_count = replay_checkpoint_file(cli_args.batch, cli_args.checkpoint, cli_args.resume) + completed_statement_count = replay_checkpoint_file(cli_args.batch, cli_args.checkpoint, cli_args.resume, progress=True) batch_count_h = click.open_file(cli_args.batch) - for _statement, _counter in statements_from_filehandle(batch_count_h): - goal_statements += 1 + with yaspin(text='validating batch ', side='right', stream=sys.stderr) as spinner: + for _statement, _counter in statements_from_filehandle(batch_count_h): + goal_statements += 1 + spinner.ok('✔') batch_count_h.close() batch_h = click.open_file(cli_args.batch) batch_gen = statements_from_filehandle(batch_h) @@ -140,7 +163,7 @@ def main_batch_with_progress_bar(mycli: 'MyCli', cli_args: 'CliArgs') -> int: if goal_statements: pb_style = prompt_toolkit.styles.Style.from_dict({'bar-a': 'reverse'}) custom_formatters = [ - progress_bar_formatters.Bar(start='[', end=']', sym_a=' ', sym_b=' ', sym_c=' '), + progress_bar_formatters.Bar(start='running queries [', end=']', sym_a=' ', sym_b=' ', sym_c=' '), progress_bar_formatters.Text(' '), progress_bar_formatters.Progress(), progress_bar_formatters.Text(' '), diff --git a/pyproject.toml b/pyproject.toml index 22d870b4..3165a381 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ dependencies = [ "pyfzf ~= 0.3.1", "rapidfuzz ~= 3.14.3", "keyring ~= 25.7.0", + "yaspin ~= 3.4.0", ] [project.urls] diff --git a/test/pytests/test_main_modes_batch.py b/test/pytests/test_main_modes_batch.py index 777f7d67..250e462a 100644 --- a/test/pytests/test_main_modes_batch.py +++ b/test/pytests/test_main_modes_batch.py @@ -68,6 +68,23 @@ def close(self) -> None: self.closed = True +class DummyStream: + def __init__(self, tty: bool = False) -> None: + self.closed = False + self.tty = tty + self.writes: list[str] = [] + + def isatty(self) -> bool: + return self.tty + + def write(self, value: str) -> int: + self.writes.append(value) + return len(value) + + def flush(self) -> None: + return None + + class DummyProgressBar: calls: list[list[int]] = [] @@ -86,6 +103,32 @@ def __call__(self, iterable) -> list[int]: return values +class DummySpinner: + instances: list['DummySpinner'] = [] + + def __init__(self, *args, **kwargs) -> None: + self.fail_calls: list[str] = [] + self.ok_calls: list[str] = [] + self.started = False + DummySpinner.instances.append(self) + + def __enter__(self) -> 'DummySpinner': + self.start() + return self + + def __exit__(self, exc_type, exc, tb) -> Literal[False]: + return False + + def start(self) -> None: + self.started = True + + def fail(self, text: str) -> None: + self.fail_calls.append(text) + + def ok(self, text: str) -> None: + self.ok_calls.append(text) + + def dispatch_batch_statements( mycli: DummyMyCli, cli_args: DummyCliArgs, @@ -108,7 +151,7 @@ def main_batch_from_stdin(mycli: DummyMyCli, cli_args: DummyCliArgs) -> int: def make_fake_sys(stdin_tty: bool, stderr_tty: bool | None = None) -> SimpleNamespace: - stderr = SimpleNamespace(isatty=lambda: stderr_tty) if stderr_tty is not None else object() + stderr = DummyStream(bool(stderr_tty)) return SimpleNamespace( stdin=SimpleNamespace(isatty=lambda: stdin_tty), stderr=stderr, @@ -186,6 +229,21 @@ def test_replay_checkpoint_file_rejects_checkpoint_longer_than_batch(tmp_path: P batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True) +def test_replay_checkpoint_file_marks_progress_failed_when_checkpoint_is_longer( + monkeypatch, + tmp_path: Path, +) -> None: + batch_path = write_batch_file(tmp_path, 'select 1;\n') + checkpoint = write_checkpoint_file(tmp_path, 'select 1;\nselect 2;\n') + DummySpinner.instances.clear() + monkeypatch.setattr(batch_mode, 'yaspin', DummySpinner) + + with pytest.raises(batch_mode.CheckpointReplayError, match='Checkpoint script longer than batch script.'): + batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True, progress=True) + + assert DummySpinner.instances[0].fail_calls == ['✘'] + + @pytest.mark.skipif(os.name == 'nt', reason='todo: unknown') def test_replay_checkpoint_file_rejects_batch_read_error(monkeypatch, tmp_path: Path) -> None: batch_path = write_batch_file(tmp_path, 'select 1;\n') @@ -198,6 +256,20 @@ def test_replay_checkpoint_file_rejects_batch_read_error(monkeypatch, tmp_path: batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True) +@pytest.mark.skipif(os.name == 'nt', reason='todo: unknown') +def test_replay_checkpoint_file_marks_progress_failed_for_batch_read_error(monkeypatch, tmp_path: Path) -> None: + batch_path = write_batch_file(tmp_path, 'select 1;\n') + checkpoint = write_checkpoint_file(tmp_path, 'select 1;\n') + DummySpinner.instances.clear() + monkeypatch.setattr(batch_mode, 'statements_from_filehandle', lambda _handle: (_ for _ in ()).throw(ValueError('bad batch'))) + monkeypatch.setattr(batch_mode, 'yaspin', DummySpinner) + + with pytest.raises(batch_mode.CheckpointReplayError, match=f'Error reading --batch file: {batch_path}: bad batch'): + batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True, progress=True) + + assert DummySpinner.instances[0].fail_calls == ['✘'] + + @pytest.mark.skipif(os.name == 'nt', reason='todo: unknown') def test_replay_checkpoint_file_rejects_batch_iteration_error(monkeypatch, tmp_path: Path) -> None: batch_path = write_batch_file(tmp_path, 'select 1;\n') @@ -219,6 +291,31 @@ def fake_statements_from_filehandle(handle): batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True) +@pytest.mark.skipif(os.name == 'nt', reason='todo: unknown') +def test_replay_checkpoint_file_marks_progress_failed_for_batch_iteration_error(monkeypatch, tmp_path: Path) -> None: + batch_path = write_batch_file(tmp_path, 'select 1;\n') + + def raise_on_next(): + raise ValueError('bad batch iterator') + yield + + def fake_statements_from_filehandle(handle): + if handle.name == batch_path: + return raise_on_next() + return iter([('select 1;', 0)]) + + DummySpinner.instances.clear() + monkeypatch.setattr(batch_mode, 'statements_from_filehandle', fake_statements_from_filehandle) + monkeypatch.setattr(batch_mode, 'yaspin', DummySpinner) + + checkpoint = write_checkpoint_file(tmp_path, 'select 1;\n') + + with pytest.raises(batch_mode.CheckpointReplayError, match=f'Error reading --batch file: {batch_path}: bad batch iterator'): + batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True, progress=True) + + assert DummySpinner.instances[0].fail_calls == ['✘'] + + @pytest.mark.skipif(os.name == 'nt', reason='todo: unknown') def test_replay_checkpoint_file_rejects_checkpoint_read_error(monkeypatch, tmp_path: Path) -> None: batch_path = write_batch_file(tmp_path, 'select 1;\n') @@ -236,6 +333,27 @@ def fake_statements_from_filehandle(handle): batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True) +@pytest.mark.skipif(os.name == 'nt', reason='todo: unknown') +def test_replay_checkpoint_file_marks_progress_failed_for_checkpoint_read_error(monkeypatch, tmp_path: Path) -> None: + batch_path = write_batch_file(tmp_path, 'select 1;\n') + + def fake_statements_from_filehandle(handle): + if handle.name == batch_path: + return iter([('select 1;', 0)]) + return (_ for _ in ()).throw(ValueError('bad checkpoint')) + + DummySpinner.instances.clear() + monkeypatch.setattr(batch_mode, 'statements_from_filehandle', fake_statements_from_filehandle) + monkeypatch.setattr(batch_mode, 'yaspin', DummySpinner) + + checkpoint = write_checkpoint_file(tmp_path, 'select 1;\n') + + with pytest.raises(batch_mode.CheckpointReplayError, match=f'Error reading --checkpoint file: {checkpoint}: bad checkpoint'): + batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True, progress=True) + + assert DummySpinner.instances[0].fail_calls == ['✘'] + + def test_replay_checkpoint_file_rejects_missing_files(tmp_path: Path) -> None: batch_path = str(tmp_path / 'missing.sql') @@ -245,6 +363,18 @@ def test_replay_checkpoint_file_rejects_missing_files(tmp_path: Path) -> None: batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True) +def test_replay_checkpoint_file_marks_progress_failed_for_missing_files(monkeypatch, tmp_path: Path) -> None: + batch_path = str(tmp_path / 'missing.sql') + checkpoint = write_checkpoint_file(tmp_path, 'select 1;\n') + DummySpinner.instances.clear() + monkeypatch.setattr(batch_mode, 'yaspin', DummySpinner) + + with pytest.raises(batch_mode.CheckpointReplayError, match='FileNotFoundError'): + batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True, progress=True) + + assert DummySpinner.instances[0].fail_calls == ['✘'] + + def test_replay_checkpoint_file_rejects_open_errors(monkeypatch, tmp_path: Path) -> None: batch_path = write_batch_file(tmp_path, 'select 1;\n') @@ -256,6 +386,19 @@ def test_replay_checkpoint_file_rejects_open_errors(monkeypatch, tmp_path: Path) batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True) +def test_replay_checkpoint_file_marks_progress_failed_for_open_errors(monkeypatch, tmp_path: Path) -> None: + batch_path = write_batch_file(tmp_path, 'select 1;\n') + checkpoint = write_checkpoint_file(tmp_path, 'select 1;\n') + DummySpinner.instances.clear() + monkeypatch.setattr(batch_mode.click, 'open_file', lambda *_args, **_kwargs: (_ for _ in ()).throw(OSError('open failed'))) + monkeypatch.setattr(batch_mode, 'yaspin', DummySpinner) + + with pytest.raises(batch_mode.CheckpointReplayError, match='OSError'): + batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True, progress=True) + + assert DummySpinner.instances[0].fail_calls == ['✘'] + + @pytest.mark.parametrize( ('format_name', 'batch_counter', 'expected'), (