diff --git a/pyrefly/lib/alt/answers_solver.rs b/pyrefly/lib/alt/answers_solver.rs index 02a770d0b7..ba0deac009 100644 --- a/pyrefly/lib/alt/answers_solver.rs +++ b/pyrefly/lib/alt/answers_solver.rs @@ -72,6 +72,7 @@ use crate::dispatch_anyidx; use crate::error::collector::ErrorCollector; use crate::error::context::ErrorInfo; use crate::error::context::TypeCheckContext; +use crate::error::context::TypeCheckKind; use crate::error::style::ErrorStyle; use crate::export::exports::LookupExport; use crate::module::module_info::ModuleInfo; @@ -3032,7 +3033,26 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { match self.is_subset_eq_with_reason(got, want) { Ok(()) => true, Err(error) => { - self.solver().error(got, want, errors, loc, tcc, error); + let mut extra_lines = Vec::new(); + if matches!( + tcc().kind, + TypeCheckKind::CallArgument(..) + | TypeCheckKind::CallVarArgs(..) + | TypeCheckKind::CallKwArgs(..) + | TypeCheckKind::CallUnpackKwArg(..) + ) && let Some(suggestion) = self.suggest_enum_member_for_value(want, got) + { + extra_lines.push(format!("Did you mean `{suggestion}`?")); + } + self.solver().error_with_extra_lines( + got, + want, + errors, + loc, + tcc, + error, + extra_lines, + ); false } } diff --git a/pyrefly/lib/alt/class/enums.rs b/pyrefly/lib/alt/class/enums.rs index 64c0110f79..9d23625c46 100644 --- a/pyrefly/lib/alt/class/enums.rs +++ b/pyrefly/lib/alt/class/enums.rs @@ -44,6 +44,40 @@ pub const VALUE_PROP: Name = Name::new_static("value"); pub const GENERATE_NEXT_VALUE: Name = Name::new_static("_generate_next_value_"); impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { + /// Suggest an enum member when a raw literal matches exactly one expected enum value. + pub fn suggest_enum_member_for_value(&self, want: &Type, got: &Type) -> Option { + match want { + Type::ClassType(cls) => { + self.suggest_enum_member_for_class_value(cls.class_object(), got) + } + Type::SelfType(cls) => { + self.suggest_enum_member_for_class_value(cls.class_object(), got) + } + Type::Literal(lit) => match &lit.value { + Lit::Enum(lit_enum) => { + self.suggest_enum_member_for_class_value(lit_enum.class.class_object(), got) + } + _ => None, + }, + Type::Union(box crate::types::types::Union { members, .. }) => { + let mut suggestion = None; + for member in members { + if let Some(candidate) = self.suggest_enum_member_for_value(member, got) { + if suggestion + .as_ref() + .is_some_and(|existing: &String| existing != &candidate) + { + return None; + } + suggestion = Some(candidate); + } + } + suggestion + } + _ => None, + } + } + pub fn get_enum_member(&self, cls: &Class, name: &Name) -> Option { self.get_field_from_current_class_only(cls, name) .and_then(|field| self.as_enum_member(Arc::unwrap_or_clone(field), cls)) @@ -60,6 +94,28 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { .unwrap_or_default() } + fn suggest_enum_member_for_class_value(&self, cls: &Class, got: &Type) -> Option { + let is_django = self.get_metadata_for_class(cls).enum_metadata()?.is_django; + let mut suggestion = None; + for lit in self.get_enum_members(cls) { + let Lit::Enum(lit_enum) = lit else { + unreachable!("enum members must be represented as enum literals"); + }; + let value_ty = self.enum_literal_to_value_type((*lit_enum).clone(), is_django); + if self.is_subset_eq(got, &value_ty) && self.is_subset_eq(&value_ty, got) { + let candidate = format!("{}.{}", lit_enum.class.name(), lit_enum.member); + if suggestion + .as_ref() + .is_some_and(|existing: &String| existing != &candidate) + { + return None; + } + suggestion = Some(candidate); + } + } + suggestion + } + fn is_valid_enum_member( &self, name: &Name, diff --git a/pyrefly/lib/solver/solver.rs b/pyrefly/lib/solver/solver.rs index c9498d7602..1333f2ea1b 100644 --- a/pyrefly/lib/solver/solver.rs +++ b/pyrefly/lib/solver/solver.rs @@ -1384,6 +1384,19 @@ impl Solver { loc: TextRange, tcc: &dyn Fn() -> TypeCheckContext, subset_error: SubsetError, + ) { + self.error_with_extra_lines(got, want, errors, loc, tcc, subset_error, Vec::new()); + } + + pub fn error_with_extra_lines( + &self, + got: &Type, + want: &Type, + errors: &ErrorCollector, + loc: TextRange, + tcc: &dyn Fn() -> TypeCheckContext, + subset_error: SubsetError, + extra_lines: Vec, ) { let tcc = tcc(); let msg = tcc.kind.format_error( @@ -1392,6 +1405,9 @@ impl Solver { errors.module().name(), ); let mut msg_lines = vec1![msg]; + for line in extra_lines { + msg_lines.push(line); + } if let Some(subset_error_msg) = subset_error.to_error_msg() { msg_lines.push(subset_error_msg); } diff --git a/pyrefly/lib/test/enums.rs b/pyrefly/lib/test/enums.rs index 671b468696..599020661d 100644 --- a/pyrefly/lib/test/enums.rs +++ b/pyrefly/lib/test/enums.rs @@ -173,6 +173,37 @@ for e in E3: "#, ); +// Regression test for https://github.com/facebook/pyrefly/issues/3128 +testcase!( + test_str_enum_argument_suggestion, + r#" +from enum import StrEnum + +class T(StrEnum): + A = "a" + +def f(t: T) -> None: + pass + +f("a") # E: Argument `Literal['a']` is not assignable to parameter `t` with type `T` in function `f`\n Did you mean `T.A`? +"#, +); + +testcase!( + test_str_enum_argument_suggestion_through_union, + r#" +from enum import StrEnum + +class T(StrEnum): + A = "a" + +def f(t: T | None) -> None: + pass + +f("a") # E: Argument `Literal['a']` is not assignable to parameter `t` with type `T | None` in function `f`\n Did you mean `T.A`? +"#, +); + testcase!( test_value_annotation, r#"