From cbba57f6fac2a44e01374e9d42ad794c43527720 Mon Sep 17 00:00:00 2001 From: Amer Elsheikh Date: Tue, 16 Jun 2026 14:59:26 +0000 Subject: [PATCH] Fix unannotated attribute type inference bug (issue #1159) Previously, Pyrefly incorrectly locked onto the first textual assignment to an unannotated attribute inside a constructor, leading to false-positive bad-assignment errors if the attribute was assigned different types across different control flow branches or initialization steps. This change modifies Pyrefly's class scope compiler to collect all assignments to an attribute inside constructors (and other recognized setup methods), and then unions their types in the solver. Type annotations, if present, still strictly override and enforce the declared type. --- pyrefly/lib/alt/class/class_field.rs | 51 +++++-- pyrefly/lib/binding/binding.rs | 26 +++- pyrefly/lib/binding/function.rs | 2 +- pyrefly/lib/binding/scope.rs | 64 ++++++-- pyrefly/lib/lsp/wasm/inlay_hints.rs | 7 +- pyrefly/lib/query.rs | 20 ++- pyrefly/lib/test/attributes.rs | 212 +++++++++++++++++++++++++++ 7 files changed, 341 insertions(+), 41 deletions(-) diff --git a/pyrefly/lib/alt/class/class_field.rs b/pyrefly/lib/alt/class/class_field.rs index 1194c8dcdb..a973860697 100644 --- a/pyrefly/lib/alt/class/class_field.rs +++ b/pyrefly/lib/alt/class/class_field.rs @@ -1562,10 +1562,10 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { ) } ClassFieldDefinition::DefinedInMethod { - value, - method, + values, + method: _, annotation: annot, - .. + receiver_kind, } => { let direct_annotation = annot.map(|a| self.get_idx(a).annotation.clone()); // Check if there's an inherited property or descriptor field from a parent class. @@ -1597,20 +1597,39 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { } } - let initialization = match method.instance_or_class { + let initialization = match receiver_kind { MethodSelfKind::Class => ClassFieldInitialization::ClassMethod, MethodSelfKind::Instance => ClassFieldInitialization::Method, }; - let (mut value_ty, annotation, is_inherited) = self.analyze_class_field_value( - value, - class, - name, - direct_annotation.as_ref(), - true, - range, - errors, - ); - if matches!(method.instance_or_class, MethodSelfKind::Instance) { + + // Solve the type of each collected assignment and union them. + let mut union_types = Vec::new(); + let mut overall_annotation = None; + let mut overall_is_inherited = IsInherited::No; + + for value in values { + let (value_ty, annotation, is_inherited) = self.analyze_class_field_value( + value, + class, + name, + direct_annotation.as_ref(), + true, + range, + errors, + ); + union_types.push(value_ty); + if overall_annotation.is_none() { + overall_annotation = annotation; + } + if matches!(is_inherited, IsInherited::Maybe) { + overall_is_inherited = IsInherited::Maybe; + } + } + + // Union all the resolved assignment types. + let mut value_ty = unions(union_types, self.heap); + + if matches!(receiver_kind, MethodSelfKind::Instance) { value_ty = self .check_and_sanitize_type_parameters(class, value_ty, name, range, errors); } @@ -1618,8 +1637,8 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { initialization, false, value_ty, - annotation, - is_inherited, + overall_annotation, + overall_is_inherited, direct_annotation, ) } diff --git a/pyrefly/lib/binding/binding.rs b/pyrefly/lib/binding/binding.rs index 03fdf96c76..4485270453 100644 --- a/pyrefly/lib/binding/binding.rs +++ b/pyrefly/lib/binding/binding.rs @@ -129,7 +129,7 @@ assert_bytes!(BindingClassChecks, 4); assert_bytes!(BindingClassDisjointBase, 4); assert_bytes!(BindingAbstractClassCheck, 4); assert_bytes!(BindingClassSubscriptSymmetry, 4); -assert_words!(BindingClassField, 11); +assert_words!(BindingClassField, 13); assert_bytes!(BindingClassSynthesizedFields, 4); assert_bytes!(BindingLegacyTypeParam, 16); assert_words!(BindingYield, 4); @@ -3028,9 +3028,16 @@ pub enum ClassFieldDefinition { /// Implicitly defined in a method, without any explicit reference /// in the class body. DefinedInMethod { - value: Box, + values: Vec, annotation: Option>, + // Refers to the first method in which the attribute was assigned. + // We prioritize recognized constructors over normal methods; if there are multiple + // constructors, this refers to the first constructor processed. method: MethodThatSetsAttr, + // The combined receiver kind of this class field (upgraded to Class if any constructor + // assigns to it via `cls.`). We track this separately from `method` to avoid mutating + // the method descriptor's own metadata (e.g. keeping `__init__` as an instance method). + receiver_kind: MethodSelfKind, }, } @@ -3079,12 +3086,15 @@ impl DisplayWith for ClassFieldDefinition { ctx.display(*definition), ) } - Self::DefinedInMethod { value, .. } => { - write!( - f, - "ClassFieldDefinition::DefinedInMethod({}, ..)", - value.display_with(ctx), - ) + Self::DefinedInMethod { values, .. } => { + write!(f, "ClassFieldDefinition::DefinedInMethod([")?; + for (i, v) in values.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{}", v.display_with(ctx))?; + } + write!(f, "], ..)") } } } diff --git a/pyrefly/lib/binding/function.rs b/pyrefly/lib/binding/function.rs index 962ba035bb..ee6e8d0caf 100644 --- a/pyrefly/lib/binding/function.rs +++ b/pyrefly/lib/binding/function.rs @@ -181,7 +181,7 @@ impl<'a> SelfAttrNames<'a> { ( n, InstanceAttribute( - ExprOrBinding::Binding(Binding::Any(AnyStyle::Implicit)), + vec![ExprOrBinding::Binding(Binding::Any(AnyStyle::Implicit))], None, r, MethodSelfKind::Instance, diff --git a/pyrefly/lib/binding/scope.rs b/pyrefly/lib/binding/scope.rs index 807eeb39a4..38c3dd36b3 100644 --- a/pyrefly/lib/binding/scope.rs +++ b/pyrefly/lib/binding/scope.rs @@ -1018,10 +1018,10 @@ impl ScopeClass { /// Produces triples (hashed_attr_name, MethodThatSetsAttr, attribute) for all assignments /// to `self.` in methods. /// - /// We iterate recognized methods first, which - assuming that the first result is the one - /// used in our class logic, which is the case - ensures both that we don't produce - /// unnecessary errors about attributes implicitly defined in unrecognized methods - /// and that the types inferred from recognized methods take precedence. + /// We iterate recognized methods first, which ensures constructor prioritization is + /// established before unrecognized helper methods are processed. This ensures both + /// that we don't produce unnecessary errors about attributes implicitly defined in + /// unrecognized methods, and that constructors take precedence. pub fn method_defined_attributes( self, ) -> impl Iterator, MethodThatSetsAttr, InstanceAttribute)> { @@ -1090,7 +1090,7 @@ pub struct YieldsAndReturns { #[derive(Clone, Debug)] pub struct InstanceAttribute( - pub ExprOrBinding, + pub Vec, pub Option>, pub TextRange, pub MethodSelfKind, @@ -1967,6 +1967,8 @@ impl Scopes { /// (like constructors) that we recognize as always being called. /// /// Returns `true` if the attribute was a self attribute. + /// Record a self attribute assignment (e.g., `self.x = value`) inside the current method scope. + /// We accumulate all assignments to the same attribute within the method so they can later be unioned. pub fn record_self_attr_assign( &mut self, x: &ExprAttribute, @@ -1978,13 +1980,21 @@ impl Scopes { && let Some(self_name) = &method_scope.self_name && matches!(&*x.value, Expr::Name(name) if name.id == self_name.id) { - if !method_scope.instance_attributes.contains_key(&x.attr.id) { + if let Some(attr) = method_scope.instance_attributes.get_mut(&x.attr.id) { + // Accumulate subsequent assignments in the method. + attr.0.push(value); + // Keep the first type annotation encountered in the method. + if attr.1.is_none() { + attr.1 = annotation; + } + } else { + // First time seeing this attribute in this method: record it. method_scope.instance_attributes.insert( x.attr.id.clone(), InstanceAttribute( - value, + vec![value], annotation, - x.attr.range(), + x.attr.range(), // Keep the range of the first assignment as the definition location. method_scope.receiver_kind, ), ); @@ -2864,16 +2874,48 @@ impl Scopes { field_definitions.insert_hashed(name.owned(), (definition, static_info.range)); } }); + // Merge assignments from different methods. + // `method_attrs` yields attributes from recognized constructor methods first (e.g. __init__), + // followed by other helper methods. method_attrs.into_iter().for_each( - |(name, method, InstanceAttribute(value, annotation, range, _))| { - if !field_definitions.contains_key_hashed(name.as_ref()) { + |(name, method, InstanceAttribute(values, annotation, range, receiver_kind))| { + if let Some(( + ClassFieldDefinition::DefinedInMethod { + values: existing_values, + annotation: existing_annot, + method: existing_method, + receiver_kind: existing_receiver, + }, + _, + )) = field_definitions.get_mut(name.key()) + { + if existing_method.recognized_attribute_defining_method + && !method.recognized_attribute_defining_method + { + // Prioritization: Existing is from a recognized constructor, new is from an + // unrecognized helper method. The constructor wins, so ignore the new assignment. + } else { + // Merge: Either both are constructors (e.g. __new__ and __init__), or both are + // helper methods. We combine all their assignments. + existing_values.extend(values); + if existing_annot.is_none() { + *existing_annot = annotation; + } + // If any constructor is a class method (e.g. __new__), the attribute is visible + // on the class object. Upgrade the receiver kind to Class. + if matches!(receiver_kind, MethodSelfKind::Class) { + *existing_receiver = MethodSelfKind::Class; + } + } + } else if !field_definitions.contains_key_hashed(name.as_ref()) { field_definitions.insert_hashed( name, ( ClassFieldDefinition::DefinedInMethod { - value: Box::new(value), + values, annotation, method, + receiver_kind, }, range, ), diff --git a/pyrefly/lib/lsp/wasm/inlay_hints.rs b/pyrefly/lib/lsp/wasm/inlay_hints.rs index e128941454..705c0236ed 100644 --- a/pyrefly/lib/lsp/wasm/inlay_hints.rs +++ b/pyrefly/lib/lsp/wasm/inlay_hints.rs @@ -244,7 +244,7 @@ impl<'a> Transaction<'a> { for field_idx in bindings.keys::() { let field = bindings.get(field_idx); if let ClassFieldDefinition::DefinedInMethod { - value, + values, annotation: None, .. } = &field.definition @@ -253,7 +253,10 @@ impl<'a> Transaction<'a> { continue; }; let ty = answers.solver().for_display(class_field.ty()); - let expr = match value.as_ref() { + let expr = match values + .first() + .expect("DefinedInMethod must have at least one value") + { ExprOrBinding::Expr(e) => Some(e), ExprOrBinding::Binding(_) => None, }; diff --git a/pyrefly/lib/query.rs b/pyrefly/lib/query.rs index 4f95dfbd7a..a8b3a71e3b 100644 --- a/pyrefly/lib/query.rs +++ b/pyrefly/lib/query.rs @@ -1706,9 +1706,6 @@ impl Query { value, annotation, alias_of: _, - } - | ClassFieldDefinition::DefinedInMethod { - value, annotation, .. } => { annotation .and_then(|idx| answers.get_idx(idx)) @@ -1724,6 +1721,23 @@ impl Query { // Final fallback: ClassField.ty() .or_else(|| answers.get_idx(class_field_idx).map(|cf| cf.ty())) } + ClassFieldDefinition::DefinedInMethod { + values, annotation, .. + } => { + annotation + .and_then(|idx| answers.get_idx(idx)) + .and_then(|a| a.annotation.ty.clone()) + // Fall back to expression type trace + .or_else(|| { + if let Some(ExprOrBinding::Expr(expr)) = values.first() { + answers.get_type_trace(expr.range()) + } else { + None + } + }) + // Final fallback: ClassField.ty() + .or_else(|| answers.get_idx(class_field_idx).map(|cf| cf.ty())) + } _ => answers.get_idx(class_field_idx).map(|cf| cf.ty()), }; let field_ty = field_ty?; diff --git a/pyrefly/lib/test/attributes.rs b/pyrefly/lib/test/attributes.rs index fd5cc8d94b..7668a31eb5 100644 --- a/pyrefly/lib/test/attributes.rs +++ b/pyrefly/lib/test/attributes.rs @@ -2664,3 +2664,215 @@ class C[T: (A, B)]: return self.f()[""][0] "#, ); + +testcase!( + test_multiple_assignments_in_same_method, + r#" +from typing import assert_type +class A: + def __init__(self): + self.val = "string" + self.val = 1 +def f(a: A): + assert_type(a.val, str | int) + +class B: + def __init__(self, x: bool): + if x: + self.val = "string" + else: + self.val = 1 +def f2(b: B): + assert_type(b.val, str | int) + "#, +); + +testcase!( + test_assignments_across_multiple_recognized_methods, + r#" +from typing import assert_type +class A: + def __init__(self): + self.val = "string" + def __post_init__(self): + self.val = 1 +def f(a: A): + assert_type(a.val, str | int) + "#, +); + +testcase!( + test_class_level_attribute_priority, + r#" +from typing import assert_type +class A: + val: int = 1 + def __init__(self): + self.val = "string" # E: `Literal['string']` is not assignable to attribute `val` with type `int` +def f(a: A): + assert_type(a.val, int) + +class B: + val = 1 + def __init__(self): + self.val = "string" # E: `Literal['string']` is not assignable to attribute `val` with type `int` +def f2(b: B): + assert_type(b.val, int) + "#, +); + +testcase!( + test_recognized_vs_unrecognized_methods, + r#" +from typing import assert_type, reveal_type +class A: + def __init__(self): + self.val = 1 + def do_work(self): + self.val = "string" # E: `Literal['string']` is not assignable to attribute `val` with type `int` +def f(a: A): + assert_type(a.val, int) + +class B: + def __init__(self): + self.val = None # W: This expression is implicitly inferred to be `Any | None`. Please provide an explicit type annotation. + def do_work(self): + self.val = "string" +def f2(b: B): + reveal_type(b.val) # E: revealed type: Unknown | None + "#, +); + +testcase!( + test_explicit_annotation_override, + r#" +from typing import assert_type + +class A: + def __init__(self): + self.val = "string" # E: `Literal['string']` is not assignable to attribute `val` with type `int` + self.val: int = 1 +def f(a: A): + assert_type(a.val, int) + +class B: + def __init__(self): + self.val: int = 1 + self.val = "string" # E: `Literal['string']` is not assignable to attribute `val` with type `int` +def f2(b: B): + assert_type(b.val, int) + +class C: + def __init__(self, cond: bool): + self.val: int = 1 + if cond: + self.val = "string" # E: `Literal['string']` is not assignable to attribute `val` with type `int` +def f3(c: C): + assert_type(c.val, int) + +class D: + def __init__(self, cond: bool): + self.val = "string" # E: `Literal['string']` is not assignable to attribute `val` with type `int` + if cond: + self.val: int = 1 +def f4(d: D): + assert_type(d.val, int) + +class E: + def __init__(self, cond: bool): + self.val = None # E: `None` is not assignable to attribute `val` with type `int` + if cond: + self.val: int = 1 +def f5(e: E): + assert_type(e.val, int) + "#, +); + +testcase!( + test_none_union_conditional_initialization, + r#" +from typing import reveal_type + +class A: + def __init__(self, x: None | int): + self._x = None # W: This expression is implicitly inferred to be `Any | None`. Please provide an explicit type annotation. + if x: + self._x = x +def f(a: A): + reveal_type(a._x) # E: revealed type: int | Unknown | None + +class B: + def __init__(self, x: None | int): + if x: + self._x = x + else: + self._x = x +def f2(b: B): + reveal_type(b._x) # E: revealed type: int | None + +class C: + def __init__(self, x: None | int): + if not x: + self._y = x + else: + self._y = x +def f3(c: C): + reveal_type(c._y) # E: revealed type: int | None + "#, +); + +testcase!( + test_class_method_receiver_union, + r#" +from typing import reveal_type, assert_type, Literal + +class A: + def __new__(cls): + cls.val = 1 + return super().__new__(cls) + def __init__(self): + self.val = "string" +def f(a: A): + assert_type(a.val, Literal[1, 'string']) + reveal_type(A.val) # E: revealed type: Literal['string', 1] + +class B: + def __init__(self): + self.val = "string" + def __new__(cls): + cls.val = 1 + return super().__new__(cls) +def g(b: B): + assert_type(b.val, Literal['string', 1]) + reveal_type(B.val) # E: revealed type: Literal['string', 1] + "#, +); + +testcase!( + test_class_body_with_method_definition, + r#" +from typing import reveal_type, assert_type + +class A: + val = 0 + def __new__(cls): + cls.val = 1 + return super().__new__(cls) + def __init__(self): + self.val = "string" # E: `Literal['string']` is not assignable to attribute `val` with type `int` +def f(a: A): + reveal_type(a.val) # E: revealed type: int + reveal_type(A.val) # E: revealed type: int + +class B: + val: int | str = 0 + def __new__(cls): + cls.val = 1 + return super().__new__(cls) + def __init__(self): + self.val = "string" +def g(b: B): + reveal_type(b.val) # E: revealed type: int | str + reveal_type(B.val) # E: revealed type: int | str + "#, +);