Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions pyrefly/lib/alt/narrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2006,6 +2006,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
narrowing_subject: &NarrowingSubject,
narrow_ops_for_fall_through: &(Box<NarrowOp>, TextRange),
subject_range: &TextRange,
display_subject_range: &Option<TextRange>,
errors: &ErrorCollector,
) {
let (op, narrow_range) = narrow_ops_for_fall_through;
Expand Down Expand Up @@ -2047,14 +2048,20 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
let subject_display = self.for_display(subject_info.into_ty());
let remaining_display = self.for_display(remaining_ty.clone());
let ctx = TypeDisplayContext::new(&[&subject_display, &remaining_display]);
let mut builder = errors.error_builder(
*subject_range,
ErrorKind::NonExhaustiveMatch,
let message = if let Some(display_subject_range) = display_subject_range {
format!(
"Match on `{}` of type `{}` is not exhaustive",
self.module().code_at(*display_subject_range),
ctx.display(&subject_display)
)
} else {
format!(
"Match on `{}` is not exhaustive",
ctx.display(&subject_display)
),
);
)
};
let mut builder =
errors.error_builder(*subject_range, ErrorKind::NonExhaustiveMatch, message);
if let Some(missing_cases) = self.format_missing_cases(&remaining_ty) {
builder = builder.with_detail(format!("Missing cases: {}", missing_cases));
}
Expand Down
2 changes: 2 additions & 0 deletions pyrefly/lib/alt/solve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2344,11 +2344,13 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
narrowing_subject,
narrow_ops_for_fall_through,
subject_range: range,
display_subject_range,
} => self.check_match_exhaustiveness(
subject_idx,
narrowing_subject,
narrow_ops_for_fall_through,
range,
display_subject_range,
errors,
),
BindingExpect::MatchCaseReachability {
Expand Down
1 change: 1 addition & 0 deletions pyrefly/lib/binding/binding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1152,6 +1152,7 @@ pub enum BindingExpect {
narrowing_subject: NarrowingSubject,
narrow_ops_for_fall_through: (Box<NarrowOp>, TextRange),
subject_range: TextRange,
display_subject_range: Option<TextRange>,
},
/// A match case whose pattern may not overlap with the current subject type.
MatchCaseReachability {
Expand Down
51 changes: 46 additions & 5 deletions pyrefly/lib/binding/pattern.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ use ruff_python_ast::Number;
use ruff_python_ast::Pattern;
use ruff_python_ast::PatternKeyword;
use ruff_python_ast::StmtMatch;
use ruff_python_ast::name::Name;
use ruff_text_size::Ranged;
use ruff_text_size::TextRange;

use crate::binding::binding::Binding;
use crate::binding::binding::BindingExpect;
Expand Down Expand Up @@ -48,18 +50,39 @@ enum MatchSubject {
None,
/// A single match subject (e.g., `match x:`).
Single(NarrowingSubject),
/// A local-only subject for matching non-name expressions (e.g., `match await f():`).
/// Python evaluates the subject once before matching, so we need a stable internal
/// subject for branch narrowing while diagnostics still point at the source expression.
Synthetic {
subject: NarrowingSubject,
display_subject_range: TextRange,
},
/// Per-element subjects from a tuple match (e.g., `match x, y:`).
Tuple(Vec<Option<NarrowingSubject>>),
}

impl MatchSubject {
/// Extract a single narrowing subject, if this is `Single`.
/// Extract a single narrowing subject, if available.
fn as_single(&self) -> Option<&NarrowingSubject> {
match self {
MatchSubject::Single(s) => Some(s),
MatchSubject::Single(s) | MatchSubject::Synthetic { subject: s, .. } => Some(s),
_ => Option::None,
}
}

fn is_synthetic(&self) -> bool {
matches!(self, MatchSubject::Synthetic { .. })
}

fn display_subject_range(&self, fallback: TextRange) -> TextRange {

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Returning an Option<TextRange> and leaving it to the caller to determine what they want to do with a None seems more useful. This will also simplify some code below.

match self {
MatchSubject::Synthetic {
display_subject_range,
..
} => *display_subject_range,
_ => fallback,
}
}
}

impl<'a> BindingsBuilder<'a> {
Expand Down Expand Up @@ -554,7 +577,10 @@ impl<'a> BindingsBuilder<'a> {
} else {
match expr_to_subjects(&x.subject).first() {
Some(s) => MatchSubject::Single(s.clone()),
None => MatchSubject::None,
None => MatchSubject::Synthetic {
subject: NarrowingSubject::Name(Name::new_static("$match_subject")),
display_subject_range: x.subject.range(),
},
}
};
let mut exhaustive = false;
Expand Down Expand Up @@ -667,7 +693,15 @@ impl<'a> BindingsBuilder<'a> {
if exhaustive {
self.finish_exhaustive_fork();
} else {
let narrow_entries = self.build_narrow_entries(&negated_prev_ops);
let narrow_entries = if match_subject.is_synthetic() {
negated_prev_ops
.0
.values()
.map(|(op, range)| (subject_idx, Box::new(op.clone()), *range))
.collect()
} else {
self.build_narrow_entries(&negated_prev_ops)
};
// Create BindingExpect only if we have a narrowing subject (for exhaustiveness warnings)
if let Some(narrowing_subject) = match_subject.as_single()
&& let Some((op, range)) = negated_prev_ops.0.get(narrowing_subject.name())
Expand All @@ -678,7 +712,14 @@ impl<'a> BindingsBuilder<'a> {
subject_idx,
narrowing_subject: narrowing_subject.clone(),
narrow_ops_for_fall_through: (Box::new(op.clone()), *range),
subject_range: x.subject.range(),
subject_range: match_subject.display_subject_range(x.subject.range()),

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're effectively matching on match_subject twice here but one of them is hidden. IMO, adding something like

let (subject_range, display_subject_range) = match match_subject {
    MatchSubject::Synthetic { display_subject_range, .. } => (display_subject_range, Some(display_subject_range)),
    _ => (x.subject.range(), None)
};

would make this more explicit and hence easier to follow. (This would also make the display_subject_range function unused.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, I pressed the wrong button. This could also turn into a let display_subject_range = match_subject.display_subject_range(); before the insert and

subject_range: display_subject_range.unwrap_or_else(|| x.subject.range()),
display_subject_range,

display_subject_range: match match_subject {
MatchSubject::Synthetic {
display_subject_range,
..
} => Some(display_subject_range),
_ => None,
},
},
);
}
Expand Down
93 changes: 91 additions & 2 deletions pyrefly/lib/test/pattern_match.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ testcase!(
test_pattern_crash,
r#"
# Used to crash, see https://github.com/facebook/pyrefly/issues/490
match None:
match None: # E: Missing cases: None
case {a: 1}: # E: # E: # E:
pass
"#,
Expand Down Expand Up @@ -258,6 +258,95 @@ def f0(x: A | B):
"#,
);

testcase!(
test_match_await_exhaustive_no_implicit_return,
r#"
from typing import NoReturn

class Ok[T]:
__match_args__ = ("value",)
value: T

class Err[E]:
__match_args__ = ("value",)
value: E

class NotFound:
pass

def handle_error(error: NotFound) -> NoReturn:
raise Exception()

async def get_result() -> Ok[list[int]] | Err[NotFound]:
raise Exception()

async def f() -> list[int]:
match await get_result():
case Ok(value):
return value
case Err(error):
handle_error(error)
"#,
);

testcase!(
test_non_exhaustive_match_call_subject_diagnostic,
r#"
from typing import final

@final
class Ok[T]:
__match_args__ = ("value",)
value: T

@final
class Err[E]:
__match_args__ = ("value",)
value: E

@final
class NotFound:
pass

def get_result() -> Ok[int] | Err[NotFound]:
raise Exception()

def f() -> None:
match get_result(): # E: get_result()
case Ok(value):
pass
"#,
);

testcase!(
test_non_exhaustive_match_await_subject_diagnostic,
r#"
from typing import final

@final
class Ok[T]:
__match_args__ = ("value",)
value: T

@final
class Err[E]:
__match_args__ = ("value",)
value: E

@final
class NotFound:
pass

async def get_result() -> Ok[int] | Err[NotFound]:
raise Exception()

async def f() -> None:
match await get_result(): # E: await get_result()
case Ok(value):
pass
"#,
);

testcase!(
test_match_sequence_pattern_narrows_tuple_out_of_union,
r#"
Expand Down Expand Up @@ -1078,7 +1167,7 @@ class Color(Enum):
def make_color() -> Color: ...

def f(y: Color) -> None:
match make_color():
match make_color(): # E: Missing cases: Color.GREEN
case Color.RED as y:
return
reveal_type(y) # E: revealed type: Color
Expand Down
Loading