@@ -187,62 +187,61 @@ def make_func(
187187 return new_func
188188
189189
190- def _get_closure_types (af : types .FunctionType ) -> dict [str , type ]:
191- # Generate a fallback mapping of closure classes.
192- # This is needed for locally defined generic types which reference
193- # themselves in their type annotations.
194- if not af .__closure__ :
195- return {}
196- return {
197- name : variable .cell_contents
198- for name , variable in zip (
199- af .__code__ .co_freevars , af .__closure__ , strict = True
200- )
201- }
202-
203-
204190EXCLUDED_ATTRIBUTES = typing .EXCLUDED_ATTRIBUTES - {'__init__' } # type: ignore[attr-defined]
205191
206192
207- def get_local_defns (boxed : Boxed ) -> tuple [dict [str , Any ], dict [str , Any ]]:
208- annos : dict [str , Any ] = {}
209- dct : dict [str , Any ] = {}
210-
211- if af := typing .cast (
212- types .FunctionType , getattr (boxed .cls , "__annotate__" , None )
213- ):
214- # Class has annotations, let's resolve generic arguments
215-
216- closure_types = _get_closure_types (af )
217- args = tuple (
218- types .CellType (
219- boxed .cls .__dict__
220- if name == "__classdict__"
221- else boxed .str_args [name ]
222- if name in boxed .str_args
223- else closure_types [name ]
193+ def get_annotations (
194+ obj : object ,
195+ args : dict [str , object ],
196+ key : str = '__annotate__' ,
197+ ) -> Any | None :
198+ """Get the annotations on an object, substituting in type vars."""
199+
200+ # Copy in any __type_params__ that aren't provided for, so that if
201+ # we have to eval, we have them.
202+ if params := getattr (obj , "__type_params__" , None ):
203+ args = args .copy ()
204+ for param in params :
205+ if str (param ) not in args :
206+ args [str (param )] = param
207+
208+ rr = None
209+ if af := typing .cast (types .FunctionType , getattr (obj , key , None )):
210+ # Substitute in names that are provided but keep the existing
211+ # values for everything else.
212+ closure = tuple (
213+ types .CellType (args [name ]) if name in args else orig_value
214+ for name , orig_value in zip (
215+ af .__code__ .co_freevars , af .__closure__ or (), strict = True
224216 )
225- for name in af .__code__ .co_freevars
226217 )
227218
228219 ff = types .FunctionType (
229- af .__code__ , af .__globals__ , af .__name__ , None , args
220+ af .__code__ , af .__globals__ , af .__name__ , None , closure
230221 )
231222 rr = ff (annotationlib .Format .VALUE )
232223
233- if rr :
234- for k , v in rr .items ():
235- if isinstance (v , str ):
236- # Handle cases where annotation is explicitly a string,
237- # e.g.:
238- #
239- # class Foo[X]:
240- # x: "Foo[X | None]"
224+ if isinstance ( rr , dict ) :
225+ for k , v in rr .items ():
226+ if isinstance (v , str ):
227+ # Handle cases where annotation is explicitly a string,
228+ # e.g.:
229+ #
230+ # class Foo[X]:
231+ # x: "Foo[X | None]"
241232
242- annos [k ] = eval (v , af .__globals__ , boxed .str_args )
243- else :
244- annos [k ] = v
245- elif af := getattr (boxed .cls , "__annotations__" , None ):
233+ rr [k ] = eval (v , af .__globals__ , args )
234+
235+ return rr
236+
237+
238+ def get_local_defns (boxed : Boxed ) -> tuple [dict [str , Any ], dict [str , Any ]]:
239+ annos : dict [str , Any ] = {}
240+ dct : dict [str , Any ] = {}
241+
242+ if (rr := get_annotations (boxed .cls , boxed .str_args )) is not None :
243+ annos .update (rr )
244+ elif anns := getattr (boxed .cls , "__annotations__" , None ):
246245 # TODO: substitute vars in this case
247246 _globals = {}
248247 if mod := sys .modules .get (boxed .cls .__module__ ):
@@ -252,7 +251,7 @@ def get_local_defns(boxed: Boxed) -> tuple[dict[str, Any], dict[str, Any]]:
252251 _locals = dict (boxed .cls .__dict__ )
253252 _locals .update (boxed .str_args )
254253
255- for k , v in af .items ():
254+ for k , v in anns .items ():
256255 if isinstance (v , str ):
257256 result = eval (v , _globals , _locals )
258257 # Handle cases where annotation is explicitly a string
@@ -273,42 +272,13 @@ def get_local_defns(boxed: Boxed) -> tuple[dict[str, Any], dict[str, Any]]:
273272 stuff = inspect .unwrap (orig )
274273
275274 if isinstance (stuff , types .FunctionType ):
276- local_fn : types .FunctionType | classmethod | staticmethod | None = (
277- None
278- )
279-
280- if af := typing .cast (
281- types .FunctionType , getattr (stuff , "__annotate__" , None )
282- ):
283- params = dict (
284- zip (
285- map (str , stuff .__type_params__ ),
286- stuff .__type_params__ ,
287- strict = True ,
288- )
289- )
290-
291- closure_types = _get_closure_types (af )
292- args = tuple (
293- types .CellType (
294- boxed .cls .__dict__
295- if name == "__classdict__"
296- else params [name ]
297- if name in params
298- else boxed .str_args [name ]
299- if name in boxed .str_args
300- else closure_types [name ]
301- )
302- for name in af .__code__ .co_freevars
303- )
304-
305- ff = types .FunctionType (
306- af .__code__ , af .__globals__ , af .__name__ , None , args
307- )
308- rr = ff (annotationlib .Format .VALUE )
275+ local_fn : Any = None
309276
277+ if (rr := get_annotations (stuff , boxed .str_args )) is not None :
310278 local_fn = make_func (orig , rr )
311- elif af := getattr (stuff , "__annotations__" , None ):
279+ elif anns := getattr (stuff , "__annotations__" , None ):
280+ # XXX: This is totally wrong; we still need to do
281+ # substitute in class vars
312282 local_fn = stuff
313283
314284 if local_fn is not None :
0 commit comments