Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions TypeSaveArgParse/autoargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,15 @@
except Exception:
doc_parse = None

from TypeSaveArgParse.utils import cast_all, class_to_str, enum_to_str, extract_sub_annotation, len_checker, translation_enum_to_str
from TypeSaveArgParse.utils import (
cast_all,
class_to_str,
enum_to_str,
extract_sub_annotation,
is_union,
len_checker,
translation_enum_to_str,
)

config_help = "config file path"

Expand Down Expand Up @@ -93,7 +101,8 @@ def data_class_to_arg_parse(
annotation = parameters[name].annotation
# Handling :A |B |...| None (None means Optional argument)
annotations = []
if get_origin(annotation) == Union:

if is_union(annotation): # :
for i in get_args(annotation):
if i == type(None):
can_be_none = True
Expand Down Expand Up @@ -231,7 +240,7 @@ def add_comments_to_yaml(cls, data: ruamel.yaml.CommentedMap, _addendum: str = "

# Handling :A |B |...| None (None means Optional argument)
annotations = []
if get_origin(annotation) == Union:
if is_union(annotation):
for i in get_args(annotation):
if i != type(None):
annotations.append(i)
Expand Down
6 changes: 5 additions & 1 deletion TypeSaveArgParse/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def extract_sub_annotation(annotation):


def _cast_all(val, annotation: types.GenericAlias, enum): # -> tuple[Any, ...] | set[Any] | Any | list[Any]:
if get_origin(annotation) == Union:
if is_union(annotation):
annotation = extract_sub_annotation(annotation)[0]

if isinstance(val, list):
Expand All @@ -87,3 +87,7 @@ def cast_all(val, parameter: Parameter, enum):
return val
val_out = _cast_all(val, parameter.annotation, enum)
return val_out


def is_union(annotation):
return (hasattr(types, "UnionType") and get_origin(annotation) == types.UnionType) or get_origin(annotation) == Union
59 changes: 41 additions & 18 deletions test/test_autoargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# coverage run -m unittest
# coverage report
# coverage html
import platform
import random
import sys
import unittest
Expand All @@ -15,26 +16,14 @@

sys.path.append(str(Path(__file__).parent.parent.parent / "TypeSaveArgParse"))


from TypeSaveArgParse import Class_to_ArgParse

DEFAULT_STR = str(random.randint(0, 1000))
DEFAULT_INT = random.randint(-1000, 1000)
DEFAULT_FLOAT = random.random()


@dataclass
class BASE_CASES(Class_to_ArgParse):
x: str = DEFAULT_STR
y: int = DEFAULT_INT
f: float = DEFAULT_FLOAT
z: Optional[int] = None
p: Optional[Path] = None
l_s: list[str] = field(default_factory=list)
l_i: list[int] = field(default_factory=lambda: [1, 2, 3])
tup: tuple[str, ...] = field(default_factory=tuple)
set_: set[str] = field(default_factory=set)


@dataclass
class TUP_CASES(Class_to_ArgParse):
tup: tuple[int, ...] = field(default_factory=tuple)
Expand All @@ -50,11 +39,45 @@ class Dummy_Enum(Enum):
THIRD = auto()


@dataclass
class ENUM_CASES(Class_to_ArgParse):
enu: Dummy_Enum = Dummy_Enum.ONE
enu_list: Optional[list[Dummy_Enum]] = None
enu_list2: list[Dummy_Enum] = field(default_factory=list)
py_version = int(platform.python_version().split(".")[1])
if py_version <= 9:

@dataclass
class BASE_CASES(Class_to_ArgParse): # type: ignore
x: str = DEFAULT_STR
y: int = DEFAULT_INT
f: float = DEFAULT_FLOAT
z: Optional[int] = None
p: Optional[Path] = None
l_s: list[str] = field(default_factory=list)
l_i: list[int] = field(default_factory=lambda: [1, 2, 3])
tup: tuple[str, ...] = field(default_factory=tuple)
set_: set[str] = field(default_factory=set)

@dataclass
class ENUM_CASES(Class_to_ArgParse): # type: ignore
enu: Dummy_Enum = Dummy_Enum.ONE
enu_list: Optional[list[Dummy_Enum]] = None
enu_list2: list[Dummy_Enum] = field(default_factory=list)
else:

@dataclass
class BASE_CASES(Class_to_ArgParse):
x: str = DEFAULT_STR
y: int = DEFAULT_INT
f: float = DEFAULT_FLOAT
z: int | None = None
p: Path | None = None
l_s: list[str] = field(default_factory=list)
l_i: list[int] = field(default_factory=lambda: [1, 2, 3])
tup: tuple[str, ...] = field(default_factory=tuple)
set_: set[str] = field(default_factory=set)

@dataclass
class ENUM_CASES(Class_to_ArgParse):
enu: Dummy_Enum = Dummy_Enum.ONE
enu_list: list[Dummy_Enum] | None = None
enu_list2: list[Dummy_Enum] = field(default_factory=list)


def assert_(value, x, type_):
Expand Down
149 changes: 94 additions & 55 deletions test/test_recursive.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# coverage run -m unittest
# coverage report
# coverage html
import platform
import random
import sys
import unittest
Expand All @@ -21,34 +22,103 @@
DEFAULT_STR = str(random.randint(0, 1000))
DEFAULT_INT = random.randint(-1000, 1000)
DEFAULT_FLOAT = random.random()
py_version = int(platform.python_version().split(".")[1])


@dataclass
class BASE_CASES(Class_to_ArgParse):
"""
Class representing base cases for argument parsing.
class Dummy_Enum(Enum):
ONE = auto()
SECOND = auto()
THIRD = auto()

Attributes:
x (str): String attribute with default value DEFAULT_STR.
y (int): Integer attribute with default value DEFAULT_INT.
f (float): Float attribute with default value DEFAULT_FLOAT.
z (Optional[int]): Optional integer attribute.
p (Optional[Path]): Optional Path attribute.
l_s (List[str]): List of strings.
l_i (List[int]): List of integers with default value [1, 2, 3].
tup (Tuple[str, ...]): Tuple of strings.
set_ (Set[str]): Set of strings.
"""

x: str = DEFAULT_STR
y: int = DEFAULT_INT
f: float = DEFAULT_FLOAT
z: Optional[int] = None
p: Optional[Path] = None
l_s: list[str] = field(default_factory=list)
l_i: list[int] = field(default_factory=lambda: [1, 2, 3])
tup: tuple[str, ...] = field(default_factory=tuple)
set_: set[str] = field(default_factory=set)
if py_version <= 9:

@dataclass
class BASE_CASES(Class_to_ArgParse): # type: ignore
"""
Class representing base cases for argument parsing.

Attributes:
x (str): String attribute with default value DEFAULT_STR.
y (int): Integer attribute with default value DEFAULT_INT.
f (float): Float attribute with default value DEFAULT_FLOAT.
z (Optional[int]): Optional integer attribute.
p (Optional[Path]): Optional Path attribute.
l_s (List[str]): List of strings.
l_i (List[int]): List of integers with default value [1, 2, 3].
tup (Tuple[str, ...]): Tuple of strings.
set_ (Set[str]): Set of strings.
"""

x: str = DEFAULT_STR
y: int = DEFAULT_INT
f: float = DEFAULT_FLOAT
z: Optional[int] = None
p: Optional[Path] = None
l_s: list[str] = field(default_factory=list)
l_i: list[int] = field(default_factory=lambda: [1, 2, 3])
tup: tuple[str, ...] = field(default_factory=tuple)
set_: set[str] = field(default_factory=set)

@dataclass
class ENUM_CASES(Class_to_ArgParse): # type: ignore
"""
Class representing enum cases for argument parsing.

Attributes:
enu (Dummy_Enum): Enum attribute with default value Dummy_Enum.ONE.
enu_list (Optional[List[Dummy_Enum]]): Optional list of enum instances.
enu_list2 (List[Dummy_Enum]): List of enum instances.
"""

enu: Dummy_Enum = Dummy_Enum.ONE
enu_list: Optional[list[Dummy_Enum]] = None
enu_list2: list[Dummy_Enum] = field(default_factory=list)


else:

@dataclass
class BASE_CASES(Class_to_ArgParse):
"""
Class representing base cases for argument parsing.

Attributes:
x (str): String attribute with default value DEFAULT_STR.
y (int): Integer attribute with default value DEFAULT_INT.
f (float): Float attribute with default value DEFAULT_FLOAT.
z (Optional[int]): Optional integer attribute.
p (Optional[Path]): Optional Path attribute.
l_s (List[str]): List of strings.
l_i (List[int]): List of integers with default value [1, 2, 3].
tup (Tuple[str, ...]): Tuple of strings.
set_ (Set[str]): Set of strings.
"""

x: str = DEFAULT_STR
y: int = DEFAULT_INT
f: float = DEFAULT_FLOAT
z: int | None = None
p: Path | None = None
l_s: list[str] = field(default_factory=list)
l_i: list[int] = field(default_factory=lambda: [1, 2, 3])
tup: tuple[str, ...] = field(default_factory=tuple)
set_: set[str] = field(default_factory=set)

@dataclass
class ENUM_CASES(Class_to_ArgParse):
"""
Class representing enum cases for argument parsing.

Attributes:
enu (Dummy_Enum): Enum attribute with default value Dummy_Enum.ONE.
enu_list (Optional[List[Dummy_Enum]]): Optional list of enum instances.
enu_list2 (List[Dummy_Enum]): List of enum instances.
"""

enu: Dummy_Enum = Dummy_Enum.ONE
enu_list: list[Dummy_Enum] | None = None
enu_list2: list[Dummy_Enum] = field(default_factory=list)


@dataclass
Expand Down Expand Up @@ -85,37 +155,6 @@ class TUP_CASES(Class_to_ArgParse):
tup4: tuple[int, int, int, int] = field(default_factory=tuple)


class Dummy_Enum(Enum):
"""
Class representing enum cases for argument parsing.

Attributes:
enu (Dummy_Enum): Enum attribute with default value Dummy_Enum.ONE.
enu_list (Optional[List[Dummy_Enum]]): Optional list of enum instances.
enu_list2 (List[Dummy_Enum]): List of enum instances.
"""

ONE = auto()
SECOND = auto()
THIRD = auto()


@dataclass
class ENUM_CASES(Class_to_ArgParse):
"""
Class representing enum cases for argument parsing.

Attributes:
enu (Dummy_Enum): Enum attribute with default value Dummy_Enum.ONE.
enu_list (Optional[List[Dummy_Enum]]): Optional list of enum instances.
enu_list2 (List[Dummy_Enum]): List of enum instances.
"""

enu: Dummy_Enum = Dummy_Enum.ONE
enu_list: Optional[list[Dummy_Enum]] = None
enu_list2: list[Dummy_Enum] = field(default_factory=list)


@dataclass()
class REC(Class_to_ArgParse):
"""
Expand Down
52 changes: 37 additions & 15 deletions test/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# coverage run -m unittest
# coverage report
# coverage html
import platform
import random
import sys
import unittest
Expand All @@ -22,28 +23,49 @@
DEFAULT_INT = random.randint(-1000, 1000)
DEFAULT_FLOAT = random.random()

py_version = int(platform.python_version().split(".")[1])


class Dummy_Enum(Enum):
ONE = auto()
SECOND = auto()
THIRD = auto()


@dataclass
class BASE_CASES(Class_to_ArgParse):
x: str = ""
y: int = -1000000
f: float = -0.3
enum: Dummy_Enum = Dummy_Enum.ONE
enum_tuple: tuple[Dummy_Enum, Dummy_Enum] = field(default_factory=lambda: (Dummy_Enum.ONE, Dummy_Enum.ONE))
z: Optional[int] = None
p: Optional[Path] = None
l_s: list[str] = field(default_factory=lambda: ["Wam", "Bam"])
l_i: list[int] = field(default_factory=lambda: [1, 2, 3])
tup: tuple[str, ...] = field(default_factory=tuple)
set_s: set[str] = field(default_factory=set)
set_i: set[int] = field(default_factory=set)
b: bool = False
if py_version <= 9:

@dataclass
class BASE_CASES(Class_to_ArgParse): # type: ignore
x: str = ""
y: int = -1000000
f: float = -0.3
enum: Dummy_Enum = Dummy_Enum.ONE
enum_tuple: tuple[Dummy_Enum, Dummy_Enum] = field(default_factory=lambda: (Dummy_Enum.ONE, Dummy_Enum.ONE))
z: Optional[int] = None
p: Optional[Path] = None
l_s: list[str] = field(default_factory=lambda: ["Wam", "Bam"])
l_i: list[int] = field(default_factory=lambda: [1, 2, 3])
tup: tuple[str, ...] = field(default_factory=tuple)
set_s: set[str] = field(default_factory=set)
set_i: set[int] = field(default_factory=set)
b: bool = False
else:

@dataclass
class BASE_CASES(Class_to_ArgParse):
x: str = ""
y: int = -1000000
f: float = -0.3
enum: Dummy_Enum = Dummy_Enum.ONE
enum_tuple: tuple[Dummy_Enum, Dummy_Enum] = field(default_factory=lambda: (Dummy_Enum.ONE, Dummy_Enum.ONE))
z: int | None = None
p: Path | None = None
l_s: list[str] = field(default_factory=lambda: ["Wam", "Bam"])
l_i: list[int] = field(default_factory=lambda: [1, 2, 3])
tup: tuple[str, ...] = field(default_factory=tuple)
set_s: set[str] = field(default_factory=set)
set_i: set[int] = field(default_factory=set)
b: bool = False


class Test_save_load(unittest.TestCase):
Expand Down