diff --git a/changelog.md b/changelog.md index 5a53fb00..d926a8eb 100644 --- a/changelog.md +++ b/changelog.md @@ -1,6 +1,11 @@ Upcoming (TBD) ============== +Bug Fixes +--------- +* Ensure that `--batch` and `--checkpoint` files are distinct. + + Internal -------- * Improve test coverage for `completion_refresher.py`. diff --git a/mycli/client.py b/mycli/client.py index cc8d199a..073544f5 100644 --- a/mycli/client.py +++ b/mycli/client.py @@ -90,6 +90,7 @@ def __init__( self._keepalive_counter = 0 self.keepalive_ticks: int | None = 0 self.sandbox_mode: bool = False + self.checkpoint: IO | None = None # self.cnf_files is a class variable that stores the list of mysql # config files to read in at launch. diff --git a/mycli/client_query.py b/mycli/client_query.py index 34e29ea5..09e4291a 100644 --- a/mycli/client_query.py +++ b/mycli/client_query.py @@ -1,7 +1,6 @@ from __future__ import annotations -from io import TextIOWrapper -from typing import TYPE_CHECKING, Any +from typing import IO, TYPE_CHECKING, Any import click from pymysql.cursors import Cursor @@ -26,6 +25,7 @@ class ClientQueryMixin: numeric_alignment: str | None binary_display: str | None query_history: list[Any] + checkpoint: IO | None def log_query(self, query: str) -> None: ... def log_output(self, output: str) -> None: ... @@ -74,12 +74,14 @@ def _on_completions_refreshed(self, new_completer: SQLCompleter) -> None: def run_query( self, query: str, - checkpoint: TextIOWrapper | None = None, + checkpoint: str | None = None, new_line: bool = True, ) -> None: """Runs *query*.""" assert self.sqlexecute is not None self.log_query(query) + if checkpoint and not self.checkpoint: + self.checkpoint = click.open_file(checkpoint, mode='a') results = self.sqlexecute.run(query) for result in results: self.main_formatter.query = query @@ -111,9 +113,9 @@ def run_query( ) for line in output: click.echo(line, nl=new_line) - if checkpoint: - checkpoint.write(query.rstrip('\n') + '\n') - checkpoint.flush() + if self.checkpoint: + self.checkpoint.write(query.rstrip('\n') + '\n') + self.checkpoint.flush() def get_last_query(self) -> str | None: """Get the last query executed or None.""" diff --git a/mycli/main.py b/mycli/main.py index 9d2eadfe..6bd1b253 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -206,9 +206,9 @@ class CliArgs: type=click.File(mode='a', encoding='utf-8'), help='Log every query and its results to a file.', ) - checkpoint: TextIOWrapper | None = clickdc.option( - type=click.File(mode='a', encoding='utf-8'), - help='In batch or --execute mode, log successful queries to a file, and skipped with --resume.', + checkpoint: str | None = clickdc.option( + type=str, + help='In batch or --execute mode, log successful queries to a file, and skip them with --resume.', ) resume: bool = clickdc.option( '--resume', @@ -369,6 +369,17 @@ def preprocess_cli_args( click.secho('Error: --resume requires a --batch file.', err=True, fg='red') sys.exit(1) + if ( + cli_args.checkpoint + and os.path.exists(cli_args.checkpoint) + and cli_args.batch + and cli_args.batch != '-' + and os.path.exists(cli_args.batch) + ): + if os.stat(cli_args.batch) == os.stat(cli_args.checkpoint): + click.secho('Error: --batch and --checkpoint must be different files.', err=True, fg='red') + sys.exit(1) + if cli_args.verbose and cli_args.quiet: click.secho('Error: --verbose and --quiet are incompatible.', err=True, fg='red') sys.exit(1) diff --git a/mycli/main_modes/batch.py b/mycli/main_modes/batch.py index af14dd5f..d87a27e4 100644 --- a/mycli/main_modes/batch.py +++ b/mycli/main_modes/batch.py @@ -1,6 +1,5 @@ from __future__ import annotations -from io import TextIOWrapper import os import sys import time @@ -27,23 +26,24 @@ class CheckpointReplayError(Exception): def replay_checkpoint_file( batch_path: str, - checkpoint: TextIOWrapper | None, + checkpoint_path: str | None, resume: bool, ) -> int: if not resume: return 0 - if checkpoint is None: + if checkpoint_path is None: + return 0 + + if not os.path.exists(checkpoint_path): return 0 if batch_path == '-': raise CheckpointReplayError('--resume is incompatible with reading from the standard input.') - checkpoint_name = checkpoint.name - checkpoint.flush() completed_count = 0 try: - with click.open_file(batch_path) as batch_h, click.open_file(checkpoint_name, mode='r', encoding='utf-8') as checkpoint_h: + with click.open_file(batch_path) as batch_h, click.open_file(checkpoint_path, mode='r', encoding='utf-8') as checkpoint_h: try: batch_gen = statements_from_filehandle(batch_h) except ValueError as e: @@ -59,7 +59,7 @@ def replay_checkpoint_file( raise CheckpointReplayError(f'Statement mismatch: {checkpoint_statement}.') completed_count += 1 except ValueError as e: - raise CheckpointReplayError(f'Error reading --checkpoint file: {checkpoint.name}: {e}') from None + raise CheckpointReplayError(f'Error reading --checkpoint file: {checkpoint_path}: {e}') from None except FileNotFoundError as e: raise CheckpointReplayError(f'FileNotFoundError: {e}') from None except OSError as e: @@ -133,7 +133,7 @@ def main_batch_with_progress_bar(mycli: 'MyCli', cli_args: 'CliArgs') -> int: click.secho(f'Error reading --batch file: {cli_args.batch}: {e}', err=True, fg='red') return 1 except CheckpointReplayError as e: - name = cli_args.checkpoint.name if cli_args.checkpoint else 'None' + name = cli_args.checkpoint if cli_args.checkpoint else 'None' click.secho(f'Error replaying --checkpoint file: {name}: {e}', err=True, fg='red') return 1 try: @@ -175,7 +175,7 @@ def main_batch_without_progress_bar(mycli: 'MyCli', cli_args: 'CliArgs') -> int: click.secho(f'Failed to open --batch file: {cli_args.batch}', err=True, fg='red') return 1 except CheckpointReplayError as e: - name = cli_args.checkpoint.name if cli_args.checkpoint else 'None' + name = cli_args.checkpoint if cli_args.checkpoint else 'None' click.secho(f'Error replaying --checkpoint file: {name}: {e}', err=True, fg='red') return 1 try: diff --git a/test/pytests/test_client_query.py b/test/pytests/test_client_query.py index 6b4ba800..8d8fd9ff 100644 --- a/test/pytests/test_client_query.py +++ b/test/pytests/test_client_query.py @@ -202,11 +202,7 @@ def format_sqlresult(result: SQLResult, **kwargs: Any) -> list[str]: cli.log_query = lambda query: state['logged_queries'].append(query) cli.log_output = lambda line: state['logged_output'].append(line) cli.format_sqlresult = format_sqlresult - checkpoint = state['checkpoint_path'].open('w+', encoding='utf-8') - try: - main.MyCli.run_query(cli, 'select 1;\n', checkpoint=checkpoint, new_line=False) - finally: - checkpoint.close() + main.MyCli.run_query(cli, 'select 1;\n', checkpoint=str(state['checkpoint_path']), new_line=False) state['cli'] = cli return state diff --git a/test/pytests/test_main.py b/test/pytests/test_main.py index 8f59ac09..b16d3495 100644 --- a/test/pytests/test_main.py +++ b/test/pytests/test_main.py @@ -2711,6 +2711,23 @@ def test_preprocess_cli_args_validates_resume_requirements( assert expected in capsys.readouterr().err +def test_preprocess_cli_args_rejects_same_batch_and_checkpoint_file( + capsys: pytest.CaptureFixture[str], + tmp_path: Path, +) -> None: + batch_path = tmp_path / 'batch.sql' + batch_path.write_text('select 1;\n', encoding='utf-8') + cli_args = CliArgs() + cli_args.batch = str(batch_path) + cli_args.checkpoint = str(batch_path) + + with pytest.raises(SystemExit) as excinfo: + preprocess_cli_args(cli_args, valid_connection_scheme) + + assert excinfo.value.code == 1 + assert 'Error: --batch and --checkpoint must be different files.' in capsys.readouterr().err + + def test_preprocess_cli_args_rejects_verbose_and_quiet(capsys: pytest.CaptureFixture[str]) -> None: cli_args = CliArgs() cli_args.verbose = 1 diff --git a/test/pytests/test_main_modes_batch.py b/test/pytests/test_main_modes_batch.py index 27488fbe..777f7d67 100644 --- a/test/pytests/test_main_modes_batch.py +++ b/test/pytests/test_main_modes_batch.py @@ -153,10 +153,10 @@ def write_batch_file(tmp_path: Path, contents: str) -> str: return str(batch_path) -def open_checkpoint_file(tmp_path: Path, contents: str) -> TextIOWrapper: +def write_checkpoint_file(tmp_path: Path, contents: str) -> str: checkpoint_path = tmp_path / 'checkpoint.sql' checkpoint_path.write_text(contents, encoding='utf-8') - return checkpoint_path.open('a', encoding='utf-8') + return str(checkpoint_path) def test_replay_checkpoint_file_returns_zero_without_replayable_batch(tmp_path: Path) -> None: @@ -164,17 +164,26 @@ def test_replay_checkpoint_file_returns_zero_without_replayable_batch(tmp_path: assert batch_mode.replay_checkpoint_file(batch_path, None, resume=True) == 0 - with open_checkpoint_file(tmp_path, 'select 1;\n') as checkpoint: - with pytest.raises(batch_mode.CheckpointReplayError, match='incompatible with reading from the standard input'): - batch_mode.replay_checkpoint_file('-', checkpoint, resume=True) + checkpoint = write_checkpoint_file(tmp_path, 'select 1;\n') + + with pytest.raises(batch_mode.CheckpointReplayError, match='incompatible with reading from the standard input'): + batch_mode.replay_checkpoint_file('-', checkpoint, resume=True) + + +def test_replay_checkpoint_file_returns_zero_when_checkpoint_is_missing(tmp_path: Path) -> None: + batch_path = write_batch_file(tmp_path, 'select 1;\n') + checkpoint_path = str(tmp_path / 'missing-checkpoint.sql') + + assert batch_mode.replay_checkpoint_file(batch_path, checkpoint_path, resume=True) == 0 def test_replay_checkpoint_file_rejects_checkpoint_longer_than_batch(tmp_path: Path) -> None: batch_path = write_batch_file(tmp_path, 'select 1;\n') - with open_checkpoint_file(tmp_path, 'select 1;\nselect 2;\n') as checkpoint: - with pytest.raises(batch_mode.CheckpointReplayError, match='Checkpoint script longer than batch script.'): - batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True) + checkpoint = write_checkpoint_file(tmp_path, 'select 1;\nselect 2;\n') + + with pytest.raises(batch_mode.CheckpointReplayError, match='Checkpoint script longer than batch script.'): + batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True) @pytest.mark.skipif(os.name == 'nt', reason='todo: unknown') @@ -183,9 +192,10 @@ def test_replay_checkpoint_file_rejects_batch_read_error(monkeypatch, tmp_path: monkeypatch.setattr(batch_mode, 'statements_from_filehandle', lambda _handle: (_ for _ in ()).throw(ValueError('bad batch'))) - with open_checkpoint_file(tmp_path, 'select 1;\n') as checkpoint: - with pytest.raises(batch_mode.CheckpointReplayError, match=f'Error reading --batch file: {batch_path}: bad batch'): - batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True) + checkpoint = write_checkpoint_file(tmp_path, 'select 1;\n') + + with pytest.raises(batch_mode.CheckpointReplayError, match=f'Error reading --batch file: {batch_path}: bad batch'): + batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True) @pytest.mark.skipif(os.name == 'nt', reason='todo: unknown') @@ -203,9 +213,10 @@ def fake_statements_from_filehandle(handle): monkeypatch.setattr(batch_mode, 'statements_from_filehandle', fake_statements_from_filehandle) - with open_checkpoint_file(tmp_path, 'select 1;\n') as checkpoint: - with pytest.raises(batch_mode.CheckpointReplayError, match=f'Error reading --batch file: {batch_path}: bad batch iterator'): - batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True) + checkpoint = write_checkpoint_file(tmp_path, 'select 1;\n') + + with pytest.raises(batch_mode.CheckpointReplayError, match=f'Error reading --batch file: {batch_path}: bad batch iterator'): + batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True) @pytest.mark.skipif(os.name == 'nt', reason='todo: unknown') @@ -219,17 +230,19 @@ def fake_statements_from_filehandle(handle): monkeypatch.setattr(batch_mode, 'statements_from_filehandle', fake_statements_from_filehandle) - with open_checkpoint_file(tmp_path, 'select 1;\n') as checkpoint: - with pytest.raises(batch_mode.CheckpointReplayError, match=f'Error reading --checkpoint file: {checkpoint.name}: bad checkpoint'): - batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True) + checkpoint = write_checkpoint_file(tmp_path, 'select 1;\n') + + with pytest.raises(batch_mode.CheckpointReplayError, match=f'Error reading --checkpoint file: {checkpoint}: bad checkpoint'): + batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True) def test_replay_checkpoint_file_rejects_missing_files(tmp_path: Path) -> None: batch_path = str(tmp_path / 'missing.sql') - with open_checkpoint_file(tmp_path, 'select 1;\n') as checkpoint: - with pytest.raises(batch_mode.CheckpointReplayError, match='FileNotFoundError'): - batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True) + checkpoint = write_checkpoint_file(tmp_path, 'select 1;\n') + + with pytest.raises(batch_mode.CheckpointReplayError, match='FileNotFoundError'): + batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True) def test_replay_checkpoint_file_rejects_open_errors(monkeypatch, tmp_path: Path) -> None: @@ -237,9 +250,10 @@ def test_replay_checkpoint_file_rejects_open_errors(monkeypatch, tmp_path: Path) monkeypatch.setattr(batch_mode.click, 'open_file', lambda *_args, **_kwargs: (_ for _ in ()).throw(OSError('open failed'))) - with open_checkpoint_file(tmp_path, 'select 1;\n') as checkpoint: - with pytest.raises(batch_mode.CheckpointReplayError, match='OSError'): - batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True) + checkpoint = write_checkpoint_file(tmp_path, 'select 1;\n') + + with pytest.raises(batch_mode.CheckpointReplayError, match='OSError'): + batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True) @pytest.mark.parametrize( @@ -514,10 +528,10 @@ def test_main_batch_without_progress_bar_skips_checkpoint_prefix(monkeypatch, tm ) monkeypatch.setattr(batch_mode, 'sys', make_fake_sys(stdin_tty=True)) - with open_checkpoint_file(tmp_path, 'select 1;\nselect 2;\n') as checkpoint: - cli_args = DummyCliArgs(batch=batch_path, checkpoint=checkpoint, resume=True) + checkpoint = write_checkpoint_file(tmp_path, 'select 1;\nselect 2;\n') - result = main_batch_without_progress_bar(DummyMyCli(), cli_args) + cli_args = DummyCliArgs(batch=batch_path, checkpoint=checkpoint, resume=True) + result = main_batch_without_progress_bar(DummyMyCli(), cli_args) assert result == 0 assert dispatch_calls == [('select 3;', 2)] @@ -534,10 +548,10 @@ def test_main_batch_without_progress_bar_skips_only_matching_duplicate_prefix(mo ) monkeypatch.setattr(batch_mode, 'sys', make_fake_sys(stdin_tty=True)) - with open_checkpoint_file(tmp_path, 'select 1;\n') as checkpoint: - cli_args = DummyCliArgs(batch=batch_path, checkpoint=checkpoint, resume=True) + checkpoint = write_checkpoint_file(tmp_path, 'select 1;\n') - result = main_batch_without_progress_bar(DummyMyCli(), cli_args) + cli_args = DummyCliArgs(batch=batch_path, checkpoint=checkpoint, resume=True) + result = main_batch_without_progress_bar(DummyMyCli(), cli_args) assert result == 0 assert dispatch_calls == [('select 1;', 1), ('select 2;', 2)] @@ -554,10 +568,10 @@ def test_main_batch_without_progress_bar_fails_on_mismatched_checkpoint(monkeypa ) monkeypatch.setattr(batch_mode, 'sys', make_fake_sys(stdin_tty=True)) - with open_checkpoint_file(tmp_path, 'select 9;\n') as checkpoint: - cli_args = DummyCliArgs(batch=batch_path, checkpoint=checkpoint, resume=True) + checkpoint = write_checkpoint_file(tmp_path, 'select 9;\n') - result = main_batch_without_progress_bar(DummyMyCli(), cli_args) + cli_args = DummyCliArgs(batch=batch_path, checkpoint=checkpoint, resume=True) + result = main_batch_without_progress_bar(DummyMyCli(), cli_args) assert result == 1 assert dispatch_calls == [] @@ -574,10 +588,10 @@ def test_main_batch_without_progress_bar_succeeds_when_checkpoint_skips_all(monk ) monkeypatch.setattr(batch_mode, 'sys', make_fake_sys(stdin_tty=True)) - with open_checkpoint_file(tmp_path, 'select 1;\nselect 2;\n') as checkpoint: - cli_args = DummyCliArgs(batch=batch_path, checkpoint=checkpoint, resume=True) + checkpoint = write_checkpoint_file(tmp_path, 'select 1;\nselect 2;\n') - result = main_batch_without_progress_bar(DummyMyCli(), cli_args) + cli_args = DummyCliArgs(batch=batch_path, checkpoint=checkpoint, resume=True) + result = main_batch_without_progress_bar(DummyMyCli(), cli_args) assert result == 0 assert dispatch_calls == [] @@ -597,10 +611,10 @@ def test_main_batch_with_progress_bar_skips_checkpoint_prefix_and_counts_all_sta ) monkeypatch.setattr(batch_mode, 'sys', make_fake_sys(stdin_tty=True)) - with open_checkpoint_file(tmp_path, 'select 1;\n') as checkpoint: - cli_args = DummyCliArgs(batch=batch_path, checkpoint=checkpoint, resume=True) + checkpoint = write_checkpoint_file(tmp_path, 'select 1;\n') - result = main_batch_with_progress_bar(DummyMyCli(), cli_args) + cli_args = DummyCliArgs(batch=batch_path, checkpoint=checkpoint, resume=True) + result = main_batch_with_progress_bar(DummyMyCli(), cli_args) assert result == 0 assert dispatch_calls == [('select 2;', 1), ('select 3;', 2)] @@ -614,13 +628,13 @@ def test_main_batch_with_progress_bar_returns_error_when_checkpoint_replay_fails monkeypatch.setattr(batch_mode.click, 'secho', lambda message, err, fg: messages.append((message, err, fg))) monkeypatch.setattr(batch_mode, 'sys', make_fake_sys(stdin_tty=True)) - with open_checkpoint_file(tmp_path, 'select 9;\n') as checkpoint: - cli_args = DummyCliArgs(batch=batch_path, checkpoint=checkpoint, resume=True) + checkpoint = write_checkpoint_file(tmp_path, 'select 9;\n') - result = main_batch_with_progress_bar(DummyMyCli(), cli_args) + cli_args = DummyCliArgs(batch=batch_path, checkpoint=checkpoint, resume=True) + result = main_batch_with_progress_bar(DummyMyCli(), cli_args) assert result == 1 - assert messages == [(f'Error replaying --checkpoint file: {checkpoint.name}: Statement mismatch: select 9;.', True, 'red')] + assert messages == [(f'Error replaying --checkpoint file: {checkpoint}: Statement mismatch: select 9;.', True, 'red')] def test_main_batch_without_progress_bar_returns_error_when_iteration_fails(monkeypatch) -> None: diff --git a/test/utils.py b/test/utils.py index 1a92963d..ec78f874 100644 --- a/test/utils.py +++ b/test/utils.py @@ -198,6 +198,7 @@ def make_bare_mycli() -> Any: cli.configure_pager = lambda: None # type: ignore[assignment] cli.refresh_completions = lambda reset=False: [SQLResult(status='refresh')] # type: ignore[assignment] cli.reconnect = lambda database='': False # type: ignore[assignment] + cli.checkpoint = None return cli