diff --git a/pyrefly/lib/alt/class/class_field.rs b/pyrefly/lib/alt/class/class_field.rs index e2793f3fa9..1915c6294c 100644 --- a/pyrefly/lib/alt/class/class_field.rs +++ b/pyrefly/lib/alt/class/class_field.rs @@ -4089,7 +4089,22 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { } else { Instance::of_class(cls) }; - Arc::unwrap_or_clone(new_member.value).as_raw_special_method_type(self.heap, &instance) + let assume_self_return = new_member.value.is_function_without_return_annotation(); + let mut new_ty = Arc::unwrap_or_clone(new_member.value) + .as_raw_special_method_type(self.heap, &instance)?; + if assume_self_return { + // Per the constructor typing spec, unannotated `__new__` may be assumed to + // return `Self` for constructor analysis. + let ret = if preserve_self { + self.heap.mk_self_type(cls.clone()) + } else { + self.heap.mk_class_type(cls.clone()) + }; + new_ty.transform_toplevel_callable(&mut |callable: &mut Callable| { + callable.ret = ret.clone(); + }); + } + Some(new_ty) } } diff --git a/pyrefly/lib/test/attributes.rs b/pyrefly/lib/test/attributes.rs index c2ca21d44d..b899f8e4ac 100644 --- a/pyrefly/lib/test/attributes.rs +++ b/pyrefly/lib/test/attributes.rs @@ -1189,7 +1189,7 @@ class C: if orig_func is None: return super().__new__(cls) def f(): - with C(): # E: `NoneType` has no attribute `__enter__` # E: `NoneType` has no attribute `__exit__` + with C(): pass "#, ); diff --git a/pyrefly/lib/test/constructors.rs b/pyrefly/lib/test/constructors.rs index 22b268a654..29cc5c4cb3 100644 --- a/pyrefly/lib/test/constructors.rs +++ b/pyrefly/lib/test/constructors.rs @@ -989,6 +989,31 @@ C("5") # E: Argument `Literal['5']` is not assignable to parameter `x` with typ "#, ); +// Regression test for a problem in networkx: https://github.com/facebook/pyrefly/issues/3121 +testcase!( + test_return_type_inference_for_constructors, + r#" +from typing import assert_type + +class A: + def __new__(cls, x: int | None = None): + if x is None: + return cls.__new__(cls, 5) + else: + return object.__new__(cls) + + def __init__(cls): + return "x" + +class B(A): ... + +a = A() +assert_type(a, A) +b = B() +assert_type(b, B) +"#, +); + testcase!( test_redundant_dict_constructor_call_ok, r#"