Skip to content

Commit a10f412

Browse files
committed
Add TapIgnore test and cleanup
1 parent ab2e6a9 commit a10f412

File tree

3 files changed

+106
-1
lines changed

3 files changed

+106
-1
lines changed

src/tap/tap.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,7 @@ def parse_args(
505505
return self
506506

507507
@classmethod
508-
def _get_from_self_and_super(cls, extract_func: Callable[[type], dict]) -> dict[str, Any] | dict:
508+
def _get_from_self_and_super(cls, extract_func: Callable[[type], dict]) -> dict[str, Any]:
509509
"""Returns a dictionary mapping variable names to values.
510510
511511
Variables and values are extracted from classes using key starting

src/tap/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -618,5 +618,7 @@ class _TapIgnoreMarker:
618618
619619
class Args(Tap):
620620
a: int
621+
622+
# TapIgnore is generic and preserves the type of the ignored attribute
621623
e: TapIgnore[int] = 5
622624
"""

tests/test_tap_ignore.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import unittest
2+
from typing import Annotated
3+
from tap import Tap, TapIgnore
4+
5+
6+
class TapIgnoreTests(unittest.TestCase):
7+
def test_tap_ignore(self):
8+
class Args(Tap):
9+
a: int
10+
b: TapIgnore[int] = 2
11+
c: Annotated[int, "metadata"] = 3
12+
d: Annotated[TapIgnore[int], "metadata"] = 4
13+
e: TapIgnore[Annotated[int, "metadata"]] = 5
14+
15+
args = Args().parse_args(["--a", "1"])
16+
17+
self.assertEqual(args.a, 1)
18+
self.assertEqual(args.b, 2)
19+
self.assertEqual(args.c, 3)
20+
self.assertEqual(args.d, 4)
21+
self.assertEqual(args.e, 5)
22+
23+
# Check that b is not in the help message (indirectly checking it's not an argument)
24+
# Or check _actions
25+
26+
actions = {a.dest for a in args._actions}
27+
self.assertIn("a", actions)
28+
self.assertNotIn("b", actions)
29+
self.assertIn("c", actions)
30+
self.assertNotIn("d", actions)
31+
self.assertNotIn("e", actions)
32+
33+
def test_tap_ignore_no_default(self):
34+
class Args(Tap):
35+
a: int
36+
b: TapIgnore[int]
37+
38+
# If b is ignored, it shouldn't be required by argparse
39+
# But if it has no default, accessing it might raise AttributeError if not set?
40+
# Tap doesn't set it if it's not in arguments.
41+
42+
args = Args().parse_args(["--a", "1"])
43+
self.assertEqual(args.a, 1)
44+
45+
# b should not be set
46+
with self.assertRaises(AttributeError):
47+
_ = args.b
48+
49+
def test_tap_ignore_annotated_unwrapping(self):
50+
class Args(Tap):
51+
a: Annotated[int, "some metadata"]
52+
53+
args = Args().parse_args(["--a", "1"])
54+
self.assertEqual(args.a, 1)
55+
56+
def test_tap_ignore_subclass(self):
57+
class BaseArgs(Tap):
58+
base_keep: int
59+
base_ignore: TapIgnore[str] = "ignore_me"
60+
61+
class SubArgs(BaseArgs):
62+
sub_keep: float
63+
sub_ignore: TapIgnore[bool] = True
64+
65+
args = SubArgs().parse_args(["--base_keep", "1", "--sub_keep", "2.5"])
66+
67+
self.assertEqual(args.base_keep, 1)
68+
self.assertEqual(args.base_ignore, "ignore_me")
69+
self.assertEqual(args.sub_keep, 2.5)
70+
self.assertEqual(args.sub_ignore, True)
71+
72+
actions = {a.dest for a in args._actions}
73+
self.assertIn("base_keep", actions)
74+
self.assertNotIn("base_ignore", actions)
75+
self.assertIn("sub_keep", actions)
76+
self.assertNotIn("sub_ignore", actions)
77+
78+
def test_tap_ignore_subclass_override(self):
79+
# Case 1: Override ignored with argument
80+
class Base1(Tap):
81+
a: TapIgnore[int] = 1
82+
83+
class Sub1(Base1):
84+
a: int = 2
85+
86+
args1 = Sub1().parse_args([])
87+
self.assertEqual(args1.a, 2)
88+
self.assertIn("a", {a.dest for a in args1._actions})
89+
90+
# Case 2: Override argument with ignored
91+
class Base2(Tap):
92+
b: int = 3
93+
94+
class Sub2(Base2):
95+
b: TapIgnore[int] = 4
96+
97+
args2 = Sub2().parse_args([])
98+
self.assertEqual(args2.b, 4)
99+
self.assertNotIn("b", {a.dest for a in args2._actions})
100+
101+
102+
if __name__ == "__main__":
103+
unittest.main()

0 commit comments

Comments
 (0)