Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Upcoming (TBD)
Internal
--------
* Add test coverage for `client_commands.py`.
* Add test coverage for `cli_runner.py`.


1.74.1 (2026/06/18)
Expand Down
234 changes: 234 additions & 0 deletions test/pytests/test_cli_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
from __future__ import annotations

from types import SimpleNamespace
from typing import Any

import pytest

from mycli import cli_runner, main


class DummyLogger:
def __init__(self) -> None:
self.debug_calls: list[tuple[tuple[Any, ...], dict[str, Any]]] = []

def debug(self, *args: Any, **kwargs: Any) -> None:
self.debug_calls.append((args, kwargs))


class DummyMyCli:
def __init__(
self,
*,
config: dict[str, Any] | None = None,
my_cnf: dict[str, Any] | None = None,
config_without_package_defaults: dict[str, Any] | None = None,
) -> None:
self.config = config or default_config()
self.my_cnf = my_cnf or {'client': {}, 'mysqld': {}}
self.config_without_package_defaults = config_without_package_defaults or {}
self.default_keepalive_ticks = 5
self.ssl_mode: str | None = None
self.logger = DummyLogger()
self.dsn_alias: str | None = None
self.connect_calls: list[dict[str, Any]] = []
self.run_cli_called = False
self.close_called = False

def connect(self, **kwargs: Any) -> None:
self.connect_calls.append(dict(kwargs))

def run_cli(self) -> None:
self.run_cli_called = True

def close(self) -> None:
self.close_called = True


def default_config() -> dict[str, Any]:
return {
'main': {'use_keyring': 'false', 'my_cnf_transition_done': 'true'},
'connection': {'default_keepalive_ticks': 0},
'alias_dsn': {},
'init-commands': {},
'alias_dsn.init-commands': {},
}


def make_cli_args() -> main.CliArgs:
cli_args = main.CliArgs()
cli_args.format = None
cli_args.ssh_config_path = '/dev/null'
return cli_args


def run_with_client(
monkeypatch: pytest.MonkeyPatch,
cli_args: main.CliArgs,
client: DummyMyCli,
) -> DummyMyCli:
monkeypatch.setattr(main, 'preprocess_cli_args', lambda args, scheme_validator: 4)
monkeypatch.setattr(cli_runner.sys, 'stdin', SimpleNamespace(isatty=lambda: True))
monkeypatch.setattr(cli_runner.sys.stderr, 'isatty', lambda: False)
cli_runner.run_from_cli_args(cli_args, lambda **_kwargs: client)
return client


def test_run_from_cli_args_checkup_exits_zero(monkeypatch: pytest.MonkeyPatch) -> None:
cli_args = make_cli_args()
cli_args.checkup = True
client = DummyMyCli()
checkup_calls: list[DummyMyCli] = []
monkeypatch.setattr(cli_runner, 'main_checkup', lambda value: checkup_calls.append(value))
monkeypatch.setattr(main, 'preprocess_cli_args', lambda args, scheme_validator: 0)

with pytest.raises(SystemExit) as excinfo:
cli_runner.run_from_cli_args(cli_args, lambda **_kwargs: client)

assert excinfo.value.code == 0
assert checkup_calls == [client]


@pytest.mark.parametrize(
('csv', 'table', 'format_name', 'message'),
(
(True, False, 'table', 'Conflicting --csv and --format arguments.'),
(False, True, 'csv', 'Conflicting --table and --format arguments.'),
),
)
def test_run_from_cli_args_rejects_conflicting_format_flags(
monkeypatch: pytest.MonkeyPatch,
csv: bool,
table: bool,
format_name: str,
message: str,
) -> None:
cli_args = make_cli_args()
cli_args.csv = csv
cli_args.table = table
cli_args.format = format_name
secho_calls: list[tuple[str, dict[str, Any]]] = []
monkeypatch.setattr(cli_runner.click, 'secho', lambda text, **kwargs: secho_calls.append((text, kwargs)))
monkeypatch.setattr(main, 'preprocess_cli_args', lambda args, scheme_validator: 0)

with pytest.raises(SystemExit) as excinfo:
cli_runner.run_from_cli_args(cli_args, lambda **_kwargs: DummyMyCli())

assert excinfo.value.code == 1
assert secho_calls == [(message, {'err': True, 'fg': 'red'})]


def test_run_from_cli_args_uses_deprecated_mysql_unix_port_and_database_alias(
monkeypatch: pytest.MonkeyPatch,
) -> None:
cli_args = make_cli_args()
cli_args.database = 'prod'
client = DummyMyCli(
config={
**default_config(),
'alias_dsn': {'prod': 'mysql://dsn_user:dsn_pass@dsn_host:3307/dsn_db'},
}
)
secho_calls: list[str] = []
monkeypatch.setenv('MYSQL_UNIX_PORT', '/tmp/mysql.sock')
monkeypatch.setattr(cli_runner.click, 'secho', lambda text, **_kwargs: secho_calls.append(text))

