diff --git a/pyrefly/lib/alt/narrow.rs b/pyrefly/lib/alt/narrow.rs index 0a7263f0b4..10a3038c21 100644 --- a/pyrefly/lib/alt/narrow.rs +++ b/pyrefly/lib/alt/narrow.rs @@ -2006,6 +2006,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { narrowing_subject: &NarrowingSubject, narrow_ops_for_fall_through: &(Box, TextRange), subject_range: &TextRange, + display_subject_range: &Option, errors: &ErrorCollector, ) { let (op, narrow_range) = narrow_ops_for_fall_through; @@ -2047,14 +2048,20 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { let subject_display = self.for_display(subject_info.into_ty()); let remaining_display = self.for_display(remaining_ty.clone()); let ctx = TypeDisplayContext::new(&[&subject_display, &remaining_display]); - let mut builder = errors.error_builder( - *subject_range, - ErrorKind::NonExhaustiveMatch, + let message = if let Some(display_subject_range) = display_subject_range { + format!( + "Match on `{}` of type `{}` is not exhaustive", + self.module().code_at(*display_subject_range), + ctx.display(&subject_display) + ) + } else { format!( "Match on `{}` is not exhaustive", ctx.display(&subject_display) - ), - ); + ) + }; + let mut builder = + errors.error_builder(*subject_range, ErrorKind::NonExhaustiveMatch, message); if let Some(missing_cases) = self.format_missing_cases(&remaining_ty) { builder = builder.with_detail(format!("Missing cases: {}", missing_cases)); } diff --git a/pyrefly/lib/alt/solve.rs b/pyrefly/lib/alt/solve.rs index bb519fd536..9aa9b3eb48 100644 --- a/pyrefly/lib/alt/solve.rs +++ b/pyrefly/lib/alt/solve.rs @@ -2344,11 +2344,13 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { narrowing_subject, narrow_ops_for_fall_through, subject_range: range, + display_subject_range, } => self.check_match_exhaustiveness( subject_idx, narrowing_subject, narrow_ops_for_fall_through, range, + display_subject_range, errors, ), BindingExpect::MatchCaseReachability { diff --git a/pyrefly/lib/binding/binding.rs b/pyrefly/lib/binding/binding.rs index f179639d42..d4ae8df7e7 100644 --- a/pyrefly/lib/binding/binding.rs +++ b/pyrefly/lib/binding/binding.rs @@ -1152,6 +1152,7 @@ pub enum BindingExpect { narrowing_subject: NarrowingSubject, narrow_ops_for_fall_through: (Box, TextRange), subject_range: TextRange, + display_subject_range: Option, }, /// A match case whose pattern may not overlap with the current subject type. MatchCaseReachability { diff --git a/pyrefly/lib/binding/pattern.rs b/pyrefly/lib/binding/pattern.rs index 218b211ca2..ec312425e6 100644 --- a/pyrefly/lib/binding/pattern.rs +++ b/pyrefly/lib/binding/pattern.rs @@ -18,7 +18,9 @@ use ruff_python_ast::Number; use ruff_python_ast::Pattern; use ruff_python_ast::PatternKeyword; use ruff_python_ast::StmtMatch; +use ruff_python_ast::name::Name; use ruff_text_size::Ranged; +use ruff_text_size::TextRange; use crate::binding::binding::Binding; use crate::binding::binding::BindingExpect; @@ -48,18 +50,39 @@ enum MatchSubject { None, /// A single match subject (e.g., `match x:`). Single(NarrowingSubject), + /// A local-only subject for matching non-name expressions (e.g., `match await f():`). + /// Python evaluates the subject once before matching, so we need a stable internal + /// subject for branch narrowing while diagnostics still point at the source expression. + Synthetic { + subject: NarrowingSubject, + display_subject_range: TextRange, + }, /// Per-element subjects from a tuple match (e.g., `match x, y:`). Tuple(Vec>), } impl MatchSubject { - /// Extract a single narrowing subject, if this is `Single`. + /// Extract a single narrowing subject, if available. fn as_single(&self) -> Option<&NarrowingSubject> { match self { - MatchSubject::Single(s) => Some(s), + MatchSubject::Single(s) | MatchSubject::Synthetic { subject: s, .. } => Some(s), _ => Option::None, } } + + fn is_synthetic(&self) -> bool { + matches!(self, MatchSubject::Synthetic { .. }) + } + + fn display_subject_range(&self, fallback: TextRange) -> TextRange { + match self { + MatchSubject::Synthetic { + display_subject_range, + .. + } => *display_subject_range, + _ => fallback, + } + } } impl<'a> BindingsBuilder<'a> { @@ -554,7 +577,10 @@ impl<'a> BindingsBuilder<'a> { } else { match expr_to_subjects(&x.subject).first() { Some(s) => MatchSubject::Single(s.clone()), - None => MatchSubject::None, + None => MatchSubject::Synthetic { + subject: NarrowingSubject::Name(Name::new_static("$match_subject")), + display_subject_range: x.subject.range(), + }, } }; let mut exhaustive = false; @@ -667,7 +693,15 @@ impl<'a> BindingsBuilder<'a> { if exhaustive { self.finish_exhaustive_fork(); } else { - let narrow_entries = self.build_narrow_entries(&negated_prev_ops); + let narrow_entries = if match_subject.is_synthetic() { + negated_prev_ops + .0 + .values() + .map(|(op, range)| (subject_idx, Box::new(op.clone()), *range)) + .collect() + } else { + self.build_narrow_entries(&negated_prev_ops) + }; // Create BindingExpect only if we have a narrowing subject (for exhaustiveness warnings) if let Some(narrowing_subject) = match_subject.as_single() && let Some((op, range)) = negated_prev_ops.0.get(narrowing_subject.name()) @@ -678,7 +712,14 @@ impl<'a> BindingsBuilder<'a> { subject_idx, narrowing_subject: narrowing_subject.clone(), narrow_ops_for_fall_through: (Box::new(op.clone()), *range), - subject_range: x.subject.range(), + subject_range: match_subject.display_subject_range(x.subject.range()), + display_subject_range: match match_subject { + MatchSubject::Synthetic { + display_subject_range, + .. + } => Some(display_subject_range), + _ => None, + }, }, ); } diff --git a/pyrefly/lib/test/pattern_match.rs b/pyrefly/lib/test/pattern_match.rs index 64b27827b4..63e1194d20 100644 --- a/pyrefly/lib/test/pattern_match.rs +++ b/pyrefly/lib/test/pattern_match.rs @@ -36,7 +36,7 @@ testcase!( test_pattern_crash, r#" # Used to crash, see https://github.com/facebook/pyrefly/issues/490 -match None: +match None: # E: Missing cases: None case {a: 1}: # E: # E: # E: pass "#, @@ -258,6 +258,95 @@ def f0(x: A | B): "#, ); +testcase!( + test_match_await_exhaustive_no_implicit_return, + r#" +from typing import NoReturn + +class Ok[T]: + __match_args__ = ("value",) + value: T + +class Err[E]: + __match_args__ = ("value",) + value: E + +class NotFound: + pass + +def handle_error(error: NotFound) -> NoReturn: + raise Exception() + +async def get_result() -> Ok[list[int]] | Err[NotFound]: + raise Exception() + +async def f() -> list[int]: + match await get_result(): + case Ok(value): + return value + case Err(error): + handle_error(error) +"#, +); + +testcase!( + test_non_exhaustive_match_call_subject_diagnostic, + r#" +from typing import final + +@final +class Ok[T]: + __match_args__ = ("value",) + value: T + +@final +class Err[E]: + __match_args__ = ("value",) + value: E + +@final +class NotFound: + pass + +def get_result() -> Ok[int] | Err[NotFound]: + raise Exception() + +def f() -> None: + match get_result(): # E: get_result() + case Ok(value): + pass +"#, +); + +testcase!( + test_non_exhaustive_match_await_subject_diagnostic, + r#" +from typing import final + +@final +class Ok[T]: + __match_args__ = ("value",) + value: T + +@final +class Err[E]: + __match_args__ = ("value",) + value: E + +@final +class NotFound: + pass + +async def get_result() -> Ok[int] | Err[NotFound]: + raise Exception() + +async def f() -> None: + match await get_result(): # E: await get_result() + case Ok(value): + pass +"#, +); + testcase!( test_match_sequence_pattern_narrows_tuple_out_of_union, r#" @@ -1078,7 +1167,7 @@ class Color(Enum): def make_color() -> Color: ... def f(y: Color) -> None: - match make_color(): + match make_color(): # E: Missing cases: Color.GREEN case Color.RED as y: return reveal_type(y) # E: revealed type: Color