Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
Upcoming (TBD)
==============

Bug Fixes
---------
* Ensure that `--batch` and `--checkpoint` files are distinct.


Internal
--------
* Improve test coverage for `completion_refresher.py`.
Expand Down
1 change: 1 addition & 0 deletions mycli/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def __init__(
self._keepalive_counter = 0
self.keepalive_ticks: int | None = 0
self.sandbox_mode: bool = False
self.checkpoint: IO | None = None

# self.cnf_files is a class variable that stores the list of mysql
# config files to read in at launch.
Expand Down
14 changes: 8 additions & 6 deletions mycli/client_query.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

from io import TextIOWrapper
from typing import TYPE_CHECKING, Any
from typing import IO, TYPE_CHECKING, Any

import click
from pymysql.cursors import Cursor
Expand All @@ -26,6 +25,7 @@ class ClientQueryMixin:
numeric_alignment: str | None
binary_display: str | None
query_history: list[Any]
checkpoint: IO | None

def log_query(self, query: str) -> None: ...
def log_output(self, output: str) -> None: ...
Expand Down Expand Up @@ -74,12 +74,14 @@ def _on_completions_refreshed(self, new_completer: SQLCompleter) -> None:
def run_query(
self,
query: str,
checkpoint: TextIOWrapper | None = None,
checkpoint: str | None = None,
new_line: bool = True,
) -> None:
"""Runs *query*."""
assert self.sqlexecute is not None
self.log_query(query)
if checkpoint and not self.checkpoint:
self.checkpoint = click.open_file(checkpoint, mode='a')
results = self.sqlexecute.run(query)
for result in results:
self.main_formatter.query = query
Expand Down Expand Up @@ -111,9 +113,9 @@ def run_query(
)
for line in output:
click.echo(line, nl=new_line)
if checkpoint:
checkpoint.write(query.rstrip('\n') + '\n')
checkpoint.flush()
if self.checkpoint:
self.checkpoint.write(query.rstrip('\n') + '\n')
self.checkpoint.flush()

def get_last_query(self) -> str | None:
"""Get the last query executed or None."""
Expand Down
17 changes: 14 additions & 3 deletions mycli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,9 @@ class CliArgs:
type=click.File(mode='a', encoding='utf-8'),
help='Log every query and its results to a file.',
)
checkpoint: TextIOWrapper | None = clickdc.option(
type=click.File(mode='a', encoding='utf-8'),
help='In batch or --execute mode, log successful queries to a file, and skipped with --resume.',
checkpoint: str | None = clickdc.option(
type=str,
help='In batch or --execute mode, log successful queries to a file, and skip them with --resume.',
)
resume: bool = clickdc.option(
'--resume',
Expand Down Expand Up @@ -369,6 +369,17 @@ def preprocess_cli_args(
click.secho('Error: --resume requires a --batch file.', err=True, fg='red')
sys.exit(1)

if (
cli_args.checkpoint
and os.path.exists(cli_args.checkpoint)
and cli_args.batch
and cli_args.batch != '-'
and os.path.exists(cli_args.batch)
):
if os.stat(cli_args.batch) == os.stat(cli_args.checkpoint):
click.secho('Error: --batch and --checkpoint must be different files.', err=True, fg='red')
sys.exit(1)

if cli_args.verbose and cli_args.quiet:
click.secho('Error: --verbose and --quiet are incompatible.', err=True, fg='red')
sys.exit(1)
Expand Down
18 changes: 9 additions & 9 deletions mycli/main_modes/batch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

from io import TextIOWrapper
import os
import sys
import time
Expand All @@ -27,23 +26,24 @@ class CheckpointReplayError(Exception):

def replay_checkpoint_file(
batch_path: str,
checkpoint: TextIOWrapper | None,
checkpoint_path: str | None,
resume: bool,
) -> int:
if not resume:
return 0

if checkpoint is None:
if checkpoint_path is None:
return 0

if not os.path.exists(checkpoint_path):
return 0

if batch_path == '-':
raise CheckpointReplayError('--resume is incompatible with reading from the standard input.')

checkpoint_name = checkpoint.name
checkpoint.flush()
completed_count = 0
try:
with click.open_file(batch_path) as batch_h, click.open_file(checkpoint_name, mode='r', encoding='utf-8') as checkpoint_h:
with click.open_file(batch_path) as batch_h, click.open_file(checkpoint_path, mode='r', encoding='utf-8') as checkpoint_h:
try:
batch_gen = statements_from_filehandle(batch_h)
except ValueError as e:
Expand All @@ -59,7 +59,7 @@ def replay_checkpoint_file(
raise CheckpointReplayError(f'Statement mismatch: {checkpoint_statement}.')
completed_count += 1
except ValueError as e:
raise CheckpointReplayError(f'Error reading --checkpoint file: {checkpoint.name}: {e}') from None
raise CheckpointReplayError(f'Error reading --checkpoint file: {checkpoint_path}: {e}') from None
except FileNotFoundError as e:
raise CheckpointReplayError(f'FileNotFoundError: {e}') from None
except OSError as e:
Expand Down Expand Up @@ -133,7 +133,7 @@ def main_batch_with_progress_bar(mycli: 'MyCli', cli_args: 'CliArgs') -> int:
click.secho(f'Error reading --batch file: {cli_args.batch}: {e}', err=True, fg='red')
return 1
except CheckpointReplayError as e:
name = cli_args.checkpoint.name if cli_args.checkpoint else 'None'
name = cli_args.checkpoint if cli_args.checkpoint else 'None'
click.secho(f'Error replaying --checkpoint file: {name}: {e}', err=True, fg='red')
return 1
try:
Expand Down Expand Up @@ -175,7 +175,7 @@ def main_batch_without_progress_bar(mycli: 'MyCli', cli_args: 'CliArgs') -> int:
click.secho(f'Failed to open --batch file: {cli_args.batch}', err=True, fg='red')
return 1
except CheckpointReplayError as e:
name = cli_args.checkpoint.name if cli_args.checkpoint else 'None'
name = cli_args.checkpoint if cli_args.checkpoint else 'None'
click.secho(f'Error replaying --checkpoint file: {name}: {e}', err=True, fg='red')
return 1
try:
Expand Down
6 changes: 1 addition & 5 deletions test/pytests/test_client_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,11 +202,7 @@ def format_sqlresult(result: SQLResult, **kwargs: Any) -> list[str]:
cli.log_query = lambda query: state['logged_queries'].append(query)
cli.log_output = lambda line: state['logged_output'].append(line)
cli.format_sqlresult = format_sqlresult
checkpoint = state['checkpoint_path'].open('w+', encoding='utf-8')
try:
main.MyCli.run_query(cli, 'select 1;\n', checkpoint=checkpoint, new_line=False)
finally:
checkpoint.close()
main.MyCli.run_query(cli, 'select 1;\n', checkpoint=str(state['checkpoint_path']), new_line=False)
state['cli'] = cli
return state

Expand Down
17 changes: 17 additions & 0 deletions test/pytests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2711,6 +2711,23 @@ def test_preprocess_cli_args_validates_resume_requirements(
assert expected in capsys.readouterr().err


def test_preprocess_cli_args_rejects_same_batch_and_checkpoint_file(
capsys: pytest.CaptureFixture[str],
tmp_path: Path,
) -> None:
batch_path = tmp_path / 'batch.sql'
batch_path.write_text('select 1;\n', encoding='utf-8')
cli_args = CliArgs()
cli_args.batch = str(batch_path)
cli_args.checkpoint = str(batch_path)

with pytest.raises(SystemExit) as excinfo:
preprocess_cli_args(cli_args, valid_connection_scheme)

assert excinfo.value.code == 1
assert 'Error: --batch and --checkpoint must be different files.' in capsys.readouterr().err


def test_preprocess_cli_args_rejects_verbose_and_quiet(capsys: pytest.CaptureFixture[str]) -> None:
cli_args = CliArgs()
cli_args.verbose = 1
Expand Down
98 changes: 56 additions & 42 deletions test/pytests/test_main_modes_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,28 +153,37 @@ def write_batch_file(tmp_path: Path, contents: str) -> str:
return str(batch_path)


def open_checkpoint_file(tmp_path: Path, contents: str) -> TextIOWrapper:
def write_checkpoint_file(tmp_path: Path, contents: str) -> str:
checkpoint_path = tmp_path / 'checkpoint.sql'
checkpoint_path.write_text(contents, encoding='utf-8')
return checkpoint_path.open('a', encoding='utf-8')
return str(checkpoint_path)


def test_replay_checkpoint_file_returns_zero_without_replayable_batch(tmp_path: Path) -> None:
batch_path = write_batch_file(tmp_path, 'select 1;\n')

assert batch_mode.replay_checkpoint_file(batch_path, None, resume=True) == 0

with open_checkpoint_file(tmp_path, 'select 1;\n') as checkpoint:
with pytest.raises(batch_mode.CheckpointReplayError, match='incompatible with reading from the standard input'):
batch_mode.replay_checkpoint_file('-', checkpoint, resume=True)
checkpoint = write_checkpoint_file(tmp_path, 'select 1;\n')

with pytest.raises(batch_mode.CheckpointReplayError, match='incompatible with reading from the standard input'):
batch_mode.replay_checkpoint_file('-', checkpoint, resume=True)


def test_replay_checkpoint_file_returns_zero_when_checkpoint_is_missing(tmp_path: Path) -> None:
batch_path = write_batch_file(tmp_path, 'select 1;\n')
checkpoint_path = str(tmp_path / 'missing-checkpoint.sql')

assert batch_mode.replay_checkpoint_file(batch_path, checkpoint_path, resume=True) == 0


def test_replay_checkpoint_file_rejects_checkpoint_longer_than_batch(tmp_path: Path) -> None:
batch_path = write_batch_file(tmp_path, 'select 1;\n')

with open_checkpoint_file(tmp_path, 'select 1;\nselect 2;\n') as checkpoint:
with pytest.raises(batch_mode.CheckpointReplayError, match='Checkpoint script longer than batch script.'):
batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True)
checkpoint = write_checkpoint_file(tmp_path, 'select 1;\nselect 2;\n')

with pytest.raises(batch_mode.CheckpointReplayError, match='Checkpoint script longer than batch script.'):
batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True)


@pytest.mark.skipif(os.name == 'nt', reason='todo: unknown')
Expand All @@ -183,9 +192,10 @@ def test_replay_checkpoint_file_rejects_batch_read_error(monkeypatch, tmp_path:

monkeypatch.setattr(batch_mode, 'statements_from_filehandle', lambda _handle: (_ for _ in ()).throw(ValueError('bad batch')))

with open_checkpoint_file(tmp_path, 'select 1;\n') as checkpoint:
with pytest.raises(batch_mode.CheckpointReplayError, match=f'Error reading --batch file: {batch_path}: bad batch'):
batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True)
checkpoint = write_checkpoint_file(tmp_path, 'select 1;\n')

with pytest.raises(batch_mode.CheckpointReplayError, match=f'Error reading --batch file: {batch_path}: bad batch'):
batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True)


@pytest.mark.skipif(os.name == 'nt', reason='todo: unknown')
Expand All @@ -203,9 +213,10 @@ def fake_statements_from_filehandle(handle):

monkeypatch.setattr(batch_mode, 'statements_from_filehandle', fake_statements_from_filehandle)

with open_checkpoint_file(tmp_path, 'select 1;\n') as checkpoint:
with pytest.raises(batch_mode.CheckpointReplayError, match=f'Error reading --batch file: {batch_path}: bad batch iterator'):
batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True)
checkpoint = write_checkpoint_file(tmp_path, 'select 1;\n')

with pytest.raises(batch_mode.CheckpointReplayError, match=f'Error reading --batch file: {batch_path}: bad batch iterator'):
batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True)


