diff --git a/pyrefly/lib/alt/class/typed_dict.rs b/pyrefly/lib/alt/class/typed_dict.rs index 046bafb535..6bb48713a4 100644 --- a/pyrefly/lib/alt/class/typed_dict.rs +++ b/pyrefly/lib/alt/class/typed_dict.rs @@ -466,6 +466,27 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { })) } + /// Get a (key, default: ValueType) -> ValueType overload. + fn get_overload_with_value_default( + &self, + metadata: &FuncMetadata, + self_param: &Param, + name: Option<&Name>, + ty: Type, + ) -> OverloadType { + OverloadType::Function(Function { + signature: Callable::list( + ParamList::new(vec![ + self_param.clone(), + self.key_param(name), + Param::PosOnly(Some(DEFAULT_PARAM.clone()), ty.clone(), Required::Required), + ]), + ty, + ), + metadata: metadata.clone(), + }) + } + /// Get a (key, default: T) -> ValueType | T overload. fn get_overload_with_default( &self, @@ -552,6 +573,13 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { ), metadata: metadata.clone(), })); + // (self, key: Literal["key"], default: ValueType) -> ValueType + literal_signatures.push(self.get_overload_with_value_default( + &metadata, + &self_param, + Some(name), + field.ty.clone(), + )); // (self, key: Literal["key"], default: T) -> ValueType | T literal_signatures.push(self.get_overload_with_default( cls, @@ -634,7 +662,15 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { metadata: metadata.clone(), })); - // 2) default: (self, key: Literal["field_name"], default: _T) -> FieldType | _T + // 2) default: (self, key: Literal["field_name"], default: FieldType) -> FieldType + overloads.push(self.get_overload_with_value_default( + metadata, + self_param, + name, + ty.clone(), + )); + + // 3) default: (self, key: Literal["field_name"], default: _T) -> FieldType | _T overloads.push(self.get_overload_with_default(cls, metadata, self_param, name, ty)); } diff --git a/pyrefly/lib/test/typed_dict.rs b/pyrefly/lib/test/typed_dict.rs index 03e037eeea..baa20da6da 100644 --- a/pyrefly/lib/test/typed_dict.rs +++ b/pyrefly/lib/test/typed_dict.rs @@ -931,6 +931,17 @@ def f(c: C): "#, ); +testcase!( + test_get_not_required_literal_default, + r#" +from typing import assert_type, Literal, NotRequired, TypedDict +class C(TypedDict): + x: NotRequired[Literal["a", "b"]] +def f(c: C): + assert_type(c.get("x", "b"), Literal["a", "b"]) + "#, +); + // Clearing a TypedDict is not allowed, since doing so would remove keys it's expected to have. testcase!( test_clear,