run_with_client(monkeypatch, cli_args, client)

assert client.dsn_alias == 'prod'
assert client.connect_calls[-1]['database'] == 'dsn_db'
assert client.connect_calls[-1]['user'] == 'dsn_user'
assert client.connect_calls[-1]['passwd'] == 'dsn_pass'
assert client.connect_calls[-1]['host'] == 'dsn_host'
assert client.connect_calls[-1]['port'] == 3307
assert client.connect_calls[-1]['socket'] == '/tmp/mysql.sock'
assert any('MYSQL_UNIX_PORT environment variable is deprecated' in call for call in secho_calls)


def test_run_from_cli_args_reports_missing_dsn(monkeypatch: pytest.MonkeyPatch) -> None:
cli_args = make_cli_args()
cli_args.dsn = 'missing'
secho_calls: list[tuple[str, dict[str, Any]]] = []
monkeypatch.setattr(cli_runner, 'is_valid_connection_scheme', lambda value: (False, None))
monkeypatch.setattr(cli_runner.click, 'secho', lambda text, **kwargs: secho_calls.append((text, kwargs)))
monkeypatch.setattr(main, 'preprocess_cli_args', lambda args, scheme_validator: 0)

with pytest.raises(SystemExit) as excinfo:
cli_runner.run_from_cli_args(cli_args, lambda **_kwargs: DummyMyCli())

assert excinfo.value.code == 1
assert secho_calls == [
(
'Could not find the specified DSN in the config file. Please check the "[alias_dsn]" section in your myclirc.',
{'err': True, 'fg': 'red'},
)
]


def test_run_from_cli_args_maps_dsn_ssl_parameters(monkeypatch: pytest.MonkeyPatch) -> None:
cli_args = make_cli_args()
cli_args.dsn = (
'mysql://user:pass@host:3306/db?ssl=true&ssl_ca=~/ca.pem&ssl_capath=/capath'
'&ssl_cert=~/cert.pem&ssl_key=~/key.pem&ssl_cipher=AES256&tls_version=TLSv1.3'
'&ssl_verify_server_cert=true'
)
client = DummyMyCli()
secho_calls: list[str] = []
monkeypatch.setattr(cli_runner.click, 'secho', lambda text, **_kwargs: secho_calls.append(text))

run_with_client(monkeypatch, cli_args, client)

ssl = client.connect_calls[-1]['ssl']
assert ssl == {
'mode': 'on',
'ca': cli_runner.os.path.expanduser('~/ca.pem'),
'capath': '/capath',
'cert': cli_runner.os.path.expanduser('~/cert.pem'),
'key': cli_runner.os.path.expanduser('~/key.pem'),
'cipher': 'AES256',
'tls_version': 'TLSv1.3',
'check_hostname': True,
}
assert any('"ssl" DSN URI parameter is deprecated' in call for call in secho_calls)


def test_run_from_cli_args_merges_global_list_and_alias_scalar_init_commands(
monkeypatch: pytest.MonkeyPatch,
) -> None:
cli_args = make_cli_args()
cli_args.dsn = 'prod'
cli_args.init_command = 'set cli=1'
client = DummyMyCli(
config={
**default_config(),
'alias_dsn': {'prod': 'mysql://u:p@h/db'},
'init-commands': {'first': ['set global=1', 'set global=2']},
'alias_dsn.init-commands': {'prod': 'set alias=1'},
}
)

run_with_client(monkeypatch, cli_args, client)

assert client.connect_calls[-1]['init_command'] == 'set global=1; set global=2; set alias=1; set cli=1'


def test_run_from_cli_args_resets_keyring(monkeypatch: pytest.MonkeyPatch) -> None:
cli_args = make_cli_args()
cli_args.use_keyring = 'reset'
client = DummyMyCli()

run_with_client(monkeypatch, cli_args, client)

assert client.connect_calls[-1]['use_keyring'] is True
assert client.connect_calls[-1]['reset_keyring'] is True


def test_run_from_cli_args_uses_explicit_keyring_flag(monkeypatch: pytest.MonkeyPatch) -> None:
cli_args = make_cli_args()
cli_args.use_keyring = 'true'
client = DummyMyCli()

run_with_client(monkeypatch, cli_args, client)

assert client.connect_calls[-1]['use_keyring'] is True
assert client.connect_calls[-1]['reset_keyring'] is False
Loading