From ccab61c86a1b1045dc2b3617c1e246f7212a58e4 Mon Sep 17 00:00:00 2001 From: Lewis Russell Date: Sun, 29 Mar 2026 13:10:31 +0100 Subject: [PATCH] fix(type): preserve unknown unions and narrow pcall(any) Preserve unknown as a real union member in type unions and doc-type parsing instead of collapsing unknown|T to T. Update call sites that previously used Unknown as an empty union accumulator to use Never or explicit first-value handling, and avoid merging missing closure param annotations as unknown. Keep the pcall(any) flow fixes so the payload stays unknown|string outside the branch, narrows to unknown on success, and narrows to string on failure. --- .../compilation/analyzer/doc/infer_type.rs | 7 -- .../src/compilation/analyzer/lua/closure.rs | 38 +++---- .../analyzer/unresolve/find_decl_function.rs | 12 +- .../analyzer/unresolve/resolve_closure.rs | 25 +++-- .../compilation/test/closure_return_test.rs | 45 ++++++++ .../src/compilation/test/member_infer_test.rs | 44 ++++++++ .../src/compilation/test/pcall_test.rs | 24 ++++ .../test/return_overload_flow_test.rs | 65 +++++++++++ .../src/db_index/member/lua_member_item.rs | 6 +- .../src/db_index/signature/return_rows.rs | 7 -- .../src/db_index/type/type_ops/union_type.rs | 2 - .../src/db_index/type/types.rs | 27 +++-- .../src/semantic/generic/call_constraint.rs | 2 +- .../instantiate_func_generic.rs | 14 ++- .../instantiate_special_generic.rs | 6 +- .../src/semantic/generic/mod.rs | 2 +- .../src/semantic/infer/infer_binary/mod.rs | 9 +- .../src/semantic/infer/infer_doc_type.rs | 7 -- .../src/semantic/infer/infer_index/mod.rs | 38 ++++--- .../src/semantic/infer/infer_name.rs | 4 +- .../src/semantic/infer/infer_table.rs | 9 +- .../narrow/condition_flow/correlated_flow.rs | 105 ++++++++++++++++-- .../semantic/infer/narrow/get_type_at_flow.rs | 2 +- .../semantic/infer/narrow/narrow_type/mod.rs | 2 +- 24 files changed, 393 insertions(+), 109 deletions(-) diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/infer_type.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/infer_type.rs index f53b3d9da..761490f3e 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/infer_type.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/infer_type.rs @@ -370,13 +370,6 @@ fn infer_binary_type(analyzer: &mut DocAnalyzer, binary_type: &LuaDocBinaryType) if let Some((left, right)) = binary_type.get_types() { let left_type = infer_type(analyzer, left); let right_type = infer_type(analyzer, right); - if left_type.is_unknown() { - return right_type; - } - if right_type.is_unknown() { - return left_type; - } - if let Some(op) = binary_type.get_op_token() { match op.get_op() { LuaTypeBinaryOperator::Union => match (left_type, right_type) { diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/lua/closure.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/lua/closure.rs index d04ab89ca..a53851c45 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/lua/closure.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/lua/closure.rs @@ -208,31 +208,33 @@ fn analyze_lambda_returns( pub fn analyze_return_point( db: &DbIndex, cache: &mut LuaInferCache, - return_points: &Vec, + return_points: &[LuaReturnPoint], ) -> Result, InferFailReason> { - let mut return_type = LuaType::Unknown; + let mut return_type = None; for point in return_points { - match point { - LuaReturnPoint::Expr(expr) => { - let expr_type = infer_expr(db, cache, expr.clone())?; - return_type = union_return_expr(db, return_type, expr_type); - } + let point_type = match point { + LuaReturnPoint::Expr(expr) => Some(infer_expr(db, cache, expr.clone())?), LuaReturnPoint::MuliExpr(exprs) => { - let mut multi_return = vec![]; + let mut multi_return = Vec::with_capacity(exprs.len()); for expr in exprs { - let expr_type = infer_expr(db, cache, expr.clone())?; - multi_return.push(expr_type); + multi_return.push(infer_expr(db, cache, expr.clone())?); } - let typ = LuaType::Variadic(VariadicType::Multi(multi_return).into()); - return_type = union_return_expr(db, return_type, typ); - } - LuaReturnPoint::Nil => { - return_type = union_return_expr(db, return_type, LuaType::Nil); + Some(LuaType::Variadic(VariadicType::Multi(multi_return).into())) } - _ => {} + LuaReturnPoint::Nil => Some(LuaType::Nil), + _ => None, + }; + + if let Some(point_type) = point_type { + return_type = Some(match return_type { + Some(return_type) => union_return_expr(db, return_type, point_type), + None => point_type, + }); } } + let return_type = return_type.unwrap_or(LuaType::Unknown); + Ok(vec![LuaDocReturnInfo { type_ref: return_type, description: None, @@ -242,10 +244,6 @@ pub fn analyze_return_point( } fn union_return_expr(db: &DbIndex, left: LuaType, right: LuaType) -> LuaType { - if left == LuaType::Unknown { - return right; - } - match (&left, &right) { (LuaType::Variadic(left_variadic), LuaType::Variadic(right_variadic)) => { match (&left_variadic.deref(), &right_variadic.deref()) { diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/unresolve/find_decl_function.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/unresolve/find_decl_function.rs index 8c8db16bc..c91b57e94 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/unresolve/find_decl_function.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/unresolve/find_decl_function.rs @@ -205,7 +205,7 @@ fn find_custom_type_function_member( let key = LuaMemberKey::from_index_key(db, cache, &index_key)?; if let Some(member_item) = db.get_member_index().get_member_item(&owner, &key) { let index_member_id = get_member_id(cache, &index_expr); - let mut result_type = LuaType::Unknown; + let mut result_type = LuaType::Never; for member_id in member_item.get_member_ids() { if index_member_id != member_id && let Some(type_cache) = db.get_type_index().get_type_cache(&member_id.into()) @@ -213,7 +213,7 @@ fn find_custom_type_function_member( result_type = TypeOps::Union.apply(db, &result_type, type_cache.as_type()); } } - if !result_type.is_unknown() { + if !result_type.is_never() { return Ok(result_type); } } @@ -540,7 +540,7 @@ fn find_member_by_index_table( .get_member_index() .get_members(&LuaMemberOwner::Element(table_range.clone())); if let Some(members) = members { - let mut result_type = LuaType::Unknown; + let mut result_type = LuaType::Never; for member in members { let member_key_type = match member.get_key() { LuaMemberKey::Name(s) => LuaType::StringConst(s.clone().into()), @@ -558,7 +558,7 @@ fn find_member_by_index_table( } } - if !result_type.is_unknown() { + if !result_type.is_never() { if matches!( key_type, LuaType::String | LuaType::Number | LuaType::Integer @@ -698,7 +698,7 @@ fn find_member_by_index_union( infer_guard: &InferGuardRef, deep_guard: &mut DeepLevel, ) -> FunctionTypeResult { - let mut member_type = LuaType::Unknown; + let mut member_type = LuaType::Never; for member in union.into_vec() { let result = find_function_type_by_operator( db, @@ -719,7 +719,7 @@ fn find_member_by_index_union( } } - if member_type.is_unknown() { + if member_type.is_never() { return Err(InferFailReason::FieldNotFound); } diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/unresolve/resolve_closure.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/unresolve/resolve_closure.rs index de1c05996..692668d1d 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/unresolve/resolve_closure.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/unresolve/resolve_closure.rs @@ -307,7 +307,8 @@ fn resolve_closure_member_type( .get(&closure_params.signature_id) .ok_or(InferFailReason::None)?; let mut final_params = signature.get_type_params().to_vec(); - let mut final_ret = LuaType::Unknown; + let mut final_ret = LuaType::Never; + let mut has_final_ret = false; let mut multi_function_type = Vec::new(); for typ in union_types.into_vec() { @@ -369,17 +370,21 @@ fn resolve_closure_member_type( break; } - let new_type = TypeOps::Union.apply( - db, - final_param.1.as_ref().unwrap_or(&LuaType::Unknown), - param.1.as_ref().unwrap_or(&LuaType::Unknown), - ); - final_params[idx] = (final_param.0.clone(), Some(new_type)); + let new_type = match (&final_param.1, ¶m.1) { + (Some(final_type), Some(param_type)) => { + Some(TypeOps::Union.apply(db, final_type, param_type)) + } + (Some(final_type), None) => Some(final_type.clone()), + (None, Some(param_type)) => Some(param_type.clone()), + (None, None) => None, + }; + final_params[idx] = (final_param.0.clone(), new_type); } else { final_params.push((param.0.clone(), param.1.clone())); } } + has_final_ret = true; final_ret = TypeOps::Union.apply(db, &final_ret, doc_func.get_ret()); } @@ -389,6 +394,12 @@ fn resolve_closure_member_type( param.1 = Some(variadic_type); } + let final_ret = if !has_final_ret { + LuaType::Unknown + } else { + final_ret + }; + resolve_doc_function( db, closure_params, diff --git a/crates/emmylua_code_analysis/src/compilation/test/closure_return_test.rs b/crates/emmylua_code_analysis/src/compilation/test/closure_return_test.rs index b8351c398..917d8cdac 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/closure_return_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/closure_return_test.rs @@ -80,4 +80,49 @@ mod test { "#, )); } + + #[test] + fn test_inferred_return_preserves_never() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@return { y: number } & { y: string } + local function impossible() end + + local function f() + return impossible().y + end + + result = f() + "#, + ); + + assert_eq!(ws.expr_ty("result"), ws.ty("never")); + } + + #[test] + fn test_member_doc_return_preserves_never() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@return { y: number } & { y: string } + local function impossible() end + + ---@class ClosureTest + ---@field e fun(): never + ---@field e fun(): never + local Test + + function Test.e() + return impossible().y + end + + result = Test.e() + "#, + ); + + assert_eq!(ws.expr_ty("result"), ws.ty("never")); + } } diff --git a/crates/emmylua_code_analysis/src/compilation/test/member_infer_test.rs b/crates/emmylua_code_analysis/src/compilation/test/member_infer_test.rs index b3d164aae..5902b6390 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/member_infer_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/member_infer_test.rs @@ -324,4 +324,48 @@ mod test { value_ty ); } + + #[test] + fn test_union_member_access_preserves_never() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@class A + ---@field y never + + ---@class B + ---@field y never + + ---@return A|B + local function make() end + + local value = make() + + result = value.y + "#, + ); + + assert_eq!(ws.expr_ty("result"), ws.ty("never")); + } + + #[test] + fn test_table_expr_index_preserves_never() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@return { y: number } & { y: string } + local function impossible() end + + local t = { + a = impossible().y, + } + + result = t["a"] + "#, + ); + + assert_eq!(ws.expr_ty("result"), ws.ty("never")); + } } diff --git a/crates/emmylua_code_analysis/src/compilation/test/pcall_test.rs b/crates/emmylua_code_analysis/src/compilation/test/pcall_test.rs index 7700022c1..e1a8cadf3 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/pcall_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/pcall_test.rs @@ -175,4 +175,28 @@ mod test { ); assert_eq!(ws.expr_ty("narrowed"), ws.ty("integer")); } + + #[test] + fn test_pcall_any_callable_splits_success_unknown_and_failure_string() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + ws.def( + r#" + ---@type any + local x + + local ok, result = pcall(x) + outside = result + if ok then + success = result + else + failure = result + end + "#, + ); + + assert_eq!(ws.expr_ty("outside"), ws.ty("unknown|string")); + assert_eq!(ws.expr_ty("success"), ws.ty("unknown")); + assert_eq!(ws.expr_ty("failure"), ws.ty("string")); + } } diff --git a/crates/emmylua_code_analysis/src/compilation/test/return_overload_flow_test.rs b/crates/emmylua_code_analysis/src/compilation/test/return_overload_flow_test.rs index 0cf6b74a8..0f64ddaeb 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/return_overload_flow_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/return_overload_flow_test.rs @@ -77,6 +77,71 @@ mod test { assert_eq!(ws.expr_ty("d"), ws.ty("boolean")); } + #[test] + fn test_return_overload_narrow_with_overlapping_target_union() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@param ok boolean + ---@return boolean + ---@return string|number + ---@return_overload true, string + ---@return_overload false, string|number + local function pick(ok) + if ok then + return true, "value" + end + return false, 1 + end + + local cond ---@type boolean + local ok, result = pick(cond) + + if ok then + success_branch = result + else + failure_branch = result + end + "#, + ); + + assert_eq!(ws.expr_ty("success_branch"), ws.ty("string")); + assert_eq!(ws.expr_ty("failure_branch"), ws.ty("string|number")); + } + + #[test] + fn test_return_overload_narrow_with_overlapping_supertype_target() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@param ok boolean + ---@return boolean + ---@return number + ---@return_overload true, integer + ---@return_overload false, number + local function pick(ok) + if ok then + return true, 1 + end + return false, 1.5 + end + + local cond ---@type boolean + local ok, result = pick(cond) + + if ok then + success_branch = result + else + failure_branch = result + end + "#, + ); + + assert_eq!(ws.expr_ty("success_branch"), ws.ty("integer")); + } + #[test] fn test_return_overload_reassign_clears_multi_return_mapping() { let mut ws = VirtualWorkspace::new(); diff --git a/crates/emmylua_code_analysis/src/db_index/member/lua_member_item.rs b/crates/emmylua_code_analysis/src/db_index/member/lua_member_item.rs index 9709736ee..837b419d8 100644 --- a/crates/emmylua_code_analysis/src/db_index/member/lua_member_item.rs +++ b/crates/emmylua_code_analysis/src/db_index/member/lua_member_item.rs @@ -70,7 +70,7 @@ fn resolve_member_type( match resolve_state { MemberTypeResolveState::All => { - let mut typ = LuaType::Unknown; + let mut typ = LuaType::Never; for member in members { typ = TypeOps::Union.apply( db, @@ -84,7 +84,7 @@ fn resolve_member_type( Ok(typ) } MemberTypeResolveState::Meta => { - let mut typ = LuaType::Unknown; + let mut typ = LuaType::Never; for member in &members { let feature = member.get_feature(); if feature.is_meta_decl() { @@ -101,7 +101,7 @@ fn resolve_member_type( Ok(typ) } MemberTypeResolveState::FileDecl => { - let mut typ = LuaType::Unknown; + let mut typ = LuaType::Never; for member in &members { let feature = member.get_feature(); if feature.is_file_decl() { diff --git a/crates/emmylua_code_analysis/src/db_index/signature/return_rows.rs b/crates/emmylua_code_analysis/src/db_index/signature/return_rows.rs index 18e2217d6..2ae56aacc 100644 --- a/crates/emmylua_code_analysis/src/db_index/signature/return_rows.rs +++ b/crates/emmylua_code_analysis/src/db_index/signature/return_rows.rs @@ -164,13 +164,6 @@ fn merge_return_rows(left_row: &[LuaType], right_row: &[LuaType]) -> LuaType { } fn merge_return_type(left: LuaType, right: LuaType) -> LuaType { - if left == LuaType::Unknown { - return right; - } - if right == LuaType::Unknown { - return left; - } - match (&left, &right) { (LuaType::Variadic(_), _) | (_, LuaType::Variadic(_)) => { let left_row = return_type_to_row(left); diff --git a/crates/emmylua_code_analysis/src/db_index/type/type_ops/union_type.rs b/crates/emmylua_code_analysis/src/db_index/type/type_ops/union_type.rs index b1a2f6458..dc3522dd6 100644 --- a/crates/emmylua_code_analysis/src/db_index/type/type_ops/union_type.rs +++ b/crates/emmylua_code_analysis/src/db_index/type/type_ops/union_type.rs @@ -21,8 +21,6 @@ fn union_type_impl(match_source: &LuaType, source: LuaType, target: LuaType) -> (_, LuaType::Any) => LuaType::Any, (LuaType::Never, _) => target, (_, LuaType::Never) => source, - (LuaType::Unknown, _) => target, - (_, LuaType::Unknown) => source, // int | int const (LuaType::Integer, LuaType::IntegerConst(_) | LuaType::DocIntegerConst(_)) => { LuaType::Integer diff --git a/crates/emmylua_code_analysis/src/db_index/type/types.rs b/crates/emmylua_code_analysis/src/db_index/type/types.rs index a6982072d..c99e7a003 100644 --- a/crates/emmylua_code_analysis/src/db_index/type/types.rs +++ b/crates/emmylua_code_analysis/src/db_index/type/types.rs @@ -670,7 +670,7 @@ impl LuaTupleType { } pub fn cast_down_array_base(&self, db: &DbIndex) -> LuaType { - let mut ty = LuaType::Unknown; + let mut ty = LuaType::Never; for t in &self.types { match t { LuaType::IntegerConst(i) => { @@ -688,7 +688,11 @@ impl LuaTupleType { } } - ty + if self.types.is_empty() { + LuaType::Unknown + } else { + ty + } } pub fn is_infer_resolve(&self) -> bool { @@ -929,18 +933,16 @@ impl LuaObjectType { let mut ty = None; for (key, value_type) in self.index_access.iter() { if matches!(key, LuaType::Integer) { - if ty.is_none() { - ty = Some(LuaType::Unknown); - } - if let Some(t) = ty { - ty = Some(TypeOps::Union.apply(db, &t, value_type)); - } + ty = Some(match ty { + Some(t) => TypeOps::Union.apply(db, &t, value_type), + None => value_type.clone(), + }); } } return ty; } - let mut ty = LuaType::Unknown; + let mut ty = None; let mut count = 1; let mut fields = self.fields.iter().collect::>(); @@ -960,10 +962,13 @@ impl LuaObjectType { count += 1; - ty = TypeOps::Union.apply(db, &ty, value_type); + ty = Some(match ty { + Some(t) => TypeOps::Union.apply(db, &t, value_type), + None => value_type.clone(), + }); } - Some(ty) + Some(ty.unwrap_or(LuaType::Unknown)) } } diff --git a/crates/emmylua_code_analysis/src/semantic/generic/call_constraint.rs b/crates/emmylua_code_analysis/src/semantic/generic/call_constraint.rs index fec8d6717..cb8f29059 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/call_constraint.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/call_constraint.rs @@ -228,7 +228,7 @@ fn get_constraint_type( if depth > 1 { return None; } - let mut result = LuaType::Unknown; + let mut result = LuaType::Never; for union_member_type in union_type.into_vec().iter() { let extend_type = get_constraint_type(semantic_model, union_member_type, depth + 1) .unwrap_or(union_member_type.clone()); diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs index d0a6b24c9..0fd323381 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs @@ -266,10 +266,18 @@ fn infer_generic_types_from_call( if let Some(return_pattern) = as_doc_function_type(context.db, func_param_type)?.map(|func| func.get_ret().clone()) - && let Some(inferred_return_type) = - infer_callable_return_from_remaining_args(context, &arg_type, &arg_exprs[i + 1..])? { - return_type_pattern_match_target_type(context, &return_pattern, &inferred_return_type)?; + if let Some(inferred_return_type) = + infer_callable_return_from_remaining_args(context, &arg_type, &arg_exprs[i + 1..])? + { + return_type_pattern_match_target_type( + context, + &return_pattern, + &inferred_return_type, + )?; + } else if arg_type.is_any() || arg_type.is_unknown() { + return_type_pattern_match_target_type(context, &return_pattern, &LuaType::Unknown)?; + } } match (func_param_type, &arg_type) { diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_special_generic.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_special_generic.rs index 8daf4ed5f..f295434db 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_special_generic.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_special_generic.rs @@ -277,7 +277,7 @@ fn instantiate_unpack_call(db: &DbIndex, operands: &[LuaType]) -> LuaType { for i in 1..10 { let member_key = LuaMemberKey::Integer(i); if let Some(member_info) = members.get(&member_key) { - let mut member_type = LuaType::Unknown; + let mut member_type = LuaType::Never; for sub_member_info in member_info { member_type = TypeOps::Union.apply(db, &member_type, &sub_member_info.typ); } @@ -305,6 +305,10 @@ fn instantiate_rawget_call(db: &DbIndex, owner: &LuaType, key: &LuaType) -> LuaT } fn instantiate_index_call(db: &DbIndex, owner: &LuaType, key: &LuaType) -> LuaType { + if owner.is_unknown() { + return LuaType::Unknown; + } + if let LuaType::Variadic(variadic) = owner { match variadic.deref() { VariadicType::Base(base) => { diff --git a/crates/emmylua_code_analysis/src/semantic/generic/mod.rs b/crates/emmylua_code_analysis/src/semantic/generic/mod.rs index 4b60b91ea..d66ae97bc 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/mod.rs @@ -103,7 +103,7 @@ pub fn get_tpl_ref_extend_type( if depth > 1 { return None; } - let mut result = LuaType::Unknown; + let mut result = LuaType::Never; for union_member_type in union_type.into_vec().iter() { let extend_type = get_tpl_ref_extend_type( db, diff --git a/crates/emmylua_code_analysis/src/semantic/infer/infer_binary/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_binary/mod.rs index 340230045..53fc80337 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/infer_binary/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_binary/mod.rs @@ -71,7 +71,7 @@ fn infer_union_binary_expr( return None; }; - let mut result = LuaType::Unknown; + let mut result = None; let types = u.into_vec(); for ty in types.iter() { // 只在实际调用时才 clone,而不是预先 clone @@ -82,10 +82,13 @@ fn infer_union_binary_expr( }; if let Ok(ty) = ty_result { - result = TypeOps::Union.apply(db, &result, &ty); + result = Some(match result { + Some(result) => TypeOps::Union.apply(db, &result, &ty), + None => ty, + }); } } - Some(result) + Some(result.unwrap_or(LuaType::Unknown)) } fn infer_binary_expr_type( diff --git a/crates/emmylua_code_analysis/src/semantic/infer/infer_doc_type.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_doc_type.rs index bade70952..d5f176d72 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/infer_doc_type.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_doc_type.rs @@ -298,13 +298,6 @@ fn infer_binary_type(ctx: DocTypeInferContext<'_>, binary_type: &LuaDocBinaryTyp if let Some((left, right)) = binary_type.get_types() { let left_type = infer_doc_type(ctx, &left); let right_type = infer_doc_type(ctx, &right); - if left_type.is_unknown() { - return right_type; - } - if right_type.is_unknown() { - return left_type; - } - if let Some(op) = binary_type.get_op_token() { match op.get_op() { LuaTypeBinaryOperator::Union => match (left_type, right_type) { diff --git a/crates/emmylua_code_analysis/src/semantic/infer/infer_index/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_index/mod.rs index 23709282e..57a820997 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/infer_index/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_index/mod.rs @@ -450,7 +450,7 @@ fn infer_tuple_member( }; } LuaType::Integer => { - let mut result = LuaType::Unknown; + let mut result = LuaType::Never; for typ in tuple_type.get_types() { result = TypeOps::Union.apply(db, &result, typ); } @@ -547,7 +547,9 @@ fn infer_union_member( index_expr: LuaIndexMemberExpr, infer_guard: &InferGuardRef, ) -> InferResult { - let mut member_types = Vec::new(); + let mut member_type = LuaType::Never; + let mut has_member = false; + let mut has_missing_member = false; let mut meet_string = false; for sub_type in union_type.into_vec() { if sub_type.is_string() { @@ -563,20 +565,26 @@ fn infer_union_member( index_expr.clone(), &infer_guard.fork(), ); - if let Ok(typ) = result { - if !typ.is_never() { - member_types.push(typ); + match result { + Ok(typ) => { + has_member = true; + member_type = TypeOps::Union.apply(db, &member_type, &typ); + } + Err(_) => { + has_missing_member = true; } - } else { - member_types.push(LuaType::Nil); } } - if member_types.iter().all(|t| t.is_nil()) { + if !has_member { return Err(InferFailReason::FieldNotFound); } - Ok(LuaType::from_vec(member_types)) + if has_missing_member { + member_type = TypeOps::Union.apply(db, &member_type, &LuaType::Nil); + } + + Ok(member_type) } fn infer_intersection_member( @@ -816,7 +824,8 @@ fn infer_member_by_index_table( .get_members(&LuaMemberOwner::Element(table_range.clone())); if let Some(mut members) = members { members.sort_by(|a, b| a.get_key().cmp(b.get_key())); - let mut result_type = LuaType::Unknown; + let mut result_type = LuaType::Never; + let mut has_match = false; for member in members { let member_key_type = match member.get_key() { LuaMemberKey::Name(s) => LuaType::StringConst(s.clone().into()), @@ -830,11 +839,12 @@ fn infer_member_by_index_table( .map(|it| it.as_type()) .unwrap_or(&LuaType::Unknown); + has_match = true; result_type = TypeOps::Union.apply(db, &result_type, member_type); } } - if !result_type.is_unknown() { + if has_match { if matches!( key_type, LuaType::String | LuaType::Number | LuaType::Integer @@ -963,12 +973,14 @@ fn infer_member_by_index_union( index_expr: LuaIndexMemberExpr, infer_guard: &InferGuardRef, ) -> InferResult { - let mut member_type = LuaType::Unknown; + let mut member_type = LuaType::Never; + let mut has_member = false; for member in union.into_vec() { let result = infer_member_by_operator(db, cache, &member, index_expr.clone(), &infer_guard.fork()); match result { Ok(typ) => { + has_member = true; member_type = TypeOps::Union.apply(db, &member_type, &typ); } Err(InferFailReason::FieldNotFound) => {} @@ -978,7 +990,7 @@ fn infer_member_by_index_union( } } - if member_type.is_unknown() { + if !has_member { return Err(InferFailReason::FieldNotFound); } diff --git a/crates/emmylua_code_analysis/src/semantic/infer/infer_name.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_name.rs index dd652510e..55e58b010 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/infer_name.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_name.rs @@ -361,7 +361,7 @@ pub fn infer_global_type(db: &DbIndex, name: &str) -> InferResult { }); // TODO: 或许应该联合所有定义的类型? - let mut valid_type = LuaType::Unknown; + let mut valid_type = LuaType::Never; let mut last_resolve_reason = InferFailReason::None; for decl_id in sorted_decl_ids { let decl_type_cache = db.get_type_index().get_type_cache(&decl_id.into()); @@ -394,7 +394,7 @@ pub fn infer_global_type(db: &DbIndex, name: &str) -> InferResult { } } - if !valid_type.is_unknown() { + if !valid_type.is_never() { return Ok(valid_type); } diff --git a/crates/emmylua_code_analysis/src/semantic/infer/infer_table.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_table.rs index 1dde01ab5..190f3329d 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/infer_table.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_table.rs @@ -273,17 +273,20 @@ fn infer_table_type_by_callee( /// 移除掉一些非`table`类型 fn union_remove_non_table_type(db: &DbIndex, union: &Arc) -> LuaType { - let mut result = LuaType::Unknown; + let mut result = None; for typ in union.into_set().into_iter() { match typ { LuaType::Signature(_) | LuaType::DocFunction(_) => {} _ if typ.is_string() || typ.is_number() || typ.is_boolean() => {} _ => { - result = TypeOps::Union.apply(db, &result, &typ); + result = Some(match result { + Some(result) => TypeOps::Union.apply(db, &result, &typ), + None => typ, + }); } } } - result + result.unwrap_or(LuaType::Unknown) } fn infer_table_field_type_by_parent( diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/correlated_flow.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/correlated_flow.rs index 5b9b88a4c..14c9b18b8 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/correlated_flow.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/correlated_flow.rs @@ -7,7 +7,7 @@ use crate::{ LuaInferCache, LuaSignature, LuaType, TypeOps, infer_expr, instantiate_func_generic, semantic::infer::{ VarRefId, - narrow::{get_single_antecedent, get_type_at_flow::get_type_at_flow}, + narrow::{get_single_antecedent, get_type_at_flow::get_type_at_flow, narrow_down_type}, }, }; @@ -50,8 +50,11 @@ impl CorrelatedConditionNarrowing { None } else { let matching_target_type = LuaType::from_vec(matching_target_types.clone()); - let narrowed_correlated_type = - TypeOps::Intersect.apply(db, &antecedent_type, &matching_target_type); + let narrowed_correlated_type = narrow_matching_correlated_type( + db, + antecedent_type.clone(), + &matching_target_type, + ); if narrowed_correlated_type.is_never() { None } else { @@ -230,7 +233,7 @@ fn collect_correlated_types_from_search_root( root_uncorrelated_target_types.push(root_type); } } else { - let mut known_call_target_types = root_correlated_candidate_types; + let mut known_call_target_types = root_correlated_candidate_types.clone(); known_call_target_types.extend(root_uncorrelated_target_types.iter().cloned()); if let Some(remaining_root_type) = get_type_at_flow(db, tree, cache, root, var_ref_id, search_root_flow_id) @@ -261,14 +264,14 @@ fn subtract_correlated_candidate_types( .into_vec() .into_iter() .filter(|member| { - !correlated_candidate_types.iter().any(|correlated_type| { - TypeOps::Union.apply(db, correlated_type, member) == *correlated_type - }) + !correlated_candidate_types + .iter() + .any(|correlated_type| correlated_type_contains(db, correlated_type, member)) }) .collect::>(), - source_type => (!correlated_candidate_types.iter().any(|correlated_type| { - TypeOps::Union.apply(db, correlated_type, &source_type) == *correlated_type - })) + source_type => (!correlated_candidate_types + .iter() + .any(|correlated_type| correlated_type_contains(db, correlated_type, &source_type))) .then_some(source_type) .into_iter() .collect(), @@ -277,6 +280,88 @@ fn subtract_correlated_candidate_types( (!remaining_types.is_empty()).then_some(LuaType::from_vec(remaining_types)) } +fn narrow_matching_correlated_type( + db: &DbIndex, + antecedent_type: LuaType, + matching_target_type: &LuaType, +) -> LuaType { + if let LuaType::Union(union) = matching_target_type { + let narrowed_types = union + .into_vec() + .into_iter() + .filter_map(|member| { + let narrowed = + narrow_matching_correlated_type(db, antecedent_type.clone(), &member); + (!narrowed.is_never()).then_some(narrowed) + }) + .collect::>(); + + return if narrowed_types.is_empty() { + LuaType::Never + } else { + LuaType::from_vec(narrowed_types) + }; + } + + if matching_target_type.is_unknown() + && let LuaType::Union(union) = &antecedent_type + { + let exact_unknown_types = union + .into_vec() + .into_iter() + .filter(|member| member.is_unknown()) + .collect::>(); + if !exact_unknown_types.is_empty() { + return LuaType::from_vec(exact_unknown_types); + } + } + + if let Some(narrowed_type) = narrow_down_type( + db, + antecedent_type.clone(), + matching_target_type.clone(), + None, + ) { + return narrowed_type; + } + + match antecedent_type { + LuaType::Union(union) => { + let narrowed_types = union + .into_vec() + .into_iter() + .filter_map(|member| { + let narrowed = + narrow_matching_correlated_type(db, member, matching_target_type); + (!narrowed.is_never()).then_some(narrowed) + }) + .collect::>(); + + if narrowed_types.is_empty() { + LuaType::Never + } else { + LuaType::from_vec(narrowed_types) + } + } + antecedent_type => TypeOps::Intersect.apply(db, &antecedent_type, matching_target_type), + } +} + +fn correlated_type_contains(db: &DbIndex, container: &LuaType, target: &LuaType) -> bool { + if target.is_unknown() && !container.is_any() { + return match container { + LuaType::Unknown => true, + LuaType::Union(union) => union + .into_vec() + .iter() + .any(|member| correlated_type_contains(db, member, target)), + _ => false, + }; + } + + TypeOps::Union.apply(db, container, target) == *container +} + #[allow(clippy::too_many_arguments)] fn collect_matching_correlated_types( db: &DbIndex, diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/get_type_at_flow.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/get_type_at_flow.rs index 416d942c9..0699c2a14 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/get_type_at_flow.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/get_type_at_flow.rs @@ -149,7 +149,7 @@ fn get_type_at_flow_internal( FlowNodeKind::BranchLabel | FlowNodeKind::NamedLabel(_) => { let multi_antecedents = get_multi_antecedents(tree, flow_node)?; - let mut branch_result_type = LuaType::Unknown; + let mut branch_result_type = LuaType::Never; for &flow_id in &multi_antecedents { let branch_type = get_type_at_flow_internal( db, diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/narrow_type/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/narrow_type/mod.rs index f6d62ea97..13275c559 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/narrow_type/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/narrow_type/mod.rs @@ -210,7 +210,7 @@ pub fn narrow_down_type( if source_types.is_empty() { return None; } - let mut result_type = LuaType::Unknown; + let mut result_type = LuaType::Never; for source_type in source_types { result_type = TypeOps::Union.apply(db, &result_type, &source_type); }