Skip to content

Commit 23a7a5c

Browse files
committed
Prevent recursive types from interfering with each other's evaluation.
1 parent 11797fe commit 23a7a5c

2 files changed

Lines changed: 62 additions & 4 deletions

File tree

typemap/type_eval/_eval_operators.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,10 @@ def _eval_Members(tp, *, ctx):
232232

233233
@type_eval.register_evaluator(FromUnion)
234234
def _eval_FromUnion(tp, *, ctx):
235-
return tuple[*_union_elems(tp, ctx)]
235+
if tp in ctx.known_recursive_types:
236+
return tuple[*_union_elems(ctx.known_recursive_types[tp], ctx)]
237+
else:
238+
return tuple[*_union_elems(tp, ctx)]
236239

237240

238241
##################################################################

typemap/type_eval/_eval_typing.py

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ class _EvalProxy:
4444

4545
@dataclasses.dataclass
4646
class EvalContext:
47+
# Fully resolved types
48+
resolved: dict[Any, Any] = dataclasses.field(default_factory=dict)
49+
# Types that have been seen, but may not be fully resolved
4750
seen: dict[Any, Any] = dataclasses.field(default_factory=dict)
4851
# The typing.Any is really a types.FunctionType, but mypy gets
4952
# confused and wants to treat it as a MethodType.
@@ -52,6 +55,14 @@ class EvalContext:
5255
)
5356
current_alias: types.GenericAlias | typing.Any | None = None
5457

58+
unwind_stack: set[typing.TypeAliasType | types.GenericAlias] = (
59+
dataclasses.field(default_factory=set)
60+
)
61+
unwinding_until: typing.TypeAliasType | types.GenericAlias | None = None
62+
known_recursive_types: dict[
63+
typing.TypeAliasType | types.GenericAlias, typing.Any
64+
] = dataclasses.field(default_factory=dict)
65+
5566

5667
# `eval_types()` calls can be nested, context must be preserved
5768
_current_context: contextvars.ContextVar[EvalContext | None] = (
@@ -100,9 +111,22 @@ def _child_context() -> typing.Iterator[EvalContext]:
100111

101112
try:
102113
child_ctx = EvalContext(
114+
resolved={
115+
# Drop resolved recursive types.
116+
# This is to allow other recursive types to expand them out
117+
# independently. For example, if we have a recursive types
118+
# A = B|C and B = A|D, we want B to expand even if we already
119+
# know A.
120+
k: v
121+
for k, v in ctx.resolved.items()
122+
if k not in ctx.known_recursive_types
123+
},
103124
seen=ctx.seen.copy(),
104125
current_alias_stack=ctx.current_alias_stack.copy(),
105126
current_alias=ctx.current_alias,
127+
unwind_stack=ctx.unwind_stack.copy(),
128+
unwinding_until=ctx.unwinding_until,
129+
known_recursive_types=ctx.known_recursive_types.copy(),
106130
)
107131
_current_context.set(child_ctx)
108132
yield child_ctx
@@ -112,21 +136,52 @@ def _child_context() -> typing.Iterator[EvalContext]:
112136

113137
def eval_typing(obj: typing.Any):
114138
with _ensure_context() as ctx:
115-
return _eval_types(obj, ctx)
139+
result = _eval_types(obj, ctx)
140+
if result in ctx.known_recursive_types:
141+
result = ctx.known_recursive_types[result]
142+
return result
116143

117144

118145
def _eval_types(obj: typing.Any, ctx: EvalContext):
146+
# Found a recursive type, we need to unwind it
147+
if obj in ctx.unwind_stack:
148+
ctx.unwinding_until = obj
149+
return obj
150+
119151
# Don't recurse into any pending alias expansion
120152
if obj in ctx.current_alias_stack:
121153
return obj
122-
# strings match
154+
155+
# Already resolved or seen, return the result
156+
if obj in ctx.resolved:
157+
return ctx.resolved[obj]
123158
if obj in ctx.seen:
124159
return ctx.seen[obj]
125160

126161
with _child_context() as child_ctx:
162+
child_ctx.unwind_stack.add(obj)
127163
evaled = _eval_types_impl(obj, child_ctx)
128164

129-
ctx.seen[obj] = evaled
165+
# If we have identified a recursive type, discard evaluation results.
166+
# This prevents external evaluations from being polluted by partial
167+
# evaluations.
168+
keep_intermediate = True
169+
if child_ctx.unwinding_until:
170+
if child_ctx.unwinding_until == obj:
171+
# Finished unwinding.
172+
ctx.known_recursive_types[obj] = evaled
173+
evaled = obj
174+
keep_intermediate = False
175+
176+
else:
177+
ctx.unwinding_until = child_ctx.unwinding_until
178+
179+
if keep_intermediate:
180+
ctx.resolved |= child_ctx.resolved
181+
ctx.seen |= child_ctx.seen
182+
ctx.known_recursive_types |= child_ctx.known_recursive_types
183+
184+
ctx.resolved[obj] = evaled
130185
return evaled
131186

132187

0 commit comments

Comments
 (0)