@pytest.mark.skipif(os.name == 'nt', reason='todo: unknown')
Expand All @@ -219,27 +230,30 @@ def fake_statements_from_filehandle(handle):

monkeypatch.setattr(batch_mode, 'statements_from_filehandle', fake_statements_from_filehandle)

with open_checkpoint_file(tmp_path, 'select 1;\n') as checkpoint:
with pytest.raises(batch_mode.CheckpointReplayError, match=f'Error reading --checkpoint file: {checkpoint.name}: bad checkpoint'):
batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True)
checkpoint = write_checkpoint_file(tmp_path, 'select 1;\n')

with pytest.raises(batch_mode.CheckpointReplayError, match=f'Error reading --checkpoint file: {checkpoint}: bad checkpoint'):
batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True)


def test_replay_checkpoint_file_rejects_missing_files(tmp_path: Path) -> None:
batch_path = str(tmp_path / 'missing.sql')

with open_checkpoint_file(tmp_path, 'select 1;\n') as checkpoint:
with pytest.raises(batch_mode.CheckpointReplayError, match='FileNotFoundError'):
batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True)
checkpoint = write_checkpoint_file(tmp_path, 'select 1;\n')

with pytest.raises(batch_mode.CheckpointReplayError, match='FileNotFoundError'):
batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True)


