11from __future__ import annotations
22
33import argparse
4+ import json
45
5- from typing import List , Literal , Union , Any , Type , TypeVar
6+ from typing import List , Literal , Union , Any , Type , TypeVar , Dict
67
78from pydantic import BaseModel
89
@@ -40,6 +41,17 @@ def _contains_list_type(annotation: Type[Any] | None) -> bool:
4041 return False
4142
4243
44+ def _contains_dict_type (annotation : Type [Any ] | None ) -> bool :
45+ origin = getattr (annotation , "__origin__" , None )
46+
47+ if origin is dict or origin is Dict :
48+ return True
49+ elif origin in (Literal , Union ):
50+ return any (_contains_dict_type (arg ) for arg in annotation .__args__ ) # type: ignore
51+ else :
52+ return False
53+
54+
4355def _parse_bool_arg (arg : str | bytes | bool ) -> bool :
4456 if isinstance (arg , bytes ):
4557 arg = arg .decode ("utf-8" )
@@ -57,6 +69,16 @@ def _parse_bool_arg(arg: str | bytes | bool) -> bool:
5769 raise ValueError (f"Invalid boolean argument: { arg } " )
5870
5971
72+ def _parse_json_object_arg (arg : str | bytes ) -> dict [str , Any ]:
73+ if isinstance (arg , bytes ):
74+ arg = arg .decode ("utf-8" )
75+
76+ value = json .loads (arg )
77+ if not isinstance (value , dict ):
78+ raise ValueError (f"Invalid JSON object argument: { arg } " )
79+ return value
80+
81+
6082def add_args_from_model (parser : argparse .ArgumentParser , model : Type [BaseModel ]):
6183 """Add arguments from a pydantic model to an argparse parser."""
6284
@@ -68,7 +90,15 @@ def add_args_from_model(parser: argparse.ArgumentParser, model: Type[BaseModel])
6890 _get_base_type (field .annotation ) if field .annotation is not None else str
6991 )
7092 list_type = _contains_list_type (field .annotation )
71- if base_type is not bool :
93+ dict_type = _contains_dict_type (field .annotation )
94+ if dict_type :
95+ parser .add_argument (
96+ f"--{ name } " ,
97+ dest = name ,
98+ type = _parse_json_object_arg ,
99+ help = description ,
100+ )
101+ elif base_type is not bool :
72102 parser .add_argument (
73103 f"--{ name } " ,
74104 dest = name ,
0 commit comments