diff --git a/pyrefly/lib/alt/class/class_field.rs b/pyrefly/lib/alt/class/class_field.rs index 1194c8dcdb..a973860697 100644 --- a/pyrefly/lib/alt/class/class_field.rs +++ b/pyrefly/lib/alt/class/class_field.rs @@ -1562,10 +1562,10 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { ) } ClassFieldDefinition::DefinedInMethod { - value, - method, + values, + method: _, annotation: annot, - .. + receiver_kind, } => { let direct_annotation = annot.map(|a| self.get_idx(a).annotation.clone()); // Check if there's an inherited property or descriptor field from a parent class. @@ -1597,20 +1597,39 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { } } - let initialization = match method.instance_or_class { + let initialization = match receiver_kind { MethodSelfKind::Class => ClassFieldInitialization::ClassMethod, MethodSelfKind::Instance => ClassFieldInitialization::Method, }; - let (mut value_ty, annotation, is_inherited) = self.analyze_class_field_value( - value, - class, - name, - direct_annotation.as_ref(), - true, - range, - errors, - ); - if matches!(method.instance_or_class, MethodSelfKind::Instance) { + + // Solve the type of each collected assignment and union them. + let mut union_types = Vec::new(); + let mut overall_annotation = None; + let mut overall_is_inherited = IsInherited::No; + + for value in values { + let (value_ty, annotation, is_inherited) = self.analyze_class_field_value( + value, + class, + name, + direct_annotation.as_ref(), + true, + range, + errors, + ); + union_types.push(value_ty); + if overall_annotation.is_none() { + overall_annotation = annotation; + } + if matches!(is_inherited, IsInherited::Maybe) { + overall_is_inherited = IsInherited::Maybe; + } + } + + // Union all the resolved assignment types. + let mut value_ty = unions(union_types, self.heap); + + if matches!(receiver_kind, MethodSelfKind::Instance) { value_ty = self .check_and_sanitize_type_parameters(class, value_ty, name, range, errors); } @@ -1618,8 +1637,8 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { initialization, false, value_ty, - annotation, - is_inherited, + overall_annotation, + overall_is_inherited, direct_annotation, ) } diff --git a/pyrefly/lib/binding/binding.rs b/pyrefly/lib/binding/binding.rs index 03fdf96c76..4485270453 100644 --- a/pyrefly/lib/binding/binding.rs +++ b/pyrefly/lib/binding/binding.rs @@ -129,7 +129,7 @@ assert_bytes!(BindingClassChecks, 4); assert_bytes!(BindingClassDisjointBase, 4); assert_bytes!(BindingAbstractClassCheck, 4); assert_bytes!(BindingClassSubscriptSymmetry, 4); -assert_words!(BindingClassField, 11); +assert_words!(BindingClassField, 13); assert_bytes!(BindingClassSynthesizedFields, 4); assert_bytes!(BindingLegacyTypeParam, 16); assert_words!(BindingYield, 4); @@ -3028,9 +3028,16 @@ pub enum ClassFieldDefinition { /// Implicitly defined in a method, without any explicit reference /// in the class body. DefinedInMethod { - value: Box, + values: Vec, annotation: Option>, + // Refers to the first method in which the attribute was assigned. + // We prioritize recognized constructors over normal methods; if there are multiple + // constructors, this refers to the first constructor processed. method: MethodThatSetsAttr, + // The combined receiver kind of this class field (upgraded to Class if any constructor + // assigns to it via `cls.`). We track this separately from `method` to avoid mutating + // the method descriptor's own metadata (e.g. keeping `__init__` as an instance method). + receiver_kind: MethodSelfKind, }, } @@ -3079,12 +3086,15 @@ impl DisplayWith for ClassFieldDefinition { ctx.display(*definition), ) } - Self::DefinedInMethod { value, .. } => { - write!( - f, - "ClassFieldDefinition::DefinedInMethod({}, ..)", - value.display_with(ctx), - ) + Self::DefinedInMethod { values, .. } => { + write!(f, "ClassFieldDefinition::DefinedInMethod([")?; + for (i, v) in values.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{}", v.display_with(ctx))?; + } + write!(f, "], ..)") } } } diff --git a/pyrefly/lib/binding/function.rs b/pyrefly/lib/binding/function.rs index 962ba035bb..ee6e8d0caf 100644 --- a/pyrefly/lib/binding/function.rs +++ b/pyrefly/lib/binding/function.rs @@ -181,7 +181,7 @@ impl<'a> SelfAttrNames<'a> { ( n, InstanceAttribute( - ExprOrBinding::Binding(Binding::Any(AnyStyle::Implicit)), + vec![ExprOrBinding::Binding(Binding::Any(AnyStyle::Implicit))], None, r, MethodSelfKind::Instance, diff --git a/pyrefly/lib/binding/scope.rs b/pyrefly/lib/binding/scope.rs index 807eeb39a4..38c3dd36b3 100644 --- a/pyrefly/lib/binding/scope.rs +++ b/pyrefly/lib/binding/scope.rs @@ -1018,10 +1018,10 @@ impl ScopeClass { /// Produces triples (hashed_attr_name, MethodThatSetsAttr, attribute) for all assignments /// to `self.` in methods. /// - /// We iterate recognized methods first, which - assuming that the first result is the one - /// used in our class logic, which is the case - ensures both that we don't produce - /// unnecessary errors about attributes implicitly defined in unrecognized methods - /// and that the types inferred from recognized methods take precedence. + /// We iterate recognized methods first, which ensures constructor prioritization is + /// established before unrecognized helper methods are processed. This ensures both + /// that we don't produce unnecessary errors about attributes implicitly defined in + /// unrecognized methods, and that constructors take precedence. pub fn method_defined_attributes( self, ) -> impl Iterator, MethodThatSetsAttr, InstanceAttribute)> { @@ -1090,7 +1090,7 @@ pub struct YieldsAndReturns { #[derive(Clone, Debug)] pub struct InstanceAttribute( - pub ExprOrBinding, + pub Vec, pub Option>, pub TextRange, pub MethodSelfKind, @@ -1967,6 +1967,8 @@ impl Scopes { /// (like constructors) that we recognize as always being called. /// /// Returns `true` if the attribute was a self attribute. + /// Record a self attribute assignment (e.g., `self.x = value`) inside the current method scope. + /// We accumulate all assignments to the same attribute within the method so they can later be unioned. pub fn record_self_attr_assign( &mut self, x: &ExprAttribute, @@ -1978,13 +1980,21 @@ impl Scopes { && let Some(self_name) = &method_scope.self_name && matches!(&*x.value, Expr::Name(name) if name.id == self_name.id) { - if !method_scope.instance_attributes.contains_key(&x.attr.id) { + if let Some(attr) = method_scope.instance_attributes.get_mut(&x.attr.id) { + // Accumulate subsequent assignments in the method. + attr.0.push(value); + // Keep the first type annotation encountered in the method. + if attr.1.is_none() { + attr.1 = annotation; + } + } else { + // First time seeing this attribute in this method: record it. method_scope.instance_attributes.insert( x.attr.id.clone(), InstanceAttribute( - value, + vec![value], annotation, - x.attr.range(), + x.attr.range(), // Keep the range of the first assignment as the definition location. method_scope.receiver_kind, ), ); @@ -2864,16 +2874,48 @@ impl Scopes { field_definitions.insert_hashed(name.owned(), (definition, static_info.range)); } }); + // Merge assignments from different methods. + // `method_attrs` yields attributes from recognized constructor methods first (e.g. __init__), + // followed by other helper methods. method_attrs.into_iter().for_each( - |(name, method, InstanceAttribute(value, annotation, range, _))| { - if !field_definitions.contains_key_hashed(name.as_ref()) { + |(name, method, InstanceAttribute(values, annotation, range, receiver_kind))| { + if let Some(( + ClassFieldDefinition::DefinedInMethod { + values: existing_values, + annotation: existing_annot, + method: existing_method, + receiver_kind: existing_receiver, + }, + _, + )) = field_definitions.get_mut(name.key()) + { + if existing_method.recognized_attribute_defining_method + && !method.recognized_attribute_defining_method + { + // Prioritization: Existing is from a recognized constructor, new is from an + // unrecognized helper method. The constructor wins, so ignore the new assignment. + } else { + // Merge: Either both are constructors (e.g. __new__ and __init__), or both are + // helper methods. We combine all their assignments. + existing_values.extend(values); + if existing_annot.is_none() { + *existing_annot = annotation; + } + // If any constructor is a class method (e.g. __new__), the attribute is visible + // on the class object. Upgrade the receiver kind to Class. + if matches!(receiver_kind, MethodSelfKind::Class) { + *existing_receiver = MethodSelfKind::Class; + } + } + } else if !field_definitions.contains_key_hashed(name.as_ref()) { field_definitions.insert_hashed( name, ( ClassFieldDefinition::DefinedInMethod { - value: Box::new(value), + values, annotation, method, + receiver_kind, }, range, ), diff --git a/pyrefly/lib/lsp/wasm/inlay_hints.rs b/pyrefly/lib/lsp/wasm/inlay_hints.rs index e128941454..705c0236ed 100644 --- a/pyrefly/lib/lsp/wasm/inlay_hints.rs +++ b/pyrefly/lib/lsp/wasm/inlay_hints.rs @@ -244,7 +244,7 @@ impl<'a> Transaction<'a> { for field_idx in bindings.keys::() { let field = bindings.get(field_idx); if let ClassFieldDefinition::DefinedInMethod { - value, + values, annotation: None, .. } = &field.definition @@ -253,7 +253,10 @@ impl<'a> Transaction<'a> { continue; }; let ty = answers.solver().for_display(class_field.ty()); - let expr = match value.as_ref() { + let expr = match values + .first() + .expect("DefinedInMethod must have at least one value") + { ExprOrBinding::Expr(e) => Some(e), ExprOrBinding::Binding(_) => None, }; diff --git a/pyrefly/lib/query.rs b/pyrefly/lib/query.rs index 4f95dfbd7a..a8b3a71e3b 100644 --- a/pyrefly/lib/query.rs +++ b/pyrefly/lib/query.rs @@ -1706,9 +1706,6 @@ impl Query { value, annotation, alias_of: _, - } - | ClassFieldDefinition::DefinedInMethod { - value, annotation, .. } => { annotation .and_then(|idx| answers.get_idx(idx)) @@ -1724,6 +1721,23 @@ impl Query { // Final fallback: ClassField.ty() .or_else(|| answers.get_idx(class_field_idx).map(|cf| cf.ty())) } + ClassFieldDefinition::DefinedInMethod { + values, annotation, .. + } => { + annotation + .and_then(|idx| answers.get_idx(idx)) + .and_then(|a| a.annotation.ty.clone()) + // Fall back to expression type trace + .or_else(|| { + if let Some(ExprOrBinding::Expr(expr)) = values.first() { + answers.get_type_trace(expr.range()) + } else { + None + } + }) + // Final fallback: ClassField.ty() + .or_else(|| answers.get_idx(class_field_idx).map(|cf| cf.ty())) + } _ => answers.get_idx(class_field_idx).map(|cf| cf.ty()), }; let field_ty = field_ty?; diff --git a/pyrefly/lib/test/attributes.rs b/pyrefly/lib/test/attributes.rs index fd5cc8d94b..7668a31eb5 100644 --- a/pyrefly/lib/test/attributes.rs +++ b/pyrefly/lib/test/attributes.rs @@ -2664,3 +2664,215 @@ class C[T: (A, B)]: return self.f()[""][0] "#, ); + +testcase!( + test_multiple_assignments_in_same_method, + r#" +from typing import assert_type +class A: + def __init__(self): + self.val = "string" + self.val = 1 +def f(a: A): + assert_type(a.val, str | int) + +class B: + def __init__(self, x: bool): + if x: + self.val = "string" + else: + self.val = 1 +def f2(b: B): + assert_type(b.val, str | int) + "#, +); + +testcase!( + test_assignments_across_multiple_recognized_methods, + r#" +from typing import assert_type +class A: + def __init__(self): + self.val = "string" + def __post_init__(self): + self.val = 1 +def f(a: A): + assert_type(a.val, str | int) + "#, +); + +testcase!( + test_class_level_attribute_priority, + r#" +from typing import assert_type +class A: + val: int = 1 + def __init__(self): + self.val = "string" # E: `Literal['string']` is not assignable to attribute `val` with type `int` +def f(a: A): + assert_type(a.val, int) + +class B: + val = 1 + def __init__(self): + self.val = "string" # E: `Literal['string']` is not assignable to attribute `val` with type `int` +def f2(b: B): + assert_type(b.val, int) + "#, +); + +testcase!( + test_recognized_vs_unrecognized_methods, + r#" +from typing import assert_type, reveal_type +class A: + def __init__(self): + self.val = 1 + def do_work(self): + self.val = "string" # E: `Literal['string']` is not assignable to attribute `val` with type `int` +def f(a: A): + assert_type(a.val, int) + +class B: + def __init__(self): + self.val = None # W: This expression is implicitly inferred to be `Any | None`. Please provide an explicit type annotation. + def do_work(self): + self.val = "string" +def f2(b: B): + reveal_type(b.val) # E: revealed type: Unknown | None + "#, +); + +testcase!( + test_explicit_annotation_override, + r#" +from typing import assert_type + +class A: + def __init__(self): + self.val = "string" # E: `Literal['string']` is not assignable to attribute `val` with type `int` + self.val: int = 1 +def f(a: A): + assert_type(a.val, int) + +class B: + def __init__(self): + self.val: int = 1 + self.val = "string" # E: `Literal['string']` is not assignable to attribute `val` with type `int` +def f2(b: B): + assert_type(b.val, int) + +class C: + def __init__(self, cond: bool): + self.val: int = 1 + if cond: + self.val = "string" # E: `Literal['string']` is not assignable to attribute `val` with type `int` +def f3(c: C): + assert_type(c.val, int) + +class D: + def __init__(self, cond: bool): + self.val = "string" # E: `Literal['string']` is not assignable to attribute `val` with type `int` + if cond: + self.val: int = 1 +def f4(d: D): + assert_type(d.val, int) + +class E: + def __init__(self, cond: bool): + self.val = None # E: `None` is not assignable to attribute `val` with type `int` + if cond: + self.val: int = 1 +def f5(e: E): + assert_type(e.val, int) + "#, +); + +testcase!( + test_none_union_conditional_initialization, + r#" +from typing import reveal_type + +class A: + def __init__(self, x: None | int): + self._x = None # W: This expression is implicitly inferred to be `Any | None`. Please provide an explicit type annotation. + if x: + self._x = x +def f(a: A): + reveal_type(a._x) # E: revealed type: int | Unknown | None + +class B: + def __init__(self, x: None | int): + if x: + self._x = x + else: + self._x = x +def f2(b: B): + reveal_type(b._x) # E: revealed type: int | None + +class C: + def __init__(self, x: None | int): + if not x: + self._y = x + else: + self._y = x +def f3(c: C): + reveal_type(c._y) # E: revealed type: int | None + "#, +); + +testcase!( + test_class_method_receiver_union, + r#" +from typing import reveal_type, assert_type, Literal + +class A: + def __new__(cls): + cls.val = 1 + return super().__new__(cls) + def __init__(self): + self.val = "string" +def f(a: A): + assert_type(a.val, Literal[1, 'string']) + reveal_type(A.val) # E: revealed type: Literal['string', 1] + +class B: + def __init__(self): + self.val = "string" + def __new__(cls): + cls.val = 1 + return super().__new__(cls) +def g(b: B): + assert_type(b.val, Literal['string', 1]) + reveal_type(B.val) # E: revealed type: Literal['string', 1] + "#, +); + +testcase!( + test_class_body_with_method_definition, + r#" +from typing import reveal_type, assert_type + +class A: + val = 0 + def __new__(cls): + cls.val = 1 + return super().__new__(cls) + def __init__(self): + self.val = "string" # E: `Literal['string']` is not assignable to attribute `val` with type `int` +def f(a: A): + reveal_type(a.val) # E: revealed type: int + reveal_type(A.val) # E: revealed type: int + +class B: + val: int | str = 0 + def __new__(cls): + cls.val = 1 + return super().__new__(cls) + def __init__(self): + self.val = "string" +def g(b: B): + reveal_type(b.val) # E: revealed type: int | str + reveal_type(B.val) # E: revealed type: int | str + "#, +);