diff --git a/CHANGES.rst b/CHANGES.rst index d5b05c526e..00a9a5ef66 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,5 +1,23 @@ .. currentmodule:: click +Version 8.4.0 +------------- + +Unreleased + +- :class:`ParamType` typing improvements. :pr:`3371` + + - :class:`ParamType` is now a generic abstract base class, + parameterized by its converted value type. + - :meth:`~ParamType.convert` return types are narrowed on all + concrete types (``str`` for :class:`STRING`, ``int`` for + :class:`INT`, etc.). + - :meth:`~ParamType.to_info_dict` returns specific + :class:`~typing.TypedDict` subclasses instead of + ``dict[str, Any]``. + - :class:`CompositeParamType` and the number-range base are now + generic with abstract methods. + Version 8.3.3 ------------- diff --git a/docs/parameter-types.md b/docs/parameter-types.md index e9ac206f2c..cdf0b5cc0a 100644 --- a/docs/parameter-types.md +++ b/docs/parameter-types.md @@ -70,6 +70,12 @@ The resulting value from an option will always be one of the originally passed c regardless of `case_sensitive`. ``` +```{versionchanged} 8.4.0 +{class}`Choice` is now generic. Parameterize it with the choice value type +({class}`!Choice[HashType]` for an enum, {class}`!Choice[str]` for plain +strings) to enable type-checked consumers. +``` + (ranges)= ### Int and Float Ranges @@ -153,16 +159,21 @@ To implement a custom type, you need to subclass the {class}`ParamType` class. F function that fails with a `ValueError` is also supported, though discouraged. Override the {meth}`~ParamType.convert` method to convert the value from a string to the correct type. +{class}`ParamType` is generic in the converted value type: parameterize it with +the type returned by `convert` so that consumers (and type checkers) can rely +on the narrowed return type. + The following code implements an integer type that accepts hex and octal numbers in addition to normal integers, and converts them into regular integers. ```python import click -class BasedIntParamType(click.ParamType): + +class BasedIntParamType(click.ParamType[int]): name = "integer" - def convert(self, value, param, ctx): + def convert(self, value, param, ctx) -> int: if isinstance(value, int): return value @@ -175,6 +186,7 @@ class BasedIntParamType(click.ParamType): except ValueError: self.fail(f"{value!r} is not a valid integer", param, ctx) + BASED_INT = BasedIntParamType() ``` @@ -184,3 +196,10 @@ conversion fails. The `param` and `ctx` arguments may be `None` in some cases su Values from user input or the command line will be strings, but default values and Python arguments may already be the correct type. The custom type should check at the top if the value is already valid and pass it through to support those cases. + +```{versionchanged} 8.4.0 +{class}`ParamType` is now a generic abstract base class. Parameterize it with +the converted value type ({class}`!ParamType[int]` for an integer-returning +type) so that {meth}`~ParamType.convert` and downstream consumers carry the +narrowed type. +``` diff --git a/docs/shell-completion.md b/docs/shell-completion.md index a4fedc6fb8..a8bc941ce2 100644 --- a/docs/shell-completion.md +++ b/docs/shell-completion.md @@ -120,7 +120,7 @@ indicate special handling for paths, and `help` for shells that support showing In this example, the type will suggest environment variables that start with the incomplete value. ```python -class EnvVarType(ParamType): +class EnvVarType(ParamType[str]): name = "envvar" def shell_complete(self, ctx, param, incomplete): diff --git a/docs/support-multiple-versions.md b/docs/support-multiple-versions.md index faa50ad133..c7263549be 100644 --- a/docs/support-multiple-versions.md +++ b/docs/support-multiple-versions.md @@ -55,7 +55,7 @@ def add_ctx_arg(f: F) -> F: Here's an example ``ParamType`` subclass which uses this: ```python -class CommaDelimitedString(click.ParamType): +class CommaDelimitedString(click.ParamType[str]): @add_ctx_arg def get_metavar(self, param: click.Parameter, ctx: click.Context | None) -> str: return "TEXT,TEXT,..." diff --git a/examples/validation/validation.py b/examples/validation/validation.py index 3f78df0e7f..32dafa8382 100644 --- a/examples/validation/validation.py +++ b/examples/validation/validation.py @@ -9,10 +9,10 @@ def validate_count(ctx, param, value): return value -class URL(click.ParamType): +class URL(click.ParamType[urlparse.ParseResult]): name = "url" - def convert(self, value, param, ctx): + def convert(self, value, param, ctx) -> urlparse.ParseResult: if not isinstance(value, tuple): value = urlparse.urlparse(value) if value.scheme not in ("http", "https"): diff --git a/src/click/core.py b/src/click/core.py index d940dd80e1..13d8841da1 100644 --- a/src/click/core.py +++ b/src/click/core.py @@ -2149,7 +2149,7 @@ class Parameter: def __init__( self, param_decls: cabc.Sequence[str] | None = None, - type: types.ParamType | t.Any | None = None, + type: types.ParamType[t.Any] | t.Any | None = None, required: bool = False, # XXX The default historically embed two concepts: # - the declaration of a Parameter object carrying the default (handy to @@ -2181,7 +2181,7 @@ def __init__( self.name, self.opts, self.secondary_opts = self._parse_decls( param_decls or (), expose_value ) - self.type: types.ParamType = types.convert_type(type, default) + self.type: types.ParamType[t.Any] = types.convert_type(type, default) # Default nargs to what the type tells us if we have that # information available. @@ -2648,7 +2648,7 @@ def shell_complete(self, ctx: Context, incomplete: str) -> list[CompletionItem]: """Return a list of completions for the incomplete value. If a ``shell_complete`` function was given during init, it is used. Otherwise, the :attr:`type` - :meth:`~click.types.ParamType.shell_complete` function is used. + :meth:`~click.types.ParamType[t.Any].shell_complete` function is used. :param ctx: Invocation context for this command. :param incomplete: Value being completed. May be empty. @@ -2749,7 +2749,7 @@ def __init__( multiple: bool = False, count: bool = False, allow_from_autoenv: bool = True, - type: types.ParamType | t.Any | None = None, + type: types.ParamType[t.Any] | t.Any | None = None, help: str | None = None, hidden: bool = False, show_choices: bool = True, @@ -2825,7 +2825,7 @@ def __init__( if type is None: # A flag without a flag_value is a boolean flag. if flag_value is UNSET: - self.type: types.ParamType = types.BoolParamType() + self.type: types.ParamType[t.Any] = types.BoolParamType() # If the flag value is a boolean, use BoolParamType. elif isinstance(flag_value, bool): self.type = types.BoolParamType() diff --git a/src/click/termui.py b/src/click/termui.py index 48f671b217..6801e30fa4 100644 --- a/src/click/termui.py +++ b/src/click/termui.py @@ -63,7 +63,7 @@ def _build_prompt( show_default: bool | str = False, default: t.Any | None = None, show_choices: bool = True, - type: ParamType | None = None, + type: ParamType[t.Any] | None = None, ) -> str: prompt = text if type is not None and show_choices and isinstance(type, Choice): @@ -87,7 +87,7 @@ def prompt( default: t.Any | None = None, hide_input: bool = False, confirmation_prompt: bool | str = False, - type: ParamType | t.Any | None = None, + type: ParamType[t.Any] | t.Any | None = None, value_proc: t.Callable[[str], t.Any] | None = None, prompt_suffix: str = ": ", show_default: bool | str = True, diff --git a/src/click/types.py b/src/click/types.py index e71c1c21e4..bf047d6862 100644 --- a/src/click/types.py +++ b/src/click/types.py @@ -1,11 +1,13 @@ from __future__ import annotations +import abc import collections.abc as cabc import enum import os import stat import sys import typing as t +import uuid from datetime import datetime from gettext import gettext as _ from gettext import ngettext @@ -27,7 +29,12 @@ ParamTypeValue = t.TypeVar("ParamTypeValue") -class ParamType: +class ParamTypeInfoDict(t.TypedDict): + param_type: str + name: str + + +class ParamType(t.Generic[ParamTypeValue], abc.ABC): """Represents the type of a parameter. Validates and converts values from the command line or Python into the correct type. @@ -59,7 +66,7 @@ class ParamType: #: Windows). envvar_list_splitter: t.ClassVar[str | None] = None - def to_info_dict(self) -> dict[str, t.Any]: + def to_info_dict(self) -> ParamTypeInfoDict: """Gather information that could be useful for a tool generating user-facing documentation. @@ -85,9 +92,10 @@ def __call__( value: t.Any, param: Parameter | None = None, ctx: Context | None = None, - ) -> t.Any: + ) -> ParamTypeValue | None: if value is not None: return self.convert(value, param, ctx) + return None def get_metavar(self, param: Parameter, ctx: Context) -> str | None: """Returns the metavar default for this param if it provides one.""" @@ -101,7 +109,7 @@ def get_missing_message(self, param: Parameter, ctx: Context | None) -> str | No def convert( self, value: t.Any, param: Parameter | None, ctx: Context | None - ) -> t.Any: + ) -> ParamTypeValue: """Convert the value to the correct type. This is not called if the value is ``None`` (the missing value). @@ -121,7 +129,9 @@ def convert( :param ctx: The current context that arrived at this value. May be ``None``. """ - return value + # The default returns the value as-is so subclasses that only customize + # metadata are not forced to redeclare ``convert``. + return t.cast("ParamTypeValue", value) def split_envvar_value(self, rv: str) -> cabc.Sequence[str]: """Given a value from an environment variable this splits it up @@ -160,27 +170,29 @@ def shell_complete( return [] -class CompositeParamType(ParamType): +class CompositeParamType(ParamType[ParamTypeValue]): is_composite = True @property - def arity(self) -> int: # type: ignore - raise NotImplementedError() + @abc.abstractmethod + def arity(self) -> int: ... # type: ignore[override] + + +class FuncParamTypeInfoDict(ParamTypeInfoDict): + func: t.Callable[[t.Any], t.Any] -class FuncParamType(ParamType): - def __init__(self, func: t.Callable[[t.Any], t.Any]) -> None: +class FuncParamType(ParamType[ParamTypeValue]): + def __init__(self, func: t.Callable[[t.Any], ParamTypeValue]) -> None: self.name: str = func.__name__ self.func = func - def to_info_dict(self) -> dict[str, t.Any]: - info_dict = super().to_info_dict() - info_dict["func"] = self.func - return info_dict + def to_info_dict(self) -> FuncParamTypeInfoDict: + return {"func": self.func, **super().to_info_dict()} def convert( self, value: t.Any, param: Parameter | None, ctx: Context | None - ) -> t.Any: + ) -> ParamTypeValue: try: return self.func(value) except ValueError: @@ -192,7 +204,7 @@ def convert( self.fail(value, param, ctx) -class UnprocessedParamType(ParamType): +class UnprocessedParamType(ParamType[t.Any]): name = "text" def convert( @@ -204,12 +216,12 @@ def __repr__(self) -> str: return "UNPROCESSED" -class StringParamType(ParamType): +class StringParamType(ParamType[str]): name = "text" def convert( self, value: t.Any, param: Parameter | None, ctx: Context | None - ) -> t.Any: + ) -> str: if isinstance(value, bytes): enc = _get_argv_encoding() try: @@ -223,14 +235,19 @@ def convert( value = value.decode("utf-8", "replace") else: value = value.decode("utf-8", "replace") - return value + return value # type: ignore[no-any-return] return str(value) def __repr__(self) -> str: return "STRING" -class Choice(ParamType, t.Generic[ParamTypeValue]): +class ChoiceInfoDict(ParamTypeInfoDict): + choices: cabc.Sequence[t.Any] + case_sensitive: bool + + +class Choice(ParamType[ParamTypeValue], t.Generic[ParamTypeValue]): """The choice type allows a value to be checked against a fixed set of supported values. @@ -261,11 +278,12 @@ def __init__( self.choices: cabc.Sequence[ParamTypeValue] = tuple(choices) self.case_sensitive = case_sensitive - def to_info_dict(self) -> dict[str, t.Any]: - info_dict = super().to_info_dict() - info_dict["choices"] = self.choices - info_dict["case_sensitive"] = self.case_sensitive - return info_dict + def to_info_dict(self) -> ChoiceInfoDict: + return { + "choices": self.choices, + "case_sensitive": self.case_sensitive, + **super().to_info_dict(), + } def _normalized_mapping( self, ctx: Context | None = None @@ -398,7 +416,11 @@ def shell_complete( return [CompletionItem(c) for c in matched] -class DateTime(ParamType): +class DateTimeInfoDict(ParamTypeInfoDict): + formats: cabc.Sequence[str] + + +class DateTime(ParamType[datetime]): """The DateTime type converts date strings into `datetime` objects. The format strings which are checked are configurable, but default to some @@ -428,10 +450,8 @@ def __init__(self, formats: cabc.Sequence[str] | None = None): "%Y-%m-%d %H:%M:%S", ] - def to_info_dict(self) -> dict[str, t.Any]: - info_dict = super().to_info_dict() - info_dict["formats"] = self.formats - return info_dict + def to_info_dict(self) -> DateTimeInfoDict: + return {"formats": self.formats, **super().to_info_dict()} def get_metavar(self, param: Parameter, ctx: Context) -> str | None: return f"[{'|'.join(self.formats)}]" @@ -444,7 +464,7 @@ def _try_to_convert_date(self, value: t.Any, format: str) -> datetime | None: def convert( self, value: t.Any, param: Parameter | None, ctx: Context | None - ) -> t.Any: + ) -> datetime: if isinstance(value, datetime): return value @@ -469,12 +489,12 @@ def __repr__(self) -> str: return "DateTime" -class _NumberParamTypeBase(ParamType): - _number_class: t.ClassVar[type[t.Any]] +class _NumberParamTypeBase(ParamType[ParamTypeValue]): + _number_class: t.Callable[[t.Any], ParamTypeValue] def convert( self, value: t.Any, param: Parameter | None, ctx: Context | None - ) -> t.Any: + ) -> ParamTypeValue: try: return self._number_class(value) except ValueError: @@ -487,7 +507,15 @@ def convert( ) -class _NumberRangeBase(_NumberParamTypeBase): +class NumberRangeInfoDict(ParamTypeInfoDict): + min: float | None + max: float | None + min_open: bool + max_open: bool + clamp: bool + + +class _NumberRangeBase(_NumberParamTypeBase[ParamTypeValue]): def __init__( self, min: float | None = None, @@ -502,36 +530,37 @@ def __init__( self.max_open = max_open self.clamp = clamp - def to_info_dict(self) -> dict[str, t.Any]: - info_dict = super().to_info_dict() - info_dict.update( - min=self.min, - max=self.max, - min_open=self.min_open, - max_open=self.max_open, - clamp=self.clamp, - ) - return info_dict + def to_info_dict(self) -> NumberRangeInfoDict: + return { + "min": self.min, + "max": self.max, + "min_open": self.min_open, + "max_open": self.max_open, + "clamp": self.clamp, + **super().to_info_dict(), + } def convert( self, value: t.Any, param: Parameter | None, ctx: Context | None - ) -> t.Any: + ) -> ParamTypeValue: import operator rv = super().convert(value, param, ctx) - lt_min: bool = self.min is not None and ( + min = self.min + max = self.max + lt_min: bool = min is not None and ( operator.le if self.min_open else operator.lt - )(rv, self.min) - gt_max: bool = self.max is not None and ( + )(rv, min) # type: ignore[arg-type] + gt_max: bool = max is not None and ( operator.ge if self.max_open else operator.gt - )(rv, self.max) + )(rv, max) # type: ignore[arg-type] if self.clamp: - if lt_min: - return self._clamp(self.min, 1, self.min_open) # type: ignore + if min is not None and lt_min: + return self._clamp(min, 1, self.min_open) # type: ignore[arg-type] - if gt_max: - return self._clamp(self.max, -1, self.max_open) # type: ignore + if max is not None and gt_max: + return self._clamp(max, -1, self.max_open) # type: ignore[arg-type] if lt_min or gt_max: self.fail( @@ -544,7 +573,10 @@ def convert( return rv - def _clamp(self, bound: float, dir: t.Literal[1, -1], open: bool) -> float: + @abc.abstractmethod + def _clamp( + self, bound: ParamTypeValue, dir: t.Literal[1, -1], open: bool + ) -> ParamTypeValue: """Find the valid value to clamp to bound in the given direction. @@ -552,7 +584,7 @@ def _clamp(self, bound: float, dir: t.Literal[1, -1], open: bool) -> float: :param dir: 1 or -1 indicating the direction to move. :param open: If true, the range does not include the bound. """ - raise NotImplementedError + ... def _describe_range(self) -> str: """Describe the range for use in help text.""" @@ -573,7 +605,7 @@ def __repr__(self) -> str: return f"<{type(self).__name__} {self._describe_range()}{clamp}>" -class IntParamType(_NumberParamTypeBase): +class IntParamType(_NumberParamTypeBase[int]): name = "integer" _number_class = int @@ -581,7 +613,7 @@ def __repr__(self) -> str: return "INT" -class IntRange(_NumberRangeBase, IntParamType): +class IntRange(_NumberRangeBase[int], IntParamType): """Restrict an :data:`click.INT` value to a range of accepted values. See :ref:`ranges`. @@ -598,16 +630,14 @@ class IntRange(_NumberRangeBase, IntParamType): name = "integer range" - def _clamp( # type: ignore - self, bound: int, dir: t.Literal[1, -1], open: bool - ) -> int: + def _clamp(self, bound: int, dir: t.Literal[1, -1], open: bool) -> int: if not open: return bound return bound + dir -class FloatParamType(_NumberParamTypeBase): +class FloatParamType(_NumberParamTypeBase[float]): name = "float" _number_class = float @@ -615,7 +645,7 @@ def __repr__(self) -> str: return "FLOAT" -class FloatRange(_NumberRangeBase, FloatParamType): +class FloatRange(_NumberRangeBase[float], FloatParamType): """Restrict a :data:`click.FLOAT` value to a range of accepted values. See :ref:`ranges`. @@ -658,7 +688,7 @@ def _clamp(self, bound: float, dir: t.Literal[1, -1], open: bool) -> float: raise RuntimeError("Clamping is not supported for open bounds.") -class BoolParamType(ParamType): +class BoolParamType(ParamType[bool]): name = "boolean" bool_states: dict[str, bool] = { @@ -727,14 +757,12 @@ def __repr__(self) -> str: return "BOOL" -class UUIDParameterType(ParamType): +class UUIDParameterType(ParamType[uuid.UUID]): name = "uuid" def convert( self, value: t.Any, param: Parameter | None, ctx: Context | None - ) -> t.Any: - import uuid - + ) -> uuid.UUID: if isinstance(value, uuid.UUID): return value @@ -751,7 +779,12 @@ def __repr__(self) -> str: return "UUID" -class File(ParamType): +class FileInfoDict(ParamTypeInfoDict): + mode: str + encoding: str | None + + +class File(ParamType[t.IO[t.Any]]): """Declares a parameter to be a file for reading or writing. The file is automatically closed once the context tears down (after the command finished working). @@ -798,10 +831,12 @@ def __init__( self.lazy = lazy self.atomic = atomic - def to_info_dict(self) -> dict[str, t.Any]: - info_dict = super().to_info_dict() - info_dict.update(mode=self.mode, encoding=self.encoding) - return info_dict + def to_info_dict(self) -> FileInfoDict: + return { + "mode": self.mode, + "encoding": self.encoding, + **super().to_info_dict(), + } def resolve_lazy_flag(self, value: str | os.PathLike[str]) -> bool: if self.lazy is not None: @@ -876,7 +911,16 @@ def _is_file_like(value: t.Any) -> te.TypeGuard[t.IO[t.Any]]: return hasattr(value, "read") or hasattr(value, "write") -class Path(ParamType): +class PathInfoDict(ParamTypeInfoDict): + exists: bool + file_okay: bool + dir_okay: bool + writable: bool + readable: bool + allow_dash: bool + + +class Path(ParamType[str | bytes | os.PathLike[str]]): """The ``Path`` type is similar to the :class:`File` type, but returns the filename instead of an open file. Various checks can be enabled to validate the type of file and permissions. @@ -940,17 +984,16 @@ def __init__( else: self.name = _("path") - def to_info_dict(self) -> dict[str, t.Any]: - info_dict = super().to_info_dict() - info_dict.update( - exists=self.exists, - file_okay=self.file_okay, - dir_okay=self.dir_okay, - writable=self.writable, - readable=self.readable, - allow_dash=self.allow_dash, - ) - return info_dict + def to_info_dict(self) -> PathInfoDict: + return { + "exists": self.exists, + "file_okay": self.file_okay, + "dir_okay": self.dir_okay, + "writable": self.writable, + "readable": self.readable, + "allow_dash": self.allow_dash, + **super().to_info_dict(), + } def coerce_path_result( self, value: str | os.PathLike[str] @@ -1057,7 +1100,11 @@ def shell_complete( return [CompletionItem(incomplete, type=type)] -class Tuple(CompositeParamType): +class TupleInfoDict(ParamTypeInfoDict): + types: cabc.Sequence[ParamTypeInfoDict] + + +class Tuple(CompositeParamType[tuple[t.Any, ...]]): """The default behavior of Click is to apply a type on a value directly. This works well in most cases, except for when `nargs` is set to a fixed count and different types should be used for different items. In this @@ -1071,25 +1118,26 @@ class Tuple(CompositeParamType): :param types: a list of types that should be used for the tuple items. """ - def __init__(self, types: cabc.Sequence[type[t.Any] | ParamType]) -> None: - self.types: cabc.Sequence[ParamType] = [convert_type(ty) for ty in types] + def __init__(self, types: cabc.Sequence[type[t.Any] | ParamType[t.Any]]) -> None: + self.types: cabc.Sequence[ParamType[t.Any]] = [convert_type(ty) for ty in types] - def to_info_dict(self) -> dict[str, t.Any]: - info_dict = super().to_info_dict() - info_dict["types"] = [t.to_info_dict() for t in self.types] - return info_dict + def to_info_dict(self) -> TupleInfoDict: + return { + "types": [ty.to_info_dict() for ty in self.types], + **super().to_info_dict(), + } @property - def name(self) -> str: # type: ignore + def name(self) -> str: # type: ignore[override] return f"<{' '.join(ty.name for ty in self.types)}>" @property - def arity(self) -> int: # type: ignore + def arity(self) -> int: # type: ignore[override] return len(self.types) def convert( self, value: t.Any, param: Parameter | None, ctx: Context | None - ) -> t.Any: + ) -> tuple[t.Any, ...]: len_type = len(self.types) len_value = len(value) @@ -1109,7 +1157,7 @@ def convert( ) -def convert_type(ty: t.Any | None, default: t.Any | None = None) -> ParamType: +def convert_type(ty: t.Any | None, default: t.Any | None = None) -> ParamType[t.Any]: """Find the most appropriate :class:`ParamType` for the given Python type. If the type isn't provided, it can be inferred from a default value. diff --git a/tests/test_imports.py b/tests/test_imports.py index 917b245f29..74b78642bc 100644 --- a/tests/test_imports.py +++ b/tests/test_imports.py @@ -27,6 +27,7 @@ def tracking_import(module, locals=None, globals=None, fromlist=None, ALLOWED_IMPORTS = { "__future__", + "abc", "codecs", "collections", "collections.abc", @@ -49,6 +50,7 @@ def tracking_import(module, locals=None, globals=None, fromlist=None, "threading", "types", "typing", + "uuid", "weakref", }