diff --git a/.changes/next-release/enhancement-clientcontextparameters-24961.json b/.changes/next-release/enhancement-clientcontextparameters-24961.json new file mode 100644 index 000000000000..ea4b5f5dbe9c --- /dev/null +++ b/.changes/next-release/enhancement-clientcontextparameters-24961.json @@ -0,0 +1,5 @@ +{ + "type": "enhancement", + "category": "client context parameters", + "description": "Add support for client context parameters as per-service CLI flags. Services that define clientContextParams in their model now automatically expose them as CLI options (e.g. ``--disable-s3-express-session-auth``, ``--force-path-style``). Boolean parameters support ``--flag`` / ``--no-flag`` syntax." +} diff --git a/awscli/customizations/clientcontextparams.py b/awscli/customizations/clientcontextparams.py new file mode 100644 index 000000000000..0ec297813452 --- /dev/null +++ b/awscli/customizations/clientcontextparams.py @@ -0,0 +1,172 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import logging +from functools import partial + +from awscli.arguments import BaseCLIArgument +from awscli.botocore import xform_name +from awscli.botocore.config import Config + +logger = logging.getLogger(__name__) + +_SUPPORTED_TYPES = ('boolean', 'string') + + +def register_client_context_params(event_handlers): + event_handlers.register( + 'building-argument-table', inject_client_context_params + ) + + +def inject_client_context_params( + argument_table, operation_model, event_name, session, **kwargs +): + service_model = operation_model.service_model + context_params = getattr(service_model, 'client_context_parameters', None) + if not context_params: + return + + parsed_args_event = event_name.replace( + 'building-argument-table.', 'operation-args-parsed.' + ) + param_defs = [] + for param in context_params: + cli_name = xform_name(param.name, '-') + # Skip if an operation input member has the same CLI name; + # the model validation test will catch this before release. + if cli_name in argument_table: + logger.debug( + 'Skipping client context param %s for %s: ' + 'collision with existing argument', + param.name, + service_model.service_name, + ) + continue + # Skip types we don't handle yet; the model validation test + # will catch new types before they reach customers. + if param.type not in _SUPPORTED_TYPES: + logger.debug( + 'Skipping client context param %s for %s: ' + 'unsupported type %r', + param.name, + service_model.service_name, + param.type, + ) + continue + arg = ClientContextParamArgument( + name=cli_name, + context_param_name=param.name, + param_type=param.type, + documentation=getattr(param, 'documentation', ''), + action='store_true' if param.type == 'boolean' else None, + group_name=cli_name if param.type == 'boolean' else None, + ) + argument_table[cli_name] = arg + if param.type == 'boolean': + negative = ClientContextParamArgument( + name='no-' + cli_name, + context_param_name=param.name, + param_type=param.type, + action='store_false', + dest=cli_name.replace('-', '_'), + group_name=cli_name, + ) + argument_table['no-' + cli_name] = negative + param_defs.append((cli_name, param.name)) + + if param_defs: + session.register( + parsed_args_event, + partial(_apply_client_context_params, param_defs, session), + ) + + +def _apply_client_context_params(param_defs, session, parsed_args, **kwargs): + context_params = {} + for cli_name, original_name in param_defs: + py_name = cli_name.replace('-', '_') + value = getattr(parsed_args, py_name, None) + if value is not None: + context_params[original_name] = value + if not context_params: + return + new_config = Config(client_context_params=context_params) + existing = session.get_default_client_config() + if existing is not None: + new_config = existing.merge(new_config) + session.set_default_client_config(new_config) + + +class ClientContextParamArgument(BaseCLIArgument): + def __init__( + self, + name, + context_param_name, + param_type, + documentation='', + action=None, + dest=None, + group_name=None, + ): + self.argument_model = None + self._name = name + self._context_param_name = context_param_name + self._param_type = param_type + self._documentation = documentation + self._required = False + self._action = action + self._dest = dest or name.replace('-', '_') + self._group_name = group_name + + @property + def cli_name(self): + return '--' + self._name + + @property + def cli_type_name(self): + return self._param_type + + @property + def required(self): + return self._required + + @required.setter + def required(self, value): + self._required = value + + @property + def documentation(self): + return self._documentation + + @property + def group_name(self): + return self._group_name + + def add_to_parser(self, parser): + if self._param_type == 'boolean': + parser.add_argument( + self.cli_name, + dest=self._dest, + action=self._action, + default=None, + ) + else: + parser.add_argument( + self.cli_name, + dest=self._dest, + ) + + def add_to_params(self, parameters, value): + # Client context params are not operation parameters; + # they are applied via _apply_client_context_params. + pass diff --git a/awscli/handlers.py b/awscli/handlers.py index 9c605dc37c33..b991cc9a61f8 100644 --- a/awscli/handlers.py +++ b/awscli/handlers.py @@ -26,6 +26,9 @@ from awscli.customizations.assumerole import register_assume_role_provider from awscli.customizations.awslambda import register_lambda_create_function from awscli.customizations.binaryformat import add_binary_formatter +from awscli.customizations.clientcontextparams import ( + register_client_context_params, +) from awscli.customizations.cliinput import register_cli_input_args from awscli.customizations.cloudformation import ( initialize as cloudformation_init, @@ -165,6 +168,7 @@ def awscli_initialize(event_handlers): ) register_parse_global_args(event_handlers) register_pagination(event_handlers) + register_client_context_params(event_handlers) register_secgroup(event_handlers) register_bundleinstance(event_handlers) s3_plugin_initialize(event_handlers) diff --git a/tests/functional/test_client_context_params.py b/tests/functional/test_client_context_params.py new file mode 100644 index 000000000000..218b45fd6284 --- /dev/null +++ b/tests/functional/test_client_context_params.py @@ -0,0 +1,71 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import pytest + +from awscli.botocore import xform_name +from awscli.clidriver import create_clidriver +from awscli.customizations.clientcontextparams import _SUPPORTED_TYPES + +_SESSION = create_clidriver().session + +_TYPE_TESTS = [] +_COLLISION_TESTS = [] + +for _svc_name in _SESSION.get_available_services(): + _model = _SESSION.get_service_model(_svc_name) + if not hasattr(_model, 'client_context_parameters'): + continue + _ctx_params = _model.client_context_parameters + for _param in _ctx_params: + _TYPE_TESTS.append((_svc_name, _param.name, _param.type)) + _ctx_names = {xform_name(p.name, '-') for p in _ctx_params} + for _op_name in _model.operation_names: + _op = _model.operation_model(_op_name) + if _op.input_shape is None: + continue + _collisions = _ctx_names & { + xform_name(m, '-') for m in _op.input_shape.members + } + if _collisions: + _COLLISION_TESTS.append((_svc_name, _op_name, _collisions)) + + +@pytest.mark.validates_models +@pytest.mark.parametrize("service_name, param_name, param_type", _TYPE_TESTS) +def test_client_context_param_types_are_supported( + service_name, param_name, param_type, record_property +): + if param_type not in _SUPPORTED_TYPES: + record_property('aws_service', service_name) + raise AssertionError( + f'Client context param {param_name!r} on service ' + f'{service_name!r} has unsupported type {param_type!r}. ' + f'Supported types: {_SUPPORTED_TYPES}' + ) + + +@pytest.mark.validates_models +@pytest.mark.parametrize( + "service_name, operation_name, collisions", _COLLISION_TESTS +) +def test_client_context_params_do_not_collide_with_operation_inputs( + service_name, operation_name, collisions, record_property +): + # Only runs when a collision exists; unconditional failure is intentional. + record_property('aws_service', service_name) + record_property('aws_operation', operation_name) + raise AssertionError( + f'Client context param name(s) {collisions} on service ' + f'{service_name!r} collide with input members of ' + f'{operation_name!r}.' + ) diff --git a/tests/functional/test_clientcontextparams.py b/tests/functional/test_clientcontextparams.py new file mode 100644 index 000000000000..2715bc97e762 --- /dev/null +++ b/tests/functional/test_clientcontextparams.py @@ -0,0 +1,41 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from awscli.testutils import BaseAWSCommandParamsTest + + +class TestClientContextParams(BaseAWSCommandParamsTest): + def setUp(self): + super().setUp() + self.parsed_responses = [{'Buckets': [], 'Owner': {}}] + + def test_boolean_flag_sets_client_context_params(self): + self.run_cmd('s3api list-buckets --disable-s3-express-session-auth') + config = self.driver.session.get_default_client_config() + self.assertEqual( + config.client_context_params, + {'DisableS3ExpressSessionAuth': True}, + ) + + def test_negative_flag_sets_false(self): + self.run_cmd('s3api list-buckets --no-disable-s3-express-session-auth') + config = self.driver.session.get_default_client_config() + self.assertEqual( + config.client_context_params, + {'DisableS3ExpressSessionAuth': False}, + ) + + def test_no_flag_does_not_set_client_context_params(self): + self.run_cmd('s3api list-buckets') + config = self.driver.session.get_default_client_config() + if config is not None: + self.assertIsNone(config.client_context_params) diff --git a/tests/unit/customizations/test_clientcontextparams.py b/tests/unit/customizations/test_clientcontextparams.py new file mode 100644 index 000000000000..5b5d1ca19abd --- /dev/null +++ b/tests/unit/customizations/test_clientcontextparams.py @@ -0,0 +1,197 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import argparse +from unittest.mock import Mock + +import pytest + +from awscli.botocore.config import Config +from awscli.customizations.clientcontextparams import ( + ClientContextParamArgument, + _apply_client_context_params, + inject_client_context_params, +) + + +def _make_context_param(name, param_type='boolean', documentation=''): + param = Mock() + param.name = name + param.type = param_type + param.documentation = documentation + return param + + +def _make_operation_model(context_params): + operation_model = Mock() + operation_model.service_model.service_name = 's3' + operation_model.service_model.client_context_parameters = context_params + return operation_model + + +@pytest.fixture +def session(): + s = Mock() + s.get_default_client_config.return_value = None + return s + + +@pytest.fixture +def argument_table(): + return {} + + +def _inject(argument_table, session, context_params): + inject_client_context_params( + argument_table=argument_table, + operation_model=_make_operation_model(context_params), + event_name='building-argument-table.s3api.list-buckets', + session=session, + ) + + +def test_no_context_params_does_nothing(argument_table, session): + _inject(argument_table, session, []) + assert argument_table == {} + session.register.assert_not_called() + + +def test_no_client_context_parameters_attr_does_nothing( + argument_table, session +): + operation_model = Mock() + del operation_model.service_model.client_context_parameters + inject_client_context_params( + argument_table=argument_table, + operation_model=operation_model, + event_name='building-argument-table.s3api.list-buckets', + session=session, + ) + assert argument_table == {} + + +def test_boolean_param_injects_positive_and_negative_with_group( + argument_table, session +): + _inject(argument_table, session, [_make_context_param('ForcePathStyle')]) + pos = argument_table['force-path-style'] + neg = argument_table['no-force-path-style'] + assert isinstance(pos, ClientContextParamArgument) + assert isinstance(neg, ClientContextParamArgument) + assert pos.group_name == 'force-path-style' + assert neg.group_name == 'force-path-style' + + +def test_string_param_injected_without_negative_or_group( + argument_table, session +): + _inject( + argument_table, + session, + [_make_context_param('Endpoint', param_type='string')], + ) + assert 'endpoint' in argument_table + assert 'no-endpoint' not in argument_table + assert argument_table['endpoint'].group_name is None + + +def test_collision_skips_param(argument_table, session): + argument_table['accelerate'] = Mock() + _inject(argument_table, session, [_make_context_param('Accelerate')]) + assert not isinstance( + argument_table['accelerate'], ClientContextParamArgument + ) + assert 'no-accelerate' not in argument_table + + +def test_unsupported_type_skips_param(argument_table, session): + _inject( + argument_table, + session, + [_make_context_param('Count', param_type='integer')], + ) + assert 'count' not in argument_table + + +def test_registers_operation_args_parsed_handler(argument_table, session): + _inject(argument_table, session, [_make_context_param('ForcePathStyle')]) + event_name = session.register.call_args[0][0] + assert event_name == 'operation-args-parsed.s3api.list-buckets' + + +def _apply(session, param_defs, **attr_values): + parsed_args = argparse.Namespace(**attr_values) + _apply_client_context_params(param_defs, session, parsed_args) + + +def test_flag_passed_sets_config(session): + _apply( + session, + [('force-path-style', 'ForcePathStyle')], + force_path_style=True, + ) + config = session.set_default_client_config.call_args[0][0] + assert config.client_context_params == {'ForcePathStyle': True} + + +def test_no_flag_passed_does_not_set_config(session): + _apply( + session, + [('force-path-style', 'ForcePathStyle')], + force_path_style=None, + ) + session.set_default_client_config.assert_not_called() + + +def test_negative_flag_sends_false(session): + _apply( + session, + [('force-path-style', 'ForcePathStyle')], + force_path_style=False, + ) + config = session.set_default_client_config.call_args[0][0] + assert config.client_context_params == {'ForcePathStyle': False} + + +def test_merges_with_existing_config(session): + session.get_default_client_config.return_value = Config(read_timeout=30) + _apply( + session, + [('force-path-style', 'ForcePathStyle')], + force_path_style=True, + ) + config = session.set_default_client_config.call_args[0][0] + assert config.client_context_params == {'ForcePathStyle': True} + assert config.read_timeout == 30 + + +def test_string_add_to_parser(): + arg = ClientContextParamArgument( + name='endpoint', + context_param_name='Endpoint', + param_type='string', + ) + parser = argparse.ArgumentParser() + arg.add_to_parser(parser) + result = parser.parse_args(['--endpoint', 'custom.example.com']) + assert result.endpoint == 'custom.example.com' + + +def test_add_to_params_is_noop(): + arg = ClientContextParamArgument( + name='force-path-style', + context_param_name='ForcePathStyle', + param_type='boolean', + ) + params = {} + arg.add_to_params(params, True) + assert params == {}