Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion pyrefly/lib/alt/answers_solver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ use crate::dispatch_anyidx;
use crate::error::collector::ErrorCollector;
use crate::error::context::ErrorInfo;
use crate::error::context::TypeCheckContext;
use crate::error::context::TypeCheckKind;
use crate::error::style::ErrorStyle;
use crate::export::exports::LookupExport;
use crate::module::module_info::ModuleInfo;
Expand Down Expand Up @@ -3032,7 +3033,26 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
match self.is_subset_eq_with_reason(got, want) {
Ok(()) => true,
Err(error) => {
self.solver().error(got, want, errors, loc, tcc, error);
let mut extra_lines = Vec::new();
if matches!(
tcc().kind,
TypeCheckKind::CallArgument(..)
| TypeCheckKind::CallVarArgs(..)
| TypeCheckKind::CallKwArgs(..)
| TypeCheckKind::CallUnpackKwArg(..)
) && let Some(suggestion) = self.suggest_enum_member_for_value(want, got)
{
extra_lines.push(format!("Did you mean `{suggestion}`?"));
}
self.solver().error_with_extra_lines(
got,
want,
errors,
loc,
tcc,
error,
extra_lines,
);
false
}
}
Expand Down
56 changes: 56 additions & 0 deletions pyrefly/lib/alt/class/enums.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,40 @@ pub const VALUE_PROP: Name = Name::new_static("value");
pub const GENERATE_NEXT_VALUE: Name = Name::new_static("_generate_next_value_");

impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
/// Suggest an enum member when a raw literal matches exactly one expected enum value.
pub fn suggest_enum_member_for_value(&self, want: &Type, got: &Type) -> Option<String> {
match want {
Type::ClassType(cls) => {
self.suggest_enum_member_for_class_value(cls.class_object(), got)
}
Type::SelfType(cls) => {
self.suggest_enum_member_for_class_value(cls.class_object(), got)
}
Type::Literal(lit) => match &lit.value {
Lit::Enum(lit_enum) => {
self.suggest_enum_member_for_class_value(lit_enum.class.class_object(), got)
}
_ => None,
},
Type::Union(box crate::types::types::Union { members, .. }) => {
let mut suggestion = None;
for member in members {
if let Some(candidate) = self.suggest_enum_member_for_value(member, got) {
if suggestion
.as_ref()
.is_some_and(|existing: &String| existing != &candidate)
{
return None;
}
suggestion = Some(candidate);
}
}
suggestion
}
_ => None,
}
}

pub fn get_enum_member(&self, cls: &Class, name: &Name) -> Option<Lit> {
self.get_field_from_current_class_only(cls, name)
.and_then(|field| self.as_enum_member(Arc::unwrap_or_clone(field), cls))
Expand All @@ -60,6 +94,28 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
.unwrap_or_default()
}

fn suggest_enum_member_for_class_value(&self, cls: &Class, got: &Type) -> Option<String> {
let is_django = self.get_metadata_for_class(cls).enum_metadata()?.is_django;
let mut suggestion = None;
for lit in self.get_enum_members(cls) {
let Lit::Enum(lit_enum) = lit else {
unreachable!("enum members must be represented as enum literals");
};
let value_ty = self.enum_literal_to_value_type((*lit_enum).clone(), is_django);
if self.is_subset_eq(got, &value_ty) && self.is_subset_eq(&value_ty, got) {
let candidate = format!("{}.{}", lit_enum.class.name(), lit_enum.member);
if suggestion
.as_ref()
.is_some_and(|existing: &String| existing != &candidate)
{
return None;
}
suggestion = Some(candidate);
}
}
suggestion
}

fn is_valid_enum_member(
&self,
name: &Name,
Expand Down
16 changes: 16 additions & 0 deletions pyrefly/lib/solver/solver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1384,6 +1384,19 @@ impl Solver {
loc: TextRange,
tcc: &dyn Fn() -> TypeCheckContext,
subset_error: SubsetError,
) {
self.error_with_extra_lines(got, want, errors, loc, tcc, subset_error, Vec::new());
}

pub fn error_with_extra_lines(
&self,
got: &Type,
want: &Type,
errors: &ErrorCollector,
loc: TextRange,
tcc: &dyn Fn() -> TypeCheckContext,
subset_error: SubsetError,
extra_lines: Vec<String>,
) {
let tcc = tcc();
let msg = tcc.kind.format_error(
Expand All @@ -1392,6 +1405,9 @@ impl Solver {
errors.module().name(),
);
let mut msg_lines = vec1![msg];
for line in extra_lines {
msg_lines.push(line);
}
if let Some(subset_error_msg) = subset_error.to_error_msg() {
msg_lines.push(subset_error_msg);
}
Expand Down
31 changes: 31 additions & 0 deletions pyrefly/lib/test/enums.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,37 @@ for e in E3:
"#,
);

// Regression test for https://github.com/facebook/pyrefly/issues/3128
testcase!(
test_str_enum_argument_suggestion,
r#"
from enum import StrEnum

class T(StrEnum):
A = "a"

def f(t: T) -> None:
pass

f("a") # E: Argument `Literal['a']` is not assignable to parameter `t` with type `T` in function `f`\n Did you mean `T.A`?
"#,
);

testcase!(
test_str_enum_argument_suggestion_through_union,
r#"
from enum import StrEnum

class T(StrEnum):
A = "a"

def f(t: T | None) -> None:
pass

f("a") # E: Argument `Literal['a']` is not assignable to parameter `t` with type `T | None` in function `f`\n Did you mean `T.A`?
"#,
);

testcase!(
test_value_annotation,
r#"
Expand Down
Loading