diff --git a/CHANGES.rst b/CHANGES.rst index 76f2b0066..bf2494c70 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -51,6 +51,12 @@ Unreleased commands. :issue:`3107` :pr:`3228` - Add ``click.get_pager_file`` for file-like access to an output pager. :pr:`1572` +- ``click.prompt`` and ``ParamType`` fully generically typed with the latter + receiving a new optional ``ParamTypeInputValue`` generic type for the + expected input type that defaults to ``Any``. Additionally, + ``click.prompt`` implementation changed slightly so that when a default + value is the same type as the expected type, it does not do a round trip + through the value processor nor the type conversion. :pr:`3407` Version 8.3.3 ------------- diff --git a/pyproject.toml b/pyproject.toml index 5a0e37d04..92a24a116 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ classifiers = [ requires-python = ">=3.10" dependencies = [ "colorama; platform_system == 'Windows'", + "typing_extensions; python_version < '3.13'", ] [project.urls] diff --git a/src/click/termui.py b/src/click/termui.py index 892e4c0bc..9191b22f3 100644 --- a/src/click/termui.py +++ b/src/click/termui.py @@ -1,5 +1,6 @@ from __future__ import annotations +import builtins import collections.abc as cabc import inspect import io @@ -25,7 +26,13 @@ if t.TYPE_CHECKING: from ._termui_impl import ProgressBar + if sys.version_info >= (3, 13): + from typing import TypeIs + else: + from typing_extensions import TypeIs + V = t.TypeVar("V") +C = t.TypeVar("C") # The prompt functions to use. The doc tools currently override these # functions to customize how they work. @@ -83,39 +90,51 @@ def _build_prompt( text: str, suffix: str, show_default: bool | str = False, - default: t.Any | None = None, + default: object | None = None, show_choices: bool = True, - type: ParamType[t.Any] | None = None, + type: object | None = None, ) -> str: prompt = text if type is not None and show_choices and isinstance(type, Choice): prompt += f" ({', '.join(map(str, type.choices))})" - if isinstance(show_default, str): - default = f"({show_default})" - if default is not None and show_default: - prompt = f"{prompt} [{_format_default(default)}]" - return f"{prompt}{suffix}" + default_preview = "" + if show_default: + if isinstance(show_default, str): + default_preview = f" [({show_default})]" + elif default is not None: + default_preview = f" [{_format_default(default)}]" + return f"{prompt}{default_preview}{suffix}" + +def _format_default(default: V) -> V | str: + if isinstance(default, (io.IOBase, LazyFile)): + name = getattr(default, "name", None) -def _format_default(default: t.Any) -> t.Any: - if isinstance(default, (io.IOBase, LazyFile)) and hasattr(default, "name"): - return default.name + if name is not None: + return str(name) return default +def _is_expected_type( + default: object, + type: ParamType[V, t.Any] | V | None, +) -> TypeIs[V]: + return builtins.type(default) is builtins.type(type) + + def prompt( text: str, - default: t.Any | None = None, + default: V | C | str | None = None, hide_input: bool = False, confirmation_prompt: bool | str = False, - type: ParamType[t.Any] | t.Any | None = None, - value_proc: t.Callable[[str], t.Any] | None = None, + type: ParamType[V, C | str] | V | None = None, + value_proc: t.Callable[[C | str], V] | None = None, prompt_suffix: str = ": ", show_default: bool | str = True, err: bool = False, show_choices: bool = True, -) -> t.Any: +) -> V: """Prompts a user for input. This is a convenience function that can be used to prompt a user for input later. @@ -145,6 +164,11 @@ def prompt( show_choices is true and text is "Group by" then the prompt will be "Group by (day, week): ". + .. versionchanged:: 8.4.0 + ``default`` no longer passes through the ``value_proc`` callback, + nor the constructor of the types of ``type`` or ``default`` field, + when it is the same type as ``type``. + .. versionchanged:: 8.3.3 ``show_default`` can be a string to show a custom value instead of the actual default, matching the help text behavior. @@ -192,21 +216,29 @@ def prompt_func(text: str) -> str: confirmation_prompt = _build_prompt(confirmation_prompt, prompt_suffix) while True: + result: V | None = None while True: - value = prompt_func(prompt) + value: C | str = prompt_func(prompt) if value: break elif default is not None: - value = default + if _is_expected_type(default=default, type=type): + # It's the expected type, don't reparse it. + result = default + else: + # It's not the expected type. Pass it through value_proc before + # returning. + value = t.cast(C | str, default) # type: ignore break - try: - result = value_proc(value) - except UsageError as e: - if hide_input: - echo(_("Error: The value you entered was invalid."), err=err) - else: - echo(_("Error: {e.message}").format(e=e), err=err) - continue + if result is None: + try: + result = t.cast(V, value_proc(value)) + except UsageError as e: + if hide_input: + echo(_("Error: The value you entered was invalid."), err=err) + else: + echo(_("Error: {e.message}").format(e=e), err=err) + continue if not confirmation_prompt: return result while True: diff --git a/src/click/types.py b/src/click/types.py index 355e98423..bd7a97d49 100644 --- a/src/click/types.py +++ b/src/click/types.py @@ -19,6 +19,12 @@ from .utils import LazyFile from .utils import safecall +# TypeVar(default=...) support. +if sys.version_info >= (3, 13): + from typing import TypeVar +else: + from typing_extensions import TypeVar + if t.TYPE_CHECKING: import typing_extensions as te @@ -26,7 +32,8 @@ from .core import Parameter from .shell_completion import CompletionItem -ParamTypeValue = t.TypeVar("ParamTypeValue") +ParamTypeValue = TypeVar("ParamTypeValue") +ParamTypeInputValue = TypeVar("ParamTypeInputValue", default=t.Any) class ParamTypeInfoDict(t.TypedDict): @@ -34,7 +41,7 @@ class ParamTypeInfoDict(t.TypedDict): name: str -class ParamType(t.Generic[ParamTypeValue], abc.ABC): +class ParamType(t.Generic[ParamTypeValue, ParamTypeInputValue], abc.ABC): """Represents the type of a parameter. Validates and converts values from the command line or Python into the correct type. @@ -87,9 +94,25 @@ def to_info_dict(self) -> ParamTypeInfoDict: return {"param_type": param_type, "name": name} + @t.overload + def __call__( + self, + value: None, + param: Parameter | None = None, + ctx: Context | None = None, + ) -> None: ... + + @t.overload def __call__( self, - value: t.Any, + value: ParamTypeInputValue, + param: Parameter | None = None, + ctx: Context | None = None, + ) -> ParamTypeValue: ... + + def __call__( + self, + value: ParamTypeInputValue | None, param: Parameter | None = None, ctx: Context | None = None, ) -> ParamTypeValue | None: @@ -108,7 +131,7 @@ def get_missing_message(self, param: Parameter, ctx: Context | None) -> str | No """ def convert( - self, value: t.Any, param: Parameter | None, ctx: Context | None + self, value: ParamTypeInputValue, param: Parameter | None, ctx: Context | None ) -> ParamTypeValue: """Convert the value to the correct type. This is not called if the value is ``None`` (the missing value). diff --git a/tests/test_imports.py b/tests/test_imports.py index 74b78642b..6e8a4261f 100644 --- a/tests/test_imports.py +++ b/tests/test_imports.py @@ -28,6 +28,7 @@ def tracking_import(module, locals=None, globals=None, fromlist=None, ALLOWED_IMPORTS = { "__future__", "abc", + "builtins", "codecs", "collections", "collections.abc", @@ -50,6 +51,7 @@ def tracking_import(module, locals=None, globals=None, fromlist=None, "threading", "types", "typing", + "typing_extensions", "uuid", "weakref", } diff --git a/uv.lock b/uv.lock index 278506184..b27fe37e5 100644 --- a/uv.lock +++ b/uv.lock @@ -177,6 +177,7 @@ version = "8.3.3" source = { editable = "." } dependencies = [ { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, ] [package.dev-dependencies] @@ -221,7 +222,10 @@ typing = [ ] [package.metadata] -requires-dist = [{ name = "colorama", marker = "sys_platform == 'win32'" }] +requires-dist = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, +] [package.metadata.requires-dev] dev = [