def test_replay_checkpoint_file_rejects_open_errors(monkeypatch, tmp_path: Path) -> None:
batch_path = write_batch_file(tmp_path, 'select 1;\n')

monkeypatch.setattr(batch_mode.click, 'open_file', lambda *_args, **_kwargs: (_ for _ in ()).throw(OSError('open failed')))

with open_checkpoint_file(tmp_path, 'select 1;\n') as checkpoint:
with pytest.raises(batch_mode.CheckpointReplayError, match='OSError'):
batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True)
checkpoint = write_checkpoint_file(tmp_path, 'select 1;\n')

with pytest.raises(batch_mode.CheckpointReplayError, match='OSError'):
batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -514,10 +528,10 @@ def test_main_batch_without_progress_bar_skips_checkpoint_prefix(monkeypatch, tm
)
monkeypatch.setattr(batch_mode, 'sys', make_fake_sys(stdin_tty=True))

with open_checkpoint_file(tmp_path, 'select 1;\nselect 2;\n') as checkpoint:
cli_args = DummyCliArgs(batch=batch_path, checkpoint=checkpoint, resume=True)
checkpoint = write_checkpoint_file(tmp_path, 'select 1;\nselect 2;\n')

result = main_batch_without_progress_bar(DummyMyCli(), cli_args)
cli_args = DummyCliArgs(batch=batch_path, checkpoint=checkpoint, resume=True)
result = main_batch_without_progress_bar(DummyMyCli(), cli_args)

