diff --git a/tests/test_type_conversion.py b/tests/test_type_conversion.py index 83edd1ecb5..40754ba923 100644 --- a/tests/test_type_conversion.py +++ b/tests/test_type_conversion.py @@ -146,6 +146,50 @@ def custom_parser( assert result.exit_code == 0 +def test_custom_parse_with_union_type(): + """parser= should bypass the 'no Union types' assertion.""" + app = typer.Typer() + + @app.command() + def cmd( + value: int | str = typer.Argument( + None, parser=lambda x: int(x) if x.isdigit() else x + ), + ): + print(repr(value)) + + result = runner.invoke(app, ["42"]) + assert result.exit_code == 0 + assert "42" in result.output + + +def test_custom_click_type_with_union_type(): + """click_type= should bypass the 'no Union types' assertion.""" + + class FlexType(click.ParamType): + name = "flex" + + def convert( + self, value: Any, param: click.Parameter | None, ctx: click.Context | None + ) -> Any: + try: + return int(value) + except ValueError: + return value + + app = typer.Typer() + + @app.command() + def cmd( + value: int | str = typer.Argument(None, click_type=FlexType()), + ): + print(repr(value)) + + result = runner.invoke(app, ["hello"]) + assert result.exit_code == 0 + assert "hello" in result.output + + def test_custom_click_type(): class BaseNumberParamType(click.ParamType): name = "base_integer" diff --git a/typer/main.py b/typer/main.py index 6febf2091e..49771d0ddd 100644 --- a/typer/main.py +++ b/typer/main.py @@ -1665,9 +1665,14 @@ def get_click_param( if type_ is NoneType: continue types.append(type_) - assert len(types) == 1, "Typer Currently doesn't support Union types" - main_type = types[0] - origin = get_origin(main_type) + if not ( + parameter_info.parser is not None + or parameter_info.click_type is not None + ): + assert len(types) == 1, "Typer Currently doesn't support Union types" + if len(types) == 1: + main_type = types[0] + origin = get_origin(main_type) # Handle Tuples and Lists if lenient_issubclass(origin, list): main_type = get_args(main_type)[0]