diff --git a/changelog.md b/changelog.md index 69aee6fc..c7e89e28 100644 --- a/changelog.md +++ b/changelog.md @@ -1,8 +1,16 @@ +1.76.0 (2026/06/20) +============== + +Features +--------- +* Optionally expand whole `${VAR}` values in DSN aliases. + + 1.75.0 (2026/06/20) ============== Features --------- +--------- * Silently accept forward slash to introduce special commands. * `--progress` spinners for setup steps in `--batch` mode. diff --git a/mycli/AUTHORS b/mycli/AUTHORS index d62b9c59..4b63b51d 100644 --- a/mycli/AUTHORS +++ b/mycli/AUTHORS @@ -71,6 +71,7 @@ Contributors: * Morgan Mitchell * mrdeathless * Nathan Huang + * n8himmel * Nicolas Palumbo * Phil Cohen * QiaoHou Peng diff --git a/mycli/cli_runner.py b/mycli/cli_runner.py index 6b35e657..934692c5 100644 --- a/mycli/cli_runner.py +++ b/mycli/cli_runner.py @@ -1,6 +1,7 @@ from __future__ import annotations import os +import re import sys from textwrap import dedent from typing import TYPE_CHECKING, Any, Callable @@ -22,6 +23,75 @@ from mycli.main import CliArgs ClientFactory = Callable[..., Any] +ENV_VAR_PATTERN = re.compile(r'^\$\{([A-Za-z_][A-Za-z0-9_]*)\}$') + + +class DsnAliasEnvVarError(ValueError): + pass + + +def expand_dsn_alias_env_var(value: str | None, alias_name: str) -> str | None: + if value is None: + return None + + match = ENV_VAR_PATTERN.fullmatch(value) + if not match: + return value + + var_name = match.group(1) + try: + return os.environ[var_name] + except KeyError as exc: + raise DsnAliasEnvVarError(f'Environment variable {var_name} referenced by DSN alias {alias_name} is not set.') from exc + + +def split_dsn_netloc(netloc: str) -> tuple[str | None, str | None, str | None, str | None]: + username = None + password = None + host_port = netloc + + if '@' in host_port: + user_info, host_port = host_port.rsplit('@', 1) + username, separator, password = user_info.partition(':') + if not separator: + password = None + + if not host_port: + return username, password, None, None + + if host_port.startswith('['): + end = host_port.find(']') + if end >= 0: + host = host_port[1:end] + port = host_port[end + 2 :] if host_port[end + 1 : end + 2] == ':' else None + return username, password, host, port + + host, separator, port = host_port.partition(':') + return username, password, host or None, port if separator else None + + +def expand_dsn_alias_env_vars( + dsn_uri: str, alias_name: str +) -> tuple[str | None, str | None, str | None, int | None, str, dict[str, list[str]]]: + uri = urlparse(dsn_uri) + username, password, host, port = split_dsn_netloc(uri.netloc) + + expanded_port = expand_dsn_alias_env_var(port, alias_name) + try: + port_number = int(expanded_port) if expanded_port else None + except ValueError as exc: + raise DsnAliasEnvVarError(f'Port in DSN alias {alias_name} must be an integer.') from exc + + params = {key: [expand_dsn_alias_env_var(value, alias_name) or '' for value in values] for key, values in parse_qs(uri.query).items()} + + return ( + expand_dsn_alias_env_var(unquote(username) if username is not None else None, alias_name), + expand_dsn_alias_env_var(unquote(password) if password is not None else None, alias_name), + expand_dsn_alias_env_var(host, alias_name), + port_number, + expand_dsn_alias_env_var(uri.path[1:], alias_name) or '', + params, + ) def run_from_cli_args(cli_args: 'CliArgs', client_factory: ClientFactory) -> None: @@ -162,22 +232,38 @@ def run_from_cli_args(cli_args: 'CliArgs', client_factory: ClientFactory) -> Non if dsn_uri: uri = urlparse(dsn_uri) + env_var_alias_name = None + dsn_alias = getattr(mycli, 'dsn_alias', None) + if dsn_alias and str_to_bool(mycli.config['main'].get('expand_dsn_alias_env_vars', 'False')): + env_var_alias_name = dsn_alias + + if env_var_alias_name: + try: + dsn_user, dsn_password, dsn_host, dsn_port, dsn_database, dsn_params = expand_dsn_alias_env_vars( + dsn_uri, env_var_alias_name + ) + except DsnAliasEnvVarError as exc: + click.secho(str(exc), err=True, fg='red') + sys.exit(1) + else: + dsn_user = unquote(uri.username) if uri.username is not None else None + dsn_password = unquote(uri.password) if uri.password is not None else None + dsn_host = uri.hostname + dsn_port = uri.port + dsn_database = uri.path[1:] + dsn_params = parse_qs(uri.query) if uri.query else {} + if not database: - database = uri.path[1:] # ignore the leading fwd slash - if not cli_args.user and uri.username is not None: - cli_args.user = unquote(uri.username) + database = dsn_database + if not cli_args.user and dsn_user is not None: + cli_args.user = dsn_user # todo: rationalize the behavior of empty-string passwords here - if not cli_args.password and uri.password is not None: - cli_args.password = unquote(uri.password) + if not cli_args.password and dsn_password is not None: + cli_args.password = dsn_password if not cli_args.host: - cli_args.host = uri.hostname + cli_args.host = dsn_host if not cli_args.port: - cli_args.port = uri.port - - if uri.query: - dsn_params = parse_qs(uri.query) - else: - dsn_params = {} + cli_args.port = dsn_port if params := dsn_params.get('ssl'): click.secho( diff --git a/mycli/myclirc b/mycli/myclirc index 76663572..9fadbc63 100644 --- a/mycli/myclirc +++ b/mycli/myclirc @@ -24,6 +24,9 @@ prefetch_schemas_mode = always # prefetch_schemas_mode = listed. Ignored in other modes. prefetch_schemas_list = +# Expand whole DSN alias values in the form ${VAR} from the environment. +expand_dsn_alias_env_vars = False + # Multi-line mode allows breaking up the sql statements into multiple lines. If # this is set to True, then the end of the statements must have a semi-colon. # If this is set to False then sql statements can't be split into multiple diff --git a/test/myclirc b/test/myclirc index 680447e5..176d1156 100644 --- a/test/myclirc +++ b/test/myclirc @@ -24,6 +24,9 @@ prefetch_schemas_mode = always # prefetch_schemas_mode = listed. Ignored in other modes. prefetch_schemas_list = +# Expand whole DSN alias values in the form ${VAR} from the environment. +expand_dsn_alias_env_vars = False + # Multi-line mode allows breaking up the sql statements into multiple lines. If # this is set to True, then the end of the statements must have a semi-colon. # If this is set to False then sql statements can't be split into multiple diff --git a/test/pytests/test_cli_runner.py b/test/pytests/test_cli_runner.py index ed9a44ca..3e491e91 100644 --- a/test/pytests/test_cli_runner.py +++ b/test/pytests/test_cli_runner.py @@ -145,6 +145,119 @@ def test_run_from_cli_args_uses_deprecated_mysql_unix_port_and_database_alias( assert any('MYSQL_UNIX_PORT environment variable is deprecated' in call for call in secho_calls) +def test_run_from_cli_args_leaves_dsn_alias_env_vars_disabled_by_default( + monkeypatch: pytest.MonkeyPatch, +) -> None: + cli_args = make_cli_args() + cli_args.dsn = 'prod' + monkeypatch.setenv('MYCLI_TEST_DSN_USER', 'env_user') + client = DummyMyCli( + config={ + **default_config(), + 'alias_dsn': {'prod': 'mysql://${MYCLI_TEST_DSN_USER}:pass@host:3306/db'}, + } + ) + + run_with_client(monkeypatch, cli_args, client) + + assert client.connect_calls[-1]['user'] == '${MYCLI_TEST_DSN_USER}' + + +def test_run_from_cli_args_expands_whole_dsn_alias_env_vars_when_enabled( + monkeypatch: pytest.MonkeyPatch, +) -> None: + cli_args = make_cli_args() + cli_args.dsn = 'prod' + monkeypatch.setenv('MYCLI_TEST_DSN_USER', 'env_user') + monkeypatch.setenv('MYCLI_TEST_DSN_PASSWORD', 'env_pass') + monkeypatch.setenv('MYCLI_TEST_DSN_HOST', 'env-host') + monkeypatch.setenv('MYCLI_TEST_DSN_PORT', '3308') + monkeypatch.setenv('MYCLI_TEST_DSN_DATABASE', 'env_db') + monkeypatch.setenv('MYCLI_TEST_DSN_CHARSET', 'utf8mb4') + monkeypatch.setenv('MYCLI_TEST_DSN_KEEPALIVE', '9') + config = default_config() + config['main'] = {**config['main'], 'expand_dsn_alias_env_vars': 'true'} + config['alias_dsn'] = { + 'prod': ( + 'mysql://${MYCLI_TEST_DSN_USER}:${MYCLI_TEST_DSN_PASSWORD}' + '@${MYCLI_TEST_DSN_HOST}:${MYCLI_TEST_DSN_PORT}/${MYCLI_TEST_DSN_DATABASE}' + '?character_set=${MYCLI_TEST_DSN_CHARSET}&keepalive_ticks=${MYCLI_TEST_DSN_KEEPALIVE}' + ) + } + client = DummyMyCli(config=config) + + run_with_client(monkeypatch, cli_args, client) + + assert client.connect_calls[-1]['user'] == 'env_user' + assert client.connect_calls[-1]['passwd'] == 'env_pass' + assert client.connect_calls[-1]['host'] == 'env-host' + assert client.connect_calls[-1]['port'] == 3308 + assert client.connect_calls[-1]['database'] == 'env_db' + assert client.connect_calls[-1]['character_set'] == 'utf8mb4' + assert client.connect_calls[-1]['keepalive_ticks'] == 9 + + +def test_run_from_cli_args_does_not_expand_partial_values_or_query_keys( + monkeypatch: pytest.MonkeyPatch, +) -> None: + cli_args = make_cli_args() + cli_args.dsn = 'prod' + monkeypatch.setenv('MYCLI_TEST_DSN_USER', 'env_user') + monkeypatch.setenv('MYCLI_TEST_DSN_QUERY_KEY', 'character_set') + config = default_config() + config['main'] = {**config['main'], 'expand_dsn_alias_env_vars': 'true'} + config['alias_dsn'] = { + 'prod': ('mysql://user-${MYCLI_TEST_DSN_USER}:pass@host:3306/db?${MYCLI_TEST_DSN_QUERY_KEY}=utf8mb4&character_set=utf8') + } + client = DummyMyCli(config=config) + + run_with_client(monkeypatch, cli_args, client) + + assert client.connect_calls[-1]['user'] == 'user-${MYCLI_TEST_DSN_USER}' + assert client.connect_calls[-1]['character_set'] == 'utf8' + + +def test_run_from_cli_args_does_not_expand_unbraced_dsn_alias_env_vars( + monkeypatch: pytest.MonkeyPatch, +) -> None: + cli_args = make_cli_args() + cli_args.dsn = 'prod' + monkeypatch.setenv('MYCLI_TEST_DSN_USER', 'env_user') + config = default_config() + config['main'] = {**config['main'], 'expand_dsn_alias_env_vars': 'true'} + config['alias_dsn'] = {'prod': 'mysql://$MYCLI_TEST_DSN_USER:pass@host:3306/db'} + client = DummyMyCli(config=config) + + run_with_client(monkeypatch, cli_args, client) + + assert client.connect_calls[-1]['user'] == '$MYCLI_TEST_DSN_USER' + + +def test_run_from_cli_args_reports_missing_dsn_alias_env_var( + monkeypatch: pytest.MonkeyPatch, +) -> None: + cli_args = make_cli_args() + cli_args.dsn = 'prod' + config = default_config() + config['main'] = {**config['main'], 'expand_dsn_alias_env_vars': 'true'} + config['alias_dsn'] = {'prod': 'mysql://${MYCLI_TEST_MISSING_DSN_USER}:pass@host:3306/db'} + client = DummyMyCli(config=config) + secho_calls: list[tuple[str, dict[str, Any]]] = [] + monkeypatch.setattr(cli_runner.click, 'secho', lambda text, **kwargs: secho_calls.append((text, kwargs))) + + with pytest.raises(SystemExit) as excinfo: + run_with_client(monkeypatch, cli_args, client) + + assert excinfo.value.code == 1 + assert secho_calls == [ + ( + 'Environment variable MYCLI_TEST_MISSING_DSN_USER referenced by DSN alias prod is not set.', + {'err': True, 'fg': 'red'}, + ) + ] + assert client.connect_calls == [] + + def test_run_from_cli_args_reports_missing_dsn(monkeypatch: pytest.MonkeyPatch) -> None: cli_args = make_cli_args() cli_args.dsn = 'missing'