assert result == 0
assert dispatch_calls == [('select 3;', 2)]
Expand All @@ -534,10 +548,10 @@ def test_main_batch_without_progress_bar_skips_only_matching_duplicate_prefix(mo
)
monkeypatch.setattr(batch_mode, 'sys', make_fake_sys(stdin_tty=True))

with open_checkpoint_file(tmp_path, 'select 1;\n') as checkpoint:
cli_args = DummyCliArgs(batch=batch_path, checkpoint=checkpoint, resume=True)
checkpoint = write_checkpoint_file(tmp_path, 'select 1;\n')

result = main_batch_without_progress_bar(DummyMyCli(), cli_args)
cli_args = DummyCliArgs(batch=batch_path, checkpoint=checkpoint, resume=True)
result = main_batch_without_progress_bar(DummyMyCli(), cli_args)

assert result == 0
assert dispatch_calls == [('select 1;', 1), ('select 2;', 2)]
Expand All @@ -554,10 +568,10 @@ def test_main_batch_without_progress_bar_fails_on_mismatched_checkpoint(monkeypa
)
monkeypatch.setattr(batch_mode, 'sys', make_fake_sys(stdin_tty=True))

with open_checkpoint_file(tmp_path, 'select 9;\n') as checkpoint:
cli_args = DummyCliArgs(batch=batch_path, checkpoint=checkpoint, resume=True)
checkpoint = write_checkpoint_file(tmp_path, 'select 9;\n')

result = main_batch_without_progress_bar(DummyMyCli(), cli_args)
cli_args = DummyCliArgs(batch=batch_path, checkpoint=checkpoint, resume=True)
result = main_batch_without_progress_bar(DummyMyCli(), cli_args)

assert result == 1
assert dispatch_calls == []
Expand All @@ -574,10 +588,10 @@ def test_main_batch_without_progress_bar_succeeds_when_checkpoint_skips_all(monk
)
monkeypatch.setattr(batch_mode, 'sys', make_fake_sys(stdin_tty=True))

with open_checkpoint_file(tmp_path, 'select 1;\nselect 2;\n') as checkpoint:
cli_args = DummyCliArgs(batch=batch_path, checkpoint=checkpoint, resume=True)
checkpoint = write_checkpoint_file(tmp_path, 'select 1;\nselect 2;\n')

result = main_batch_without_progress_bar(DummyMyCli(), cli_args)
cli_args = DummyCliArgs(batch=batch_path, checkpoint=checkpoint, resume=True)
result = main_batch_without_progress_bar(DummyMyCli(), cli_args)

assert result == 0
assert dispatch_calls == []
Expand All @@ -597,10 +611,10 @@ def test_main_batch_with_progress_bar_skips_checkpoint_prefix_and_counts_all_sta
)
monkeypatch.setattr(batch_mode, 'sys', make_fake_sys(stdin_tty=True))

with open_checkpoint_file(tmp_path, 'select 1;\n') as checkpoint:
cli_args = DummyCliArgs(batch=batch_path, checkpoint=checkpoint, resume=True)
checkpoint = write_checkpoint_file(tmp_path, 'select 1;\n')

result = main_batch_with_progress_bar(DummyMyCli(), cli_args)
cli_args = DummyCliArgs(batch=batch_path, checkpoint=checkpoint, resume=True)
result = main_batch_with_progress_bar(DummyMyCli(), cli_args)

assert result == 0
assert dispatch_calls == [('select 2;', 1), ('select 3;', 2)]
Expand All @@ -614,13 +628,13 @@ def test_main_batch_with_progress_bar_returns_error_when_checkpoint_replay_fails
monkeypatch.setattr(batch_mode.click, 'secho', lambda message, err, fg: messages.append((message, err, fg)))
monkeypatch.setattr(batch_mode, 'sys', make_fake_sys(stdin_tty=True))

with open_checkpoint_file(tmp_path, 'select 9;\n') as checkpoint:
cli_args = DummyCliArgs(batch=batch_path, checkpoint=checkpoint, resume=True)
checkpoint = write_checkpoint_file(tmp_path, 'select 9;\n')

result = main_batch_with_progress_bar(DummyMyCli(), cli_args)
cli_args = DummyCliArgs(batch=batch_path, checkpoint=checkpoint, resume=True)
result = main_batch_with_progress_bar(DummyMyCli(), cli_args)

assert result == 1
assert messages == [(f'Error replaying --checkpoint file: {checkpoint.name}: Statement mismatch: select 9;.', True, 'red')]
assert messages == [(f'Error replaying --checkpoint file: {checkpoint}: Statement mismatch: select 9;.', True, 'red')]


def test_main_batch_without_progress_bar_returns_error_when_iteration_fails(monkeypatch) -> None:
Expand Down
Loading
Loading