Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ Upcoming (TBD)
==============

Features
---------
--------
* Silently accept forward slash to introduce special commands.
* `--progress` spinners for setup steps in `--batch` mode.


Internal
Expand Down
31 changes: 27 additions & 4 deletions mycli/main_modes/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from prompt_toolkit.shortcuts import ProgressBar
from prompt_toolkit.shortcuts.progress_bar import formatters as progress_bar_formatters
import pymysql
from yaspin import yaspin

from mycli.packages.batch_utils import statements_from_filehandle
from mycli.packages.interactive_utils import confirm_destructive_query
Expand All @@ -28,6 +29,7 @@ def replay_checkpoint_file(
batch_path: str,
checkpoint_path: str | None,
resume: bool,
progress: bool = False,
) -> int:
if not resume:
return 0
Expand All @@ -42,27 +44,46 @@ def replay_checkpoint_file(
raise CheckpointReplayError('--resume is incompatible with reading from the standard input.')

completed_count = 0
if progress:
spinner = yaspin(text='replaying checkpoint', side='right', stream=sys.stderr)
spinner.start()
try:
with click.open_file(batch_path) as batch_h, click.open_file(checkpoint_path, mode='r', encoding='utf-8') as checkpoint_h:
try:
batch_gen = statements_from_filehandle(batch_h)
except ValueError as e:
if progress:
spinner.fail('✘')
raise CheckpointReplayError(f'Error reading --batch file: {batch_path}: {e}') from None
for checkpoint_statement, _checkpoint_counter in statements_from_filehandle(checkpoint_h):
try:
batch_statement, _batch_counter = next(batch_gen)
except StopIteration:
if progress:
spinner.fail('✘')
raise CheckpointReplayError('Checkpoint script longer than batch script.') from None
except ValueError as e:
if progress:
spinner.fail('✘')
raise CheckpointReplayError(f'Error reading --batch file: {batch_path}: {e}') from None
if checkpoint_statement != batch_statement:
if progress:
spinner.fail('✘')
raise CheckpointReplayError(f'Statement mismatch: {checkpoint_statement}.')
completed_count += 1
if progress:
spinner.ok('✔')
except ValueError as e:
if progress:
spinner.fail('✘')
raise CheckpointReplayError(f'Error reading --checkpoint file: {checkpoint_path}: {e}') from None
except FileNotFoundError as e:
if progress:
spinner.fail('✘')
raise CheckpointReplayError(f'FileNotFoundError: {e}') from None
except OSError as e:
if progress:
spinner.fail('✘')
raise CheckpointReplayError(f'OSError: {e}') from None

return completed_count
Expand Down Expand Up @@ -119,10 +140,12 @@ def main_batch_with_progress_bar(mycli: 'MyCli', cli_args: 'CliArgs') -> int:
click.secho('--progress is only compatible with a plain file.', err=True, fg='red')
return 1
try:
completed_statement_count = replay_checkpoint_file(cli_args.batch, cli_args.checkpoint, cli_args.resume)
completed_statement_count = replay_checkpoint_file(cli_args.batch, cli_args.checkpoint, cli_args.resume, progress=True)
batch_count_h = click.open_file(cli_args.batch)
for _statement, _counter in statements_from_filehandle(batch_count_h):
goal_statements += 1
with yaspin(text='validating batch ', side='right', stream=sys.stderr) as spinner:
for _statement, _counter in statements_from_filehandle(batch_count_h):
goal_statements += 1
spinner.ok('✔')
batch_count_h.close()
batch_h = click.open_file(cli_args.batch)
batch_gen = statements_from_filehandle(batch_h)
Expand All @@ -140,7 +163,7 @@ def main_batch_with_progress_bar(mycli: 'MyCli', cli_args: 'CliArgs') -> int:
if goal_statements:
pb_style = prompt_toolkit.styles.Style.from_dict({'bar-a': 'reverse'})
custom_formatters = [
progress_bar_formatters.Bar(start='[', end=']', sym_a=' ', sym_b=' ', sym_c=' '),
progress_bar_formatters.Bar(start='running queries [', end=']', sym_a=' ', sym_b=' ', sym_c=' '),
progress_bar_formatters.Text(' '),
progress_bar_formatters.Progress(),
progress_bar_formatters.Text(' '),
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ dependencies = [
"pyfzf ~= 0.3.1",
"rapidfuzz ~= 3.14.3",
"keyring ~= 25.7.0",
"yaspin ~= 3.4.0",
]

[project.urls]
Expand Down
145 changes: 144 additions & 1 deletion test/pytests/test_main_modes_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,23 @@ def close(self) -> None:
self.closed = True


class DummyStream:
def __init__(self, tty: bool = False) -> None:
self.closed = False
self.tty = tty
self.writes: list[str] = []

def isatty(self) -> bool:
return self.tty

def write(self, value: str) -> int:
self.writes.append(value)
return len(value)

def flush(self) -> None:
return None


class DummyProgressBar:
calls: list[list[int]] = []

Expand All @@ -86,6 +103,32 @@ def __call__(self, iterable) -> list[int]:
return values


class DummySpinner:
instances: list['DummySpinner'] = []

def __init__(self, *args, **kwargs) -> None:
self.fail_calls: list[str] = []
self.ok_calls: list[str] = []
self.started = False
DummySpinner.instances.append(self)

def __enter__(self) -> 'DummySpinner':
self.start()
return self

def __exit__(self, exc_type, exc, tb) -> Literal[False]:
return False

def start(self) -> None:
self.started = True

def fail(self, text: str) -> None:
self.fail_calls.append(text)

def ok(self, text: str) -> None:
self.ok_calls.append(text)


def dispatch_batch_statements(
mycli: DummyMyCli,
cli_args: DummyCliArgs,
Expand All @@ -108,7 +151,7 @@ def main_batch_from_stdin(mycli: DummyMyCli, cli_args: DummyCliArgs) -> int:


def make_fake_sys(stdin_tty: bool, stderr_tty: bool | None = None) -> SimpleNamespace:
stderr = SimpleNamespace(isatty=lambda: stderr_tty) if stderr_tty is not None else object()
stderr = DummyStream(bool(stderr_tty))
return SimpleNamespace(
stdin=SimpleNamespace(isatty=lambda: stdin_tty),
stderr=stderr,
Expand Down Expand Up @@ -186,6 +229,21 @@ def test_replay_checkpoint_file_rejects_checkpoint_longer_than_batch(tmp_path: P
batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True)


def test_replay_checkpoint_file_marks_progress_failed_when_checkpoint_is_longer(
monkeypatch,
tmp_path: Path,
) -> None:
batch_path = write_batch_file(tmp_path, 'select 1;\n')
checkpoint = write_checkpoint_file(tmp_path, 'select 1;\nselect 2;\n')
DummySpinner.instances.clear()
monkeypatch.setattr(batch_mode, 'yaspin', DummySpinner)

with pytest.raises(batch_mode.CheckpointReplayError, match='Checkpoint script longer than batch script.'):
batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True, progress=True)

assert DummySpinner.instances[0].fail_calls == ['✘']


@pytest.mark.skipif(os.name == 'nt', reason='todo: unknown')
def test_replay_checkpoint_file_rejects_batch_read_error(monkeypatch, tmp_path: Path) -> None:
batch_path = write_batch_file(tmp_path, 'select 1;\n')
Expand All @@ -198,6 +256,20 @@ def test_replay_checkpoint_file_rejects_batch_read_error(monkeypatch, tmp_path:
batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True)


@pytest.mark.skipif(os.name == 'nt', reason='todo: unknown')
def test_replay_checkpoint_file_marks_progress_failed_for_batch_read_error(monkeypatch, tmp_path: Path) -> None:
batch_path = write_batch_file(tmp_path, 'select 1;\n')
checkpoint = write_checkpoint_file(tmp_path, 'select 1;\n')
DummySpinner.instances.clear()
monkeypatch.setattr(batch_mode, 'statements_from_filehandle', lambda _handle: (_ for _ in ()).throw(ValueError('bad batch')))
monkeypatch.setattr(batch_mode, 'yaspin', DummySpinner)

with pytest.raises(batch_mode.CheckpointReplayError, match=f'Error reading --batch file: {batch_path}: bad batch'):
batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True, progress=True)

assert DummySpinner.instances[0].fail_calls == ['✘']


@pytest.mark.skipif(os.name == 'nt', reason='todo: unknown')
def test_replay_checkpoint_file_rejects_batch_iteration_error(monkeypatch, tmp_path: Path) -> None:
batch_path = write_batch_file(tmp_path, 'select 1;\n')
Expand All @@ -219,6 +291,31 @@ def fake_statements_from_filehandle(handle):
batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True)


@pytest.mark.skipif(os.name == 'nt', reason='todo: unknown')
def test_replay_checkpoint_file_marks_progress_failed_for_batch_iteration_error(monkeypatch, tmp_path: Path) -> None:
batch_path = write_batch_file(tmp_path, 'select 1;\n')

def raise_on_next():
raise ValueError('bad batch iterator')
yield

def fake_statements_from_filehandle(handle):
if handle.name == batch_path:
return raise_on_next()
return iter([('select 1;', 0)])

DummySpinner.instances.clear()
monkeypatch.setattr(batch_mode, 'statements_from_filehandle', fake_statements_from_filehandle)
monkeypatch.setattr(batch_mode, 'yaspin', DummySpinner)

checkpoint = write_checkpoint_file(tmp_path, 'select 1;\n')

with pytest.raises(batch_mode.CheckpointReplayError, match=f'Error reading --batch file: {batch_path}: bad batch iterator'):
batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True, progress=True)

assert DummySpinner.instances[0].fail_calls == ['✘']


@pytest.mark.skipif(os.name == 'nt', reason='todo: unknown')
def test_replay_checkpoint_file_rejects_checkpoint_read_error(monkeypatch, tmp_path: Path) -> None:
batch_path = write_batch_file(tmp_path, 'select 1;\n')
Expand All @@ -236,6 +333,27 @@ def fake_statements_from_filehandle(handle):
batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True)


@pytest.mark.skipif(os.name == 'nt', reason='todo: unknown')
def test_replay_checkpoint_file_marks_progress_failed_for_checkpoint_read_error(monkeypatch, tmp_path: Path) -> None:
batch_path = write_batch_file(tmp_path, 'select 1;\n')

def fake_statements_from_filehandle(handle):
if handle.name == batch_path:
return iter([('select 1;', 0)])
return (_ for _ in ()).throw(ValueError('bad checkpoint'))

DummySpinner.instances.clear()
monkeypatch.setattr(batch_mode, 'statements_from_filehandle', fake_statements_from_filehandle)
monkeypatch.setattr(batch_mode, 'yaspin', DummySpinner)

checkpoint = write_checkpoint_file(tmp_path, 'select 1;\n')

with pytest.raises(batch_mode.CheckpointReplayError, match=f'Error reading --checkpoint file: {checkpoint}: bad checkpoint'):
batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True, progress=True)

assert DummySpinner.instances[0].fail_calls == ['✘']


def test_replay_checkpoint_file_rejects_missing_files(tmp_path: Path) -> None:
batch_path = str(tmp_path / 'missing.sql')

Expand All @@ -245,6 +363,18 @@ def test_replay_checkpoint_file_rejects_missing_files(tmp_path: Path) -> None:
batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True)


def test_replay_checkpoint_file_marks_progress_failed_for_missing_files(monkeypatch, tmp_path: Path) -> None:
batch_path = str(tmp_path / 'missing.sql')
checkpoint = write_checkpoint_file(tmp_path, 'select 1;\n')
DummySpinner.instances.clear()
monkeypatch.setattr(batch_mode, 'yaspin', DummySpinner)

with pytest.raises(batch_mode.CheckpointReplayError, match='FileNotFoundError'):
batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True, progress=True)

assert DummySpinner.instances[0].fail_calls == ['✘']


def test_replay_checkpoint_file_rejects_open_errors(monkeypatch, tmp_path: Path) -> None:
batch_path = write_batch_file(tmp_path, 'select 1;\n')

Expand All @@ -256,6 +386,19 @@ def test_replay_checkpoint_file_rejects_open_errors(monkeypatch, tmp_path: Path)
batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True)


def test_replay_checkpoint_file_marks_progress_failed_for_open_errors(monkeypatch, tmp_path: Path) -> None:
batch_path = write_batch_file(tmp_path, 'select 1;\n')
checkpoint = write_checkpoint_file(tmp_path, 'select 1;\n')
DummySpinner.instances.clear()
monkeypatch.setattr(batch_mode.click, 'open_file', lambda *_args, **_kwargs: (_ for _ in ()).throw(OSError('open failed')))
monkeypatch.setattr(batch_mode, 'yaspin', DummySpinner)

with pytest.raises(batch_mode.CheckpointReplayError, match='OSError'):
batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True, progress=True)

assert DummySpinner.instances[0].fail_calls == ['✘']


@pytest.mark.parametrize(
('format_name', 'batch_counter', 'expected'),
(
Expand Down
Loading