From 8b03352bdd08a55075d3174a9a4ec46059c923c2 Mon Sep 17 00:00:00 2001 From: Lewis Russell Date: Fri, 20 Mar 2026 14:01:39 +0000 Subject: [PATCH 1/5] fix(generic): stabilize higher-order callable overload inference When inferring higher-order generic calls from remaining arguments, the old path kept partially instantiated callable candidates and treated unresolved callback signatures too loosely. That could produce unknown returns, miss callable union and intersection members, and make named unresolved callbacks behave differently from inline closures. Fix this by collecting callable overload candidates from doc functions, signatures, aliases, unions, and intersections, instantiating each candidate from the remaining argument types, and resolving the best candidate with resolve_signature_by_args. Drop candidates whose argument matching fails or whose instantiated return still contains unresolved local template refs. Before overload resolution, prefer candidates with informative returns over ones that only narrow to unknown, so opaque callback returns fall back to concrete callable overloads instead of winning generic ones. Also match unresolved callback signatures by inferred signature id rather than only by closure source position so named callback values and inline closures follow the same generic matching path. Add regression coverage for higher-order callback overload ranking with concrete, unknown, and unresolved callback returns, for preserving callback parameter shape across integer and string branches, and for pcall callable union and intersection return inference. --- .../test/callable_return_infer_test.rs | 115 ++++++++ .../src/compilation/test/pcall_test.rs | 106 ++++++- .../instantiate_func_generic.rs | 274 ++++++++++++++++-- .../generic/tpl_pattern/lambda_tpl_pattern.rs | 20 +- .../src/semantic/infer/infer_call/mod.rs | 15 + .../src/semantic/overload_resolve/mod.rs | 2 +- 6 files changed, 489 insertions(+), 43 deletions(-) diff --git a/crates/emmylua_code_analysis/src/compilation/test/callable_return_infer_test.rs b/crates/emmylua_code_analysis/src/compilation/test/callable_return_infer_test.rs index 5771d9442..fd60b4e62 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/callable_return_infer_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/callable_return_infer_test.rs @@ -85,4 +85,119 @@ mod test { assert_eq!(ws.expr_ty("result"), ws.ty("string")); } + + #[test] + fn test_apply_return_infer_prefers_informative_callback_overloads_and_keeps_param_shape() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic A, R + ---@param f fun(x: A): R + ---@param x A + ---@return R + local function apply(f, x) + return f(x) + end + + ---@overload fun(cb: fun(): T): T + ---@overload fun(cb: function): boolean + local function run(cb) end + + ---@overload fun(cb: fun(x: integer): T): integer + ---@overload fun(cb: fun(x: string): T): string + ---@overload fun(cb: function): boolean + local function classify(cb) end + + ---@overload fun(cb: fun(): T): { value: T } + ---@overload fun(cb: function): boolean + local function wrap(cb) end + + local source ---@type table + + ---@return integer + local function cb_concrete() + return 1 + end + + ---@type fun(): unknown + local cb_unknown + + ---@type fun(x: integer): unknown + local cb_param_unknown + + ---@type fun(x: string): unknown + local cb_param_unknown_string + + ---@param x integer + local function cb_param_unresolved(x) + return source.missing + end + + ---@param x string + local function cb_param_unresolved_string(x) + return source.missing + end + + local function cb_named_unresolved() + return source.missing + end + + run_concrete = apply(run, cb_concrete) + + run_unknown = apply(run, cb_unknown) + + run_unresolved = apply(run, function() + return source.missing + end) + + run_named_unresolved = apply(run, cb_named_unresolved) + + wrap_named_unresolved = apply(wrap, cb_named_unresolved) + + classify_unknown = apply(classify, cb_param_unknown) + + classify_unresolved = apply(classify, cb_param_unresolved) + + classify_string_unknown = apply(classify, cb_param_unknown_string) + + classify_string_unresolved = apply(classify, cb_param_unresolved_string) + "#, + ); + + // The callback return is concrete, so `T` is inferred as `integer` and the generic + // overload is more informative than the `function -> boolean` fallback. + assert_eq!(ws.expr_ty("run_concrete"), ws.ty("integer")); + + // `fun(): unknown` matches the generic parameter shape, but its return stays opaque, so + // the concrete `function -> boolean` fallback is the best available result. + assert_eq!(ws.expr_ty("run_unknown"), ws.ty("boolean")); + + // An unresolved closure return should be treated the same as an explicit `unknown` + // return, so this should resolve to the same fallback `boolean`. + assert_eq!(ws.expr_ty("run_unresolved"), ws.ty("boolean")); + + // The named-callback path should detect the unresolved closure signature too, so this + // should stay aligned with the inline unresolved callback case above. + assert_eq!(ws.expr_ty("run_named_unresolved"), ws.ty("boolean")); + + // If the instantiated return still embeds the unresolved template, the generic overload + // should be discarded and the concrete fallback should win. + assert_eq!(ws.expr_ty("wrap_named_unresolved"), ws.ty("boolean")); + + // The callback's parameter type is known, so the generic `fun(x: integer): T` overload + // should still win even though the callback return is only `unknown`. + assert_eq!(ws.expr_ty("classify_unknown"), ws.ty("integer")); + + // The callback return is unresolved, but its parameter is still `integer`, so overload + // ranking should keep using that known shape and pick the generic integer branch. + assert_eq!(ws.expr_ty("classify_unresolved"), ws.ty("integer")); + + // The callback's parameter type is `string`, so overload selection should not fall back + // to the first generic branch when the callback return is only `unknown`. + assert_eq!(ws.expr_ty("classify_string_unknown"), ws.ty("string")); + + // The same `string`-parameter branch should still win when the callback return is + // unresolved and carried through a named callback value. + assert_eq!(ws.expr_ty("classify_string_unresolved"), ws.ty("string")); + } } diff --git a/crates/emmylua_code_analysis/src/compilation/test/pcall_test.rs b/crates/emmylua_code_analysis/src/compilation/test/pcall_test.rs index e1a8cadf3..0f23c3352 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/pcall_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/pcall_test.rs @@ -1,6 +1,6 @@ #[cfg(test)] mod test { - use crate::{DiagnosticCode, VirtualWorkspace}; + use crate::{DiagnosticCode, LuaType, VirtualWorkspace}; const STACKED_PCALL_ALIAS_GUARDS: usize = 180; @@ -199,4 +199,108 @@ mod test { assert_eq!(ws.expr_ty("success"), ws.ty("unknown")); assert_eq!(ws.expr_ty("failure"), ws.ty("string")); } + + #[test] + fn test_return_overload_unions_callable_union_members_with_same_domain() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + ws.def( + r#" + ---@alias FnA fun(x: integer): integer + ---@alias FnB fun(x: integer): boolean + + ---@type FnA | FnB + local run + + _, a = pcall(run, 1) + "#, + ); + + assert_eq!( + ws.expr_ty("a"), + LuaType::from_vec(vec![LuaType::Integer, LuaType::Boolean, LuaType::String]) + ); + } + + #[test] + fn test_return_overload_callable_union_is_not_order_dependent() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + ws.def( + r#" + ---@alias FnA fun(x: integer): integer + ---@alias FnB fun(x: integer): boolean + + ---@type FnB | FnA + local run + + _, a = pcall(run, 1) + "#, + ); + + assert_eq!( + ws.expr_ty("a"), + LuaType::from_vec(vec![LuaType::Integer, LuaType::Boolean, LuaType::String]) + ); + } + + #[test] + fn test_return_overload_unions_callable_intersection_members_with_same_domain() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + ws.def( + r#" + ---@alias FnA fun(x: integer): integer + ---@alias FnB fun(x: integer): boolean + + ---@type FnA & FnB + local run + + _, a = pcall(run, 1) + "#, + ); + + assert_eq!( + ws.expr_ty("a"), + LuaType::from_vec(vec![LuaType::Integer, LuaType::Boolean, LuaType::String]) + ); + } + + #[test] + fn test_return_overload_narrows_callable_union_member_by_arg_shape() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + ws.def( + r#" + ---@alias FnA fun(x: integer): integer + ---@alias FnB fun(x: string): integer + + ---@type FnA | FnB + local run + + _, a = pcall(run, 1) + "#, + ); + + assert_eq!(ws.expr_ty("a"), ws.ty("integer|string")); + } + + #[test] + fn test_return_overload_narrows_callable_intersection_member_by_arg_shape() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + ws.def( + r#" + ---@alias FnA fun(x: integer): integer + ---@alias FnB fun(x: string): boolean + + ---@type FnA & FnB + local run + + _, a = pcall(run, 1) + "#, + ); + + assert_eq!(ws.expr_ty("a"), ws.ty("integer|string")); + } } diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs index 39f09fcd3..c4b9a8a6d 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs @@ -7,7 +7,7 @@ use std::{ops::Deref, sync::Arc}; use crate::semantic::infer::infer_expr_list_types; use crate::{ DocTypeInferContext, FileId, GenericTpl, GenericTplId, LuaFunctionType, LuaGenericType, - TypeVisitTrait, + TypeVisitTrait, check_type_compact, db_index::{DbIndex, LuaType}, infer_doc_type, semantic::{ @@ -22,6 +22,7 @@ use crate::{ }, infer::InferFailReason, infer_expr, + overload_resolve::resolve_signature_by_args, }, }; use crate::{ @@ -147,6 +148,157 @@ fn infer_return_from_callable( } } +fn fallback_return_from_overloads(db: &DbIndex, overloads: &[Arc]) -> LuaType { + LuaType::from_vec( + overloads + .iter() + .map(|callable| { + let mut callable_tpls = HashSet::new(); + callable.visit_type(&mut |ty| { + if let LuaType::TplRef(generic_tpl) | LuaType::ConstTplRef(generic_tpl) = ty { + callable_tpls.insert(generic_tpl.get_tpl_id()); + } + }); + + let mut callable_substitutor = TypeSubstitutor::new(); + callable_substitutor.add_need_infer_tpls(callable_tpls); + infer_return_from_callable(db, callable, &callable_substitutor) + }) + .collect(), + ) +} + +fn infer_callable_return_from_arg_types( + context: &mut TplContext, + callable_type: &LuaType, + call_arg_types: Option<&[LuaType]>, + fallback_on_no_match: bool, +) -> Result, InferFailReason> { + let overloads = match callable_type { + LuaType::Ref(type_id) | LuaType::Def(type_id) => { + if let Some(origin_type) = context + .db + .get_type_index() + .get_type_decl(type_id) + .ok_or(InferFailReason::None)? + .get_alias_origin(context.db, None) + { + return infer_callable_return_from_arg_types( + context, + &origin_type, + call_arg_types, + fallback_on_no_match, + ); + } + return Ok(None); + } + LuaType::Generic(generic) => { + let substitutor = TypeSubstitutor::from_type_array(generic.get_params().to_vec()); + if let Some(origin_type) = context + .db + .get_type_index() + .get_type_decl(&generic.get_base_type_id()) + .ok_or(InferFailReason::None)? + .get_alias_origin(context.db, Some(&substitutor)) + { + return infer_callable_return_from_arg_types( + context, + &origin_type, + call_arg_types, + fallback_on_no_match, + ); + } + return Ok(None); + } + LuaType::Union(union) => { + let mut member_returns = Vec::new(); + for member in union.into_vec() { + if let Some(member_return) = + infer_callable_return_from_arg_types(context, &member, call_arg_types, false)? + { + member_returns.push(member_return); + } + } + return Ok(if member_returns.is_empty() { + None + } else { + Some(LuaType::from_vec(member_returns)) + }); + } + LuaType::Intersection(intersection) => { + let mut member_returns = Vec::new(); + for member in intersection.get_types() { + if let Some(member_return) = + infer_callable_return_from_arg_types(context, member, call_arg_types, false)? + { + member_returns.push(member_return); + } + } + return Ok(if member_returns.is_empty() { + None + } else { + Some(LuaType::from_vec(member_returns)) + }); + } + LuaType::DocFunction(doc_func) => vec![doc_func.clone()], + LuaType::Signature(sig_id) => { + let signature = context + .db + .get_signature_index() + .get(sig_id) + .ok_or(InferFailReason::None)?; + let mut overloads = signature.overloads.iter().cloned().collect::>(); + overloads.push(signature.to_doc_func_type()); + overloads + } + _ => return Ok(None), + }; + + let Some(call_arg_types) = call_arg_types else { + return Ok(Some(fallback_return_from_overloads(context.db, &overloads))); + }; + + let instantiated_overloads = overloads + .iter() + .filter_map(|callable| { + instantiate_callable_from_arg_types(context, callable, call_arg_types) + }) + .collect::>(); + if instantiated_overloads.is_empty() { + return Ok(if fallback_on_no_match { + Some(fallback_return_from_overloads(context.db, &overloads)) + } else { + None + }); + } + + // Prefer candidates that produce an informative instantiated return over ones that only + // narrow to `unknown`. This keeps explicit `fun(...): unknown` callbacks aligned with + // unresolved closure returns while still preserving parameter-shape matching. + let preferred_overloads = instantiated_overloads + .iter() + .filter(|callable| !callable.get_ret().is_unknown()) + .cloned() + .collect::>(); + let overloads_to_resolve = if preferred_overloads.is_empty() { + &instantiated_overloads + } else { + &preferred_overloads + }; + + Ok(Some( + resolve_signature_by_args( + context.db, + overloads_to_resolve, + call_arg_types, + false, + None, + ) + .map(|callable| callable.get_ret().clone()) + .unwrap_or_else(|_| fallback_return_from_overloads(context.db, &overloads)), + )) +} + pub fn infer_callable_return_from_remaining_args( context: &mut TplContext, callable_type: &LuaType, @@ -156,9 +308,76 @@ pub fn infer_callable_return_from_remaining_args( return Ok(None); } - let Some(callable) = as_doc_function_type(context.db, callable_type)? else { + let call_arg_types = + match infer_expr_list_types(context.db, context.cache, arg_exprs, None, infer_expr) { + Ok(types) => types.into_iter().map(|(ty, _)| ty).collect::>(), + Err(_) => { + return infer_callable_return_from_arg_types(context, callable_type, None, true); + } + }; + if call_arg_types.is_empty() { return Ok(None); - }; + } + + infer_callable_return_from_arg_types(context, callable_type, Some(&call_arg_types), true) +} + +fn callable_matches_arg_types( + db: &DbIndex, + callable: &LuaFunctionType, + call_arg_types: &[LuaType], +) -> bool { + let params = callable.get_params(); + let param_len = params.len(); + if param_len < call_arg_types.len() && !callable_last_param_is_variadic(callable) { + return false; + } + + for (arg_index, call_arg_type) in call_arg_types.iter().enumerate() { + let mut param_index = arg_index; + if callable.is_colon_define() { + if param_index == 0 { + continue; + } + param_index -= 1; + } + + let param_type = if let Some((_, ty)) = params.get(param_index) { + ty.clone().unwrap_or(LuaType::Any) + } else if let Some((name, ty)) = params.last() { + if name == "..." { + ty.clone().unwrap_or(LuaType::Any) + } else { + return false; + } + } else { + return false; + }; + + if !param_type.is_any() && check_type_compact(db, ¶m_type, call_arg_type).is_err() { + return false; + } + } + + true +} + +fn callable_last_param_is_variadic(callable: &LuaFunctionType) -> bool { + if let Some((name, _)) = callable.get_params().last() { + name == "..." + } else { + false + } +} + +fn instantiate_callable_from_arg_types( + context: &mut TplContext, + callable: &Arc, + call_arg_types: &[LuaType], +) -> Option> { + if !callable_matches_arg_types(context.db, callable, call_arg_types) { + return None; + } let mut callable_tpls = HashSet::new(); callable.visit_type(&mut |ty| { @@ -167,20 +386,7 @@ pub fn infer_callable_return_from_remaining_args( } }); if callable_tpls.is_empty() { - return Ok(Some(callable.get_ret().clone())); - } - - let mut callable_substitutor = TypeSubstitutor::new(); - callable_substitutor.add_need_infer_tpls(callable_tpls); - let fallback_return = infer_return_from_callable(context.db, &callable, &callable_substitutor); - - let call_arg_types = - match infer_expr_list_types(context.db, context.cache, arg_exprs, None, infer_expr) { - Ok(types) => types.into_iter().map(|(ty, _)| ty).collect::>(), - Err(_) => return Ok(Some(fallback_return)), - }; - if call_arg_types.is_empty() { - return Ok(None); + return Some(callable.clone()); } let callable_param_types = callable @@ -188,28 +394,36 @@ pub fn infer_callable_return_from_remaining_args( .iter() .map(|(_, ty)| ty.clone().unwrap_or(LuaType::Unknown)) .collect::>(); - + let mut callable_substitutor = TypeSubstitutor::new(); + callable_substitutor.add_need_infer_tpls(callable_tpls.clone()); let mut callable_context = TplContext { db: context.db, cache: context.cache, substitutor: &mut callable_substitutor, call_expr: context.call_expr.clone(), }; - if tpl_pattern_match_args( - &mut callable_context, - &callable_param_types, - &call_arg_types, - ) - .is_err() + if tpl_pattern_match_args(&mut callable_context, &callable_param_types, call_arg_types).is_err() { - return Ok(Some(fallback_return)); + return None; + } + + let instantiated = match instantiate_doc_function(context.db, callable, &callable_substitutor) { + LuaType::DocFunction(func) => func, + _ => callable.clone(), + }; + let mut has_unresolved_return_tpl = false; + instantiated.get_ret().visit_type(&mut |ty| { + if let LuaType::TplRef(generic_tpl) | LuaType::ConstTplRef(generic_tpl) = ty + && callable_tpls.contains(&generic_tpl.get_tpl_id()) + { + has_unresolved_return_tpl = true; + } + }); + if has_unresolved_return_tpl { + return None; } - Ok(Some(infer_return_from_callable( - context.db, - &callable, - &callable_substitutor, - ))) + Some(instantiated) } fn infer_generic_types_from_call( diff --git a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/lambda_tpl_pattern.rs b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/lambda_tpl_pattern.rs index 4b666a0a3..7b21b34cb 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/lambda_tpl_pattern.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/lambda_tpl_pattern.rs @@ -1,7 +1,5 @@ -use emmylua_parser::LuaAstNode; - use crate::{ - InferFailReason, LuaFunctionType, LuaSignatureId, TplContext, + InferFailReason, LuaFunctionType, LuaSignatureId, LuaType, TplContext, infer_expr, semantic::generic::tpl_pattern::TplPatternMatchResult, }; @@ -12,14 +10,14 @@ pub fn check_lambda_tpl_pattern( ) -> TplPatternMatchResult { let call_expr = context.call_expr.clone().ok_or(InferFailReason::None)?; let call_arg_list = call_expr.get_args_list().ok_or(InferFailReason::None)?; - let closure_position = signature_id.get_position(); - let closure_expr = call_arg_list - .get_args() - .find(|arg| arg.get_position() == closure_position); - - if closure_expr.is_none() { - return Err(InferFailReason::UnResolveSignatureReturn(signature_id)); + for arg in call_arg_list.get_args() { + if let Ok(LuaType::Signature(arg_signature_id)) = + infer_expr(context.db, context.cache, arg.clone()) + && arg_signature_id == signature_id + { + return Ok(()); + } } - Ok(()) + Err(InferFailReason::UnResolveSignatureReturn(signature_id)) } diff --git a/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs index 3fd266d29..b2218b6b5 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs @@ -841,4 +841,19 @@ mod tests { assert_eq!(ws.expr_ty("ok"), ws.ty("boolean")); assert_eq!(ws.expr_ty("payload"), ws.ty("string")); } + + #[test] + fn test_union_call_ignores_unresolved_alias_member() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@type MissingAlias | fun(): integer + local run + + result = run() + "#, + ); + + assert_eq!(ws.expr_ty("result"), ws.ty("integer")); + } } diff --git a/crates/emmylua_code_analysis/src/semantic/overload_resolve/mod.rs b/crates/emmylua_code_analysis/src/semantic/overload_resolve/mod.rs index bb8344e41..0b1b69dc4 100644 --- a/crates/emmylua_code_analysis/src/semantic/overload_resolve/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/overload_resolve/mod.rs @@ -16,7 +16,7 @@ use super::{ infer::{InferCallFuncResult, InferFailReason}, }; -use resolve_signature_by_args::resolve_signature_by_args; +pub(crate) use resolve_signature_by_args::resolve_signature_by_args; pub fn resolve_signature( db: &DbIndex, From 38d9008a63d1713d13e836f75223a2125c5f5e80 Mon Sep 17 00:00:00 2001 From: Lewis Russell Date: Thu, 2 Apr 2026 14:54:46 +0100 Subject: [PATCH 2/5] refactor(generic): share callable overload collection --- .../instantiate_func_generic.rs | 230 +++++------------- .../semantic/generic/instantiate_type/mod.rs | 58 ++++- .../src/semantic/generic/mod.rs | 1 + .../src/semantic/infer/infer_call/mod.rs | 146 ++++------- .../src/semantic/overload_resolve/mod.rs | 2 +- .../resolve_signature_by_args.rs | 149 +++++++----- 6 files changed, 258 insertions(+), 328 deletions(-) diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs index c4b9a8a6d..2d9544242 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs @@ -7,7 +7,7 @@ use std::{ops::Deref, sync::Arc}; use crate::semantic::infer::infer_expr_list_types; use crate::{ DocTypeInferContext, FileId, GenericTpl, GenericTplId, LuaFunctionType, LuaGenericType, - TypeVisitTrait, check_type_compact, + TypeVisitTrait, db_index::{DbIndex, LuaType}, infer_doc_type, semantic::{ @@ -22,7 +22,7 @@ use crate::{ }, infer::InferFailReason, infer_expr, - overload_resolve::resolve_signature_by_args, + overload_resolve::{callable_accepts_args, resolve_signature_by_args}, }, }; use crate::{ @@ -30,7 +30,7 @@ use crate::{ tpl_pattern_match_args, }; -use super::TypeSubstitutor; +use super::{TypeSubstitutor, collect_callable_overload_groups}; pub fn instantiate_func_generic( db: &DbIndex, @@ -174,129 +174,77 @@ fn infer_callable_return_from_arg_types( call_arg_types: Option<&[LuaType]>, fallback_on_no_match: bool, ) -> Result, InferFailReason> { - let overloads = match callable_type { - LuaType::Ref(type_id) | LuaType::Def(type_id) => { - if let Some(origin_type) = context - .db - .get_type_index() - .get_type_decl(type_id) - .ok_or(InferFailReason::None)? - .get_alias_origin(context.db, None) - { - return infer_callable_return_from_arg_types( - context, - &origin_type, - call_arg_types, - fallback_on_no_match, - ); - } - return Ok(None); - } - LuaType::Generic(generic) => { - let substitutor = TypeSubstitutor::from_type_array(generic.get_params().to_vec()); - if let Some(origin_type) = context - .db - .get_type_index() - .get_type_decl(&generic.get_base_type_id()) - .ok_or(InferFailReason::None)? - .get_alias_origin(context.db, Some(&substitutor)) - { - return infer_callable_return_from_arg_types( - context, - &origin_type, - call_arg_types, - fallback_on_no_match, - ); - } - return Ok(None); - } - LuaType::Union(union) => { - let mut member_returns = Vec::new(); - for member in union.into_vec() { - if let Some(member_return) = - infer_callable_return_from_arg_types(context, &member, call_arg_types, false)? - { - member_returns.push(member_return); - } - } - return Ok(if member_returns.is_empty() { - None - } else { - Some(LuaType::from_vec(member_returns)) - }); - } - LuaType::Intersection(intersection) => { - let mut member_returns = Vec::new(); - for member in intersection.get_types() { - if let Some(member_return) = - infer_callable_return_from_arg_types(context, member, call_arg_types, false)? - { - member_returns.push(member_return); - } - } - return Ok(if member_returns.is_empty() { - None - } else { - Some(LuaType::from_vec(member_returns)) - }); - } - LuaType::DocFunction(doc_func) => vec![doc_func.clone()], - LuaType::Signature(sig_id) => { - let signature = context - .db - .get_signature_index() - .get(sig_id) - .ok_or(InferFailReason::None)?; - let mut overloads = signature.overloads.iter().cloned().collect::>(); - overloads.push(signature.to_doc_func_type()); - overloads - } - _ => return Ok(None), - }; + let mut overload_groups = Vec::new(); + collect_callable_overload_groups(context.db, callable_type, &mut overload_groups)?; + if overload_groups.is_empty() { + return Ok(None); + } let Some(call_arg_types) = call_arg_types else { - return Ok(Some(fallback_return_from_overloads(context.db, &overloads))); + return Ok(Some(fallback_return_from_overloads( + context.db, + &overload_groups + .iter() + .flatten() + .cloned() + .collect::>(), + ))); }; - let instantiated_overloads = overloads - .iter() - .filter_map(|callable| { - instantiate_callable_from_arg_types(context, callable, call_arg_types) - }) - .collect::>(); - if instantiated_overloads.is_empty() { + let mut member_returns = Vec::new(); + for overloads in &overload_groups { + let instantiated_overloads = overloads + .iter() + .filter_map(|callable| { + instantiate_callable_from_arg_types(context, callable, call_arg_types) + }) + .collect::>(); + if instantiated_overloads.is_empty() { + continue; + } + + // Prefer candidates that produce an informative instantiated return over ones that only + // narrow to `unknown`. This keeps explicit `fun(...): unknown` callbacks aligned with + // unresolved closure returns while still preserving parameter-shape matching. + let preferred_overloads = instantiated_overloads + .iter() + .filter(|callable| !callable.get_ret().is_unknown()) + .cloned() + .collect::>(); + let overloads_to_resolve = if preferred_overloads.is_empty() { + &instantiated_overloads + } else { + &preferred_overloads + }; + + member_returns.push( + resolve_signature_by_args( + context.db, + overloads_to_resolve, + call_arg_types, + false, + None, + ) + .map(|callable| callable.get_ret().clone()) + .unwrap_or_else(|_| fallback_return_from_overloads(context.db, overloads)), + ); + } + if member_returns.is_empty() { return Ok(if fallback_on_no_match { - Some(fallback_return_from_overloads(context.db, &overloads)) + Some(fallback_return_from_overloads( + context.db, + &overload_groups + .iter() + .flatten() + .cloned() + .collect::>(), + )) } else { None }); } - // Prefer candidates that produce an informative instantiated return over ones that only - // narrow to `unknown`. This keeps explicit `fun(...): unknown` callbacks aligned with - // unresolved closure returns while still preserving parameter-shape matching. - let preferred_overloads = instantiated_overloads - .iter() - .filter(|callable| !callable.get_ret().is_unknown()) - .cloned() - .collect::>(); - let overloads_to_resolve = if preferred_overloads.is_empty() { - &instantiated_overloads - } else { - &preferred_overloads - }; - - Ok(Some( - resolve_signature_by_args( - context.db, - overloads_to_resolve, - call_arg_types, - false, - None, - ) - .map(|callable| callable.get_ret().clone()) - .unwrap_or_else(|_| fallback_return_from_overloads(context.db, &overloads)), - )) + Ok(Some(LuaType::from_vec(member_returns))) } pub fn infer_callable_return_from_remaining_args( @@ -322,60 +270,12 @@ pub fn infer_callable_return_from_remaining_args( infer_callable_return_from_arg_types(context, callable_type, Some(&call_arg_types), true) } -fn callable_matches_arg_types( - db: &DbIndex, - callable: &LuaFunctionType, - call_arg_types: &[LuaType], -) -> bool { - let params = callable.get_params(); - let param_len = params.len(); - if param_len < call_arg_types.len() && !callable_last_param_is_variadic(callable) { - return false; - } - - for (arg_index, call_arg_type) in call_arg_types.iter().enumerate() { - let mut param_index = arg_index; - if callable.is_colon_define() { - if param_index == 0 { - continue; - } - param_index -= 1; - } - - let param_type = if let Some((_, ty)) = params.get(param_index) { - ty.clone().unwrap_or(LuaType::Any) - } else if let Some((name, ty)) = params.last() { - if name == "..." { - ty.clone().unwrap_or(LuaType::Any) - } else { - return false; - } - } else { - return false; - }; - - if !param_type.is_any() && check_type_compact(db, ¶m_type, call_arg_type).is_err() { - return false; - } - } - - true -} - -fn callable_last_param_is_variadic(callable: &LuaFunctionType) -> bool { - if let Some((name, _)) = callable.get_params().last() { - name == "..." - } else { - false - } -} - fn instantiate_callable_from_arg_types( context: &mut TplContext, callable: &Arc, call_arg_types: &[LuaType], ) -> Option> { - if !callable_matches_arg_types(context.db, callable, call_arg_types) { + if !callable_accepts_args(context.db, callable, call_arg_types, false, None) { return None; } diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs index 9b1b1e2fe..e4a7e8f84 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs @@ -2,7 +2,7 @@ mod instantiate_func_generic; mod instantiate_special_generic; use hashbrown::{HashMap, HashSet}; -use std::ops::Deref; +use std::{ops::Deref, sync::Arc}; use crate::{ DbIndex, GenericTpl, GenericTplId, LuaAliasCallKind, LuaArrayType, LuaConditionalType, @@ -12,7 +12,10 @@ use crate::{ LuaFunctionType, LuaGenericType, LuaIntersectionType, LuaObjectType, LuaTupleType, LuaType, LuaUnionType, VariadicType, }, - semantic::type_check::{TypeCheckCheckLevel, check_type_compact_with_level}, + semantic::{ + infer::InferFailReason, + type_check::{TypeCheckCheckLevel, check_type_compact_with_level}, + }, }; use super::type_substitutor::{SubstitutorValue, TypeSubstitutor}; @@ -22,6 +25,57 @@ pub use instantiate_func_generic::{build_self_type, infer_self_type, instantiate pub use instantiate_special_generic::get_keyof_members; pub use instantiate_special_generic::instantiate_alias_call; +pub(crate) fn collect_callable_overload_groups( + db: &DbIndex, + callable_type: &LuaType, + groups: &mut Vec>>, +) -> Result<(), InferFailReason> { + match callable_type { + LuaType::Ref(type_id) | LuaType::Def(type_id) => { + let Some(type_decl) = db.get_type_index().get_type_decl(type_id) else { + return Ok(()); + }; + if let Some(origin_type) = type_decl.get_alias_origin(db, None) { + collect_callable_overload_groups(db, &origin_type, groups)?; + } + } + LuaType::Generic(generic) => { + let substitutor = TypeSubstitutor::from_type_array(generic.get_params().to_vec()); + let Some(type_decl) = db + .get_type_index() + .get_type_decl(&generic.get_base_type_id()) + else { + return Ok(()); + }; + if let Some(origin_type) = type_decl.get_alias_origin(db, Some(&substitutor)) { + collect_callable_overload_groups(db, &origin_type, groups)?; + } + } + LuaType::Union(union) => { + for member in union.into_vec() { + collect_callable_overload_groups(db, &member, groups)?; + } + } + LuaType::Intersection(intersection) => { + for member in intersection.get_types() { + collect_callable_overload_groups(db, member, groups)?; + } + } + LuaType::DocFunction(doc_func) => groups.push(vec![doc_func.clone()]), + LuaType::Signature(sig_id) => { + let Some(signature) = db.get_signature_index().get(sig_id) else { + return Ok(()); + }; + let mut overloads = signature.overloads.to_vec(); + overloads.push(signature.to_doc_func_type()); + groups.push(overloads); + } + _ => {} + } + + Ok(()) +} + pub fn instantiate_type_generic( db: &DbIndex, ty: &LuaType, diff --git a/crates/emmylua_code_analysis/src/semantic/generic/mod.rs b/crates/emmylua_code_analysis/src/semantic/generic/mod.rs index d66ae97bc..f41c898bd 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/mod.rs @@ -10,6 +10,7 @@ pub use call_constraint::{ }; use emmylua_parser::LuaAstNode; use emmylua_parser::LuaExpr; +pub(crate) use instantiate_type::collect_callable_overload_groups; pub use instantiate_type::*; use rowan::NodeOrToken; pub use tpl_context::TplContext; diff --git a/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs index b2218b6b5..c820890ca 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs @@ -15,7 +15,10 @@ use crate::{ use crate::{ InferGuardRef, semantic::{ - generic::{TypeSubstitutor, get_tpl_ref_extend_type, instantiate_doc_function}, + generic::{ + TypeSubstitutor, collect_callable_overload_groups, get_tpl_ref_extend_type, + instantiate_doc_function, + }, infer::narrow::get_type_at_call_expr_inline_cast, }, }; @@ -198,14 +201,10 @@ fn infer_generic_doc_function_union( call_expr: LuaCallExpr, args_count: Option, ) -> InferCallFuncResult { - let overloads = union - .into_vec() - .iter() - .filter_map(|typ| match typ { - LuaType::DocFunction(f) => Some(f.clone()), - _ => None, - }) - .collect::>(); + let mut overloads = Vec::new(); + for typ in union.into_vec() { + overloads.extend(collect_callable_overloads(db, &typ)?); + } resolve_signature(db, cache, overloads, call_expr.clone(), false, args_count) } @@ -225,40 +224,16 @@ fn infer_signature_doc_function( return Err(InferFailReason::UnResolveSignatureReturn(signature_id)); } let is_generic = signature_is_generic(db, cache, &signature, &call_expr).unwrap_or(false); - let overloads = &signature.overloads; - if overloads.is_empty() { - let mut fake_doc_function = LuaFunctionType::new( - signature.async_state, - signature.is_colon_define, - signature.is_vararg, - signature.get_type_params(), - signature.get_return_type(), - ); - if is_generic { - fake_doc_function = instantiate_func_generic(db, cache, &fake_doc_function, call_expr)?; - } + let overloads = collect_callable_overloads(db, &LuaType::Signature(signature_id))?; - Ok(fake_doc_function.into()) - } else { - let mut new_overloads = signature.overloads.clone(); - let fake_doc_function = Arc::new(LuaFunctionType::new( - signature.async_state, - signature.is_colon_define, - signature.is_vararg, - signature.get_type_params(), - signature.get_return_type(), - )); - new_overloads.push(fake_doc_function); - - resolve_signature( - db, - cache, - new_overloads, - call_expr.clone(), - is_generic, - args_count, - ) - } + resolve_signature( + db, + cache, + overloads, + call_expr.clone(), + is_generic, + args_count, + ) } fn infer_type_doc_function( @@ -486,73 +461,38 @@ fn infer_union( args_count: Option, ) -> InferCallFuncResult { // 此时一般是 signature + doc_function 的联合体 - let mut all_overloads = Vec::new(); - let mut base_signatures = Vec::new(); - + let mut overload_groups = Vec::new(); for ty in union.into_vec() { - match ty { - LuaType::Signature(signature_id) => { - if let Some(signature) = db.get_signature_index().get(&signature_id) { - // 处理 overloads - let overloads = if signature.is_generic() { - signature - .overloads - .iter() - .map(|func| { - Ok(Arc::new(instantiate_func_generic( - db, - cache, - func, - call_expr.clone(), - )?)) - }) - .collect::, _>>()? - } else { - signature.overloads.clone() - }; - all_overloads.extend(overloads); - - // 处理 signature 本身的函数类型 - let mut fake_doc_function = LuaFunctionType::new( - signature.async_state, - signature.is_colon_define, - signature.is_vararg, - signature.get_type_params(), - signature.get_return_type(), - ); - if signature.is_generic() { - fake_doc_function = instantiate_func_generic( - db, - cache, - &fake_doc_function, - call_expr.clone(), - )?; - } - base_signatures.push(Arc::new(fake_doc_function)); - } - } - LuaType::DocFunction(func) => { - let func_to_push = if func.contain_tpl() { - Arc::new(instantiate_func_generic( - db, - cache, - &func, - call_expr.clone(), - )?) - } else { - func.clone() - }; - base_signatures.push(func_to_push); - } - _ => {} + collect_callable_overload_groups(db, &ty, &mut overload_groups)?; + } + + let mut overloads = Vec::new(); + let mut base_signatures = Vec::new(); + for mut group in overload_groups { + if group.len() > 1 { + let base = group.pop().expect("signature group contains base function"); + overloads.extend(group); + base_signatures.push(base); + } else { + base_signatures.extend(group); } } - all_overloads.extend(base_signatures); - if all_overloads.is_empty() { + overloads.extend(base_signatures); + if overloads.is_empty() { return Err(InferFailReason::None); } - resolve_signature(db, cache, all_overloads, call_expr, false, args_count) + let is_generic = overloads.iter().any(|func| func.contain_tpl()); + resolve_signature(db, cache, overloads, call_expr, is_generic, args_count) +} + +fn collect_callable_overloads( + db: &DbIndex, + callable_type: &LuaType, +) -> Result>, InferFailReason> { + let mut overload_groups = Vec::new(); + collect_callable_overload_groups(db, callable_type, &mut overload_groups)?; + Ok(overload_groups.into_iter().flatten().collect()) } fn infer_intersection( diff --git a/crates/emmylua_code_analysis/src/semantic/overload_resolve/mod.rs b/crates/emmylua_code_analysis/src/semantic/overload_resolve/mod.rs index 0b1b69dc4..11d3d3595 100644 --- a/crates/emmylua_code_analysis/src/semantic/overload_resolve/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/overload_resolve/mod.rs @@ -16,7 +16,7 @@ use super::{ infer::{InferCallFuncResult, InferFailReason}, }; -pub(crate) use resolve_signature_by_args::resolve_signature_by_args; +pub(crate) use resolve_signature_by_args::{callable_accepts_args, resolve_signature_by_args}; pub fn resolve_signature( db: &DbIndex, diff --git a/crates/emmylua_code_analysis/src/semantic/overload_resolve/resolve_signature_by_args.rs b/crates/emmylua_code_analysis/src/semantic/overload_resolve/resolve_signature_by_args.rs index 00a26369e..2fda0c1a4 100644 --- a/crates/emmylua_code_analysis/src/semantic/overload_resolve/resolve_signature_by_args.rs +++ b/crates/emmylua_code_analysis/src/semantic/overload_resolve/resolve_signature_by_args.rs @@ -6,6 +6,33 @@ use crate::{ semantic::infer::InferCallFuncResult, }; +pub(crate) fn callable_accepts_args( + db: &DbIndex, + func: &LuaFunctionType, + expr_types: &[LuaType], + is_colon_call: bool, + arg_count: Option, +) -> bool { + let arg_count = arg_count.unwrap_or(expr_types.len()); + if func.get_params().len() < arg_count && !is_func_last_param_variadic(func) { + return false; + } + + for (arg_index, expr_type) in expr_types.iter().enumerate() { + let param_type = match get_call_arg_param(func, arg_index, is_colon_call) { + CallArgParam::Skip => continue, + CallArgParam::Present { param_type, .. } => param_type, + CallArgParam::Missing => return false, + }; + + if !param_type.is_any() && check_type_compact(db, ¶m_type, expr_type).is_err() { + return false; + } + } + + true +} + pub fn resolve_signature_by_args( db: &DbIndex, overloads: &[Arc], @@ -43,41 +70,22 @@ pub fn resolve_signature_by_args( None => continue, Some(func) => func, }; - let param_len = func.get_params().len(); - if param_len < arg_count && !is_func_last_param_variadic(func) { + if func.get_params().len() < arg_count && !is_func_last_param_variadic(func) { *opt_func = None; continue; } - let colon_define = func.is_colon_define(); - let mut param_index = arg_index; - match (colon_define, is_colon_call) { - (true, false) => { - if param_index == 0 { - continue; - } - param_index -= 1; - } - (false, true) => { - param_index += 1; - } - _ => {} - } - let param_type = if param_index < param_len { - let param_info = func.get_params().get(param_index); - param_info - .map(|it| it.1.clone().unwrap_or(LuaType::Any)) - .unwrap_or(LuaType::Any) - } else if let Some(last_param_info) = func.get_params().last() { - if last_param_info.0 == "..." { - last_param_info.1.clone().unwrap_or(LuaType::Any) - } else { + let (param_index, param_type) = match get_call_arg_param(func, arg_index, is_colon_call) + { + CallArgParam::Skip => continue, + CallArgParam::Present { + param_index, + param_type, + } => (param_index, param_type), + CallArgParam::Missing => { *opt_func = None; continue; } - } else { - *opt_func = None; - continue; }; let match_result = if param_type.is_any() { @@ -144,35 +152,15 @@ pub fn resolve_signature_by_args( None => continue, Some(func) => func, }; - let param_len = func.get_params().len(); - let colon_define = func.is_colon_define(); - let mut param_index = param_index; - match (colon_define, is_colon_call) { - (true, false) => { - if param_index == 0 { - continue; - } - param_index -= 1; - } - (false, true) => { - param_index += 1; - } - _ => {} - } - let param_type = if param_index < param_len { - let param_info = func.get_params().get(param_index); - param_info - .map(|it| it.1.clone().unwrap_or(LuaType::Any)) - .unwrap_or(LuaType::Any) - } else if let Some(last_param_info) = func.get_params().last() { - if last_param_info.0 == "..." { - last_param_info.1.clone().unwrap_or(LuaType::Any) - } else { - return Ok(func.clone()); - } - } else { - return Ok(func.clone()); - }; + let (param_index, param_type) = + match get_call_arg_param(func, param_index, is_colon_call) { + CallArgParam::Skip => continue, + CallArgParam::Present { + param_index, + param_type, + } => (param_index, param_type), + CallArgParam::Missing => return Ok(func.clone()), + }; let match_result = if param_type.is_any() { ParamMatchResult::Any @@ -208,6 +196,44 @@ pub fn resolve_signature_by_args( Ok(best_match_result) } +fn get_call_arg_param( + func: &LuaFunctionType, + arg_index: usize, + is_colon_call: bool, +) -> CallArgParam { + let mut param_index = arg_index; + match (func.is_colon_define(), is_colon_call) { + (true, false) => { + if param_index == 0 { + return CallArgParam::Skip; + } + param_index -= 1; + } + (false, true) => { + param_index += 1; + } + _ => {} + } + + if let Some((_, ty)) = func.get_params().get(param_index) { + return CallArgParam::Present { + param_index, + param_type: ty.clone().unwrap_or(LuaType::Any), + }; + } + + if let Some((name, ty)) = func.get_params().last() + && name == "..." + { + return CallArgParam::Present { + param_index, + param_type: ty.clone().unwrap_or(LuaType::Any), + }; + } + + CallArgParam::Missing +} + fn is_func_last_param_variadic(func: &LuaFunctionType) -> bool { if let Some(last_param) = func.get_params().last() { last_param.0 == "..." @@ -222,3 +248,12 @@ enum ParamMatchResult { Any, Type, } + +enum CallArgParam { + Skip, + Present { + param_index: usize, + param_type: LuaType, + }, + Missing, +} From f81dcfbdfc29a3f30ca27358a1b19528178e4687 Mon Sep 17 00:00:00 2001 From: Lewis Russell Date: Thu, 2 Apr 2026 15:48:55 +0100 Subject: [PATCH 3/5] fix(generic): guard recursive callable alias expansion --- .../semantic/generic/instantiate_type/mod.rs | 49 ++++++++++++++----- .../src/semantic/infer/infer_call/mod.rs | 16 ++++++ 2 files changed, 54 insertions(+), 11 deletions(-) diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs index e4a7e8f84..0104fad18 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs @@ -29,36 +29,63 @@ pub(crate) fn collect_callable_overload_groups( db: &DbIndex, callable_type: &LuaType, groups: &mut Vec>>, +) -> Result<(), InferFailReason> { + let mut visiting_aliases = HashSet::new(); + collect_callable_overload_groups_inner(db, callable_type, groups, &mut visiting_aliases) +} + +fn collect_callable_overload_groups_inner( + db: &DbIndex, + callable_type: &LuaType, + groups: &mut Vec>>, + visiting_aliases: &mut HashSet, ) -> Result<(), InferFailReason> { match callable_type { LuaType::Ref(type_id) | LuaType::Def(type_id) => { let Some(type_decl) = db.get_type_index().get_type_decl(type_id) else { return Ok(()); }; - if let Some(origin_type) = type_decl.get_alias_origin(db, None) { - collect_callable_overload_groups(db, &origin_type, groups)?; + if !visiting_aliases.insert(type_id.clone()) { + return Ok(()); } + + let result = if let Some(origin_type) = type_decl.get_alias_origin(db, None) { + collect_callable_overload_groups_inner(db, &origin_type, groups, visiting_aliases) + } else { + Ok(()) + }; + visiting_aliases.remove(type_id); + result?; } LuaType::Generic(generic) => { + let type_id = generic.get_base_type_id(); + if !visiting_aliases.insert(type_id.clone()) { + return Ok(()); + } let substitutor = TypeSubstitutor::from_type_array(generic.get_params().to_vec()); - let Some(type_decl) = db - .get_type_index() - .get_type_decl(&generic.get_base_type_id()) - else { + let Some(type_decl) = db.get_type_index().get_type_decl(&type_id) else { + visiting_aliases.remove(&type_id); return Ok(()); }; - if let Some(origin_type) = type_decl.get_alias_origin(db, Some(&substitutor)) { - collect_callable_overload_groups(db, &origin_type, groups)?; - } + + let result = if let Some(origin_type) = + type_decl.get_alias_origin(db, Some(&substitutor)) + { + collect_callable_overload_groups_inner(db, &origin_type, groups, visiting_aliases) + } else { + Ok(()) + }; + visiting_aliases.remove(&type_id); + result?; } LuaType::Union(union) => { for member in union.into_vec() { - collect_callable_overload_groups(db, &member, groups)?; + collect_callable_overload_groups_inner(db, &member, groups, visiting_aliases)?; } } LuaType::Intersection(intersection) => { for member in intersection.get_types() { - collect_callable_overload_groups(db, member, groups)?; + collect_callable_overload_groups_inner(db, member, groups, visiting_aliases)?; } } LuaType::DocFunction(doc_func) => groups.push(vec![doc_func.clone()]), diff --git a/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs index c820890ca..bd5d0beae 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs @@ -796,4 +796,20 @@ mod tests { assert_eq!(ws.expr_ty("result"), ws.ty("integer")); } + + #[test] + fn test_union_call_breaks_recursive_alias_cycle() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias A A | fun(): integer + ---@type A + local run + + result = run() + "#, + ); + + assert_eq!(ws.expr_ty("result"), ws.ty("integer")); + } } From c058e3de32bd3892a8c160d2e56c2b95c0a2351b Mon Sep 17 00:00:00 2001 From: Lewis Russell Date: Thu, 2 Apr 2026 16:15:17 +0100 Subject: [PATCH 4/5] fix(generic): keep callable overload inference shape-aware --- .../test/callable_return_infer_test.rs | 94 +++++++++++++++++++ .../instantiate_func_generic.rs | 24 +++-- .../src/semantic/generic/mod.rs | 1 + .../src/semantic/generic/tpl_pattern/mod.rs | 20 ++++ .../src/semantic/infer/infer_call/mod.rs | 79 +++++++++++++--- 5 files changed, 199 insertions(+), 19 deletions(-) diff --git a/crates/emmylua_code_analysis/src/compilation/test/callable_return_infer_test.rs b/crates/emmylua_code_analysis/src/compilation/test/callable_return_infer_test.rs index fd60b4e62..6dd26cc9b 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/callable_return_infer_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/callable_return_infer_test.rs @@ -200,4 +200,98 @@ mod test { // unresolved and carried through a named callback value. assert_eq!(ws.expr_ty("classify_string_unresolved"), ws.ty("string")); } + + #[test] + fn test_apply_return_infer_leaves_result_unknown_when_no_callable_member_matches_arg_shape() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic A, R + ---@param f fun(x: A): R + ---@param x A + ---@return R + local function apply(f, x) + return f(x) + end + + ---@alias FnInt fun(x: integer): integer + ---@alias FnString fun(x: string): string + + ---@type FnInt | FnString + local run + + ---@type boolean + local b + + result = apply(run, b) + "#, + ); + + let result_ty = ws.expr_ty("result"); + assert_eq!(result_ty, ws.ty("unknown")); + } + + #[test] + fn test_apply_return_infer_keeps_only_arity_compatible_fallbacks() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic A, B, R + ---@param f fun(x: A, y: B): R + ---@param x A + ---@param y B + ---@return R + local function apply2(f, x, y) + return f(x, y) + end + + ---@overload fun(x: integer): integer + ---@param x integer + ---@param y string + ---@return string + local function run(x, y) end + + local source ---@type table + + result = apply2(run, 1, source.missing) + "#, + ); + + let result_ty = ws.expr_ty("result"); + assert_eq!(ws.humanize_type(result_ty), "string"); + } + + #[test] + fn test_union_call_ignores_non_matching_generic_callable_member() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@type (fun(x: T): T) | fun(x: integer): integer + local run + + result = run(1) + "#, + ); + + let result_ty = ws.expr_ty("result"); + assert_eq!(ws.humanize_type(result_ty), "integer"); + } + + #[test] + fn test_union_call_ignores_non_matching_generic_alias_member() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias GenericStr fun(x: T): T + + ---@type GenericStr | fun(x: integer): integer + local run + + result = run(1) + "#, + ); + + let result_ty = ws.expr_ty("result"); + assert_eq!(ws.humanize_type(result_ty), "integer"); + } } diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs index 2d9544242..528203d92 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs @@ -27,7 +27,7 @@ use crate::{ }; use crate::{ LuaMemberOwner, LuaSemanticDeclId, SemanticDeclLevel, infer_node_semantic_decl, - tpl_pattern_match_args, + tpl_pattern_match_args_skip_unknown, }; use super::{TypeSubstitutor, collect_callable_overload_groups}; @@ -259,15 +259,22 @@ pub fn infer_callable_return_from_remaining_args( let call_arg_types = match infer_expr_list_types(context.db, context.cache, arg_exprs, None, infer_expr) { Ok(types) => types.into_iter().map(|(ty, _)| ty).collect::>(), - Err(_) => { - return infer_callable_return_from_arg_types(context, callable_type, None, true); - } + Err(_) => arg_exprs + .iter() + .map(|arg_expr| { + infer_expr(context.db, context.cache, arg_expr.clone()) + .unwrap_or(LuaType::Unknown) + }) + .collect::>(), }; if call_arg_types.is_empty() { return Ok(None); } - infer_callable_return_from_arg_types(context, callable_type, Some(&call_arg_types), true) + // Preserve any known remaining-arg shape, including arity, even when some later arguments + // collapse to `unknown`. This avoids unioning returns from overloads that are impossible + // for the current call. + infer_callable_return_from_arg_types(context, callable_type, Some(&call_arg_types), false) } fn instantiate_callable_from_arg_types( @@ -302,7 +309,12 @@ fn instantiate_callable_from_arg_types( substitutor: &mut callable_substitutor, call_expr: context.call_expr.clone(), }; - if tpl_pattern_match_args(&mut callable_context, &callable_param_types, call_arg_types).is_err() + if tpl_pattern_match_args_skip_unknown( + &mut callable_context, + &callable_param_types, + call_arg_types, + ) + .is_err() { return None; } diff --git a/crates/emmylua_code_analysis/src/semantic/generic/mod.rs b/crates/emmylua_code_analysis/src/semantic/generic/mod.rs index f41c898bd..16749b061 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/mod.rs @@ -15,6 +15,7 @@ pub use instantiate_type::*; use rowan::NodeOrToken; pub use tpl_context::TplContext; pub use tpl_pattern::tpl_pattern_match_args; +pub use tpl_pattern::tpl_pattern_match_args_skip_unknown; pub use type_substitutor::TypeSubstitutor; use crate::DbIndex; diff --git a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs index 88e695057..3339a2dce 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs @@ -32,6 +32,23 @@ pub fn tpl_pattern_match_args( context: &mut TplContext, func_param_types: &[LuaType], call_arg_types: &[LuaType], +) -> TplPatternMatchResult { + tpl_pattern_match_args_inner(context, func_param_types, call_arg_types, false) +} + +pub fn tpl_pattern_match_args_skip_unknown( + context: &mut TplContext, + func_param_types: &[LuaType], + call_arg_types: &[LuaType], +) -> TplPatternMatchResult { + tpl_pattern_match_args_inner(context, func_param_types, call_arg_types, true) +} + +fn tpl_pattern_match_args_inner( + context: &mut TplContext, + func_param_types: &[LuaType], + call_arg_types: &[LuaType], + skip_unknown_tpl: bool, ) -> TplPatternMatchResult { for i in 0..func_param_types.len() { if i >= call_arg_types.len() { @@ -54,6 +71,9 @@ pub fn tpl_pattern_match_args( )?; break; } + _ if skip_unknown_tpl + && func_param_type.contain_tpl() + && (call_arg_type.is_any() || call_arg_type.is_unknown()) => {} _ => { tpl_pattern_match(context, func_param_type, call_arg_type)?; } diff --git a/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs index bd5d0beae..147af1681 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs @@ -1,12 +1,14 @@ use std::sync::Arc; use emmylua_parser::{LuaAstNode, LuaCallExpr, LuaExpr, LuaSyntaxKind}; +use hashbrown::HashSet; use rowan::TextRange; use super::{ super::{InferGuard, LuaInferCache, instantiate_type_generic, resolve_signature}, InferFailReason, InferResult, }; +use crate::semantic::overload_resolve::callable_accepts_args; use crate::{ CacheEntry, DbIndex, InFiled, LuaFunctionType, LuaGenericType, LuaInstanceType, LuaIntersectionType, LuaOperatorMetaMethod, LuaOperatorOwner, LuaSignature, LuaSignatureId, @@ -102,12 +104,16 @@ pub fn infer_call_expr_func( ), LuaType::Union(union) => { // 此时我们将其视为泛型实例化联合体 - if union - .into_vec() + let union_types = union.into_vec(); + if union_types .iter() .all(|t| matches!(t, LuaType::DocFunction(_))) { - infer_generic_doc_function_union(db, cache, union, call_expr.clone(), args_count) + let mut overloads = Vec::new(); + for typ in union_types { + overloads.extend(collect_callable_overloads(db, &typ)?); + } + resolve_filtered_overloads(db, cache, overloads, call_expr.clone(), args_count) } else { infer_union(db, cache, union, call_expr.clone(), args_count) } @@ -194,19 +200,67 @@ fn infer_doc_function( Ok(func.clone().into()) } -fn infer_generic_doc_function_union( +fn filter_callable_overloads_by_call_args( db: &DbIndex, cache: &mut LuaInferCache, - union: &LuaUnionType, + overloads: Vec>, + call_expr: &LuaCallExpr, + args_count: Option, +) -> Result>, InferFailReason> { + let args = call_expr.get_args_list().ok_or(InferFailReason::None)?; + let expr_types = super::infer_expr_list_types( + db, + cache, + &args.get_args().collect::>(), + args_count, + |db, cache, expr| Ok(infer_expr(db, cache, expr).unwrap_or(LuaType::Unknown)), + )? + .into_iter() + .map(|(ty, _)| ty) + .collect::>(); + let is_colon_call = call_expr.is_colon_call(); + + Ok(overloads + .into_iter() + .filter(|func| { + let mut callable_tpls = HashSet::new(); + func.visit_type(&mut |ty| match ty { + LuaType::TplRef(generic_tpl) | LuaType::ConstTplRef(generic_tpl) => { + callable_tpls.insert(generic_tpl.get_tpl_id()); + } + LuaType::StrTplRef(str_tpl) => { + callable_tpls.insert(str_tpl.get_tpl_id()); + } + _ => {} + }); + + if callable_tpls.is_empty() { + return true; + } + + let mut substitutor = TypeSubstitutor::new(); + substitutor.add_need_infer_tpls(callable_tpls); + let match_func = match instantiate_doc_function(db, func, &substitutor) { + LuaType::DocFunction(doc_func) => doc_func, + _ => func.clone(), + }; + + callable_accepts_args(db, &match_func, &expr_types, is_colon_call, args_count) + }) + .collect()) +} + +fn resolve_filtered_overloads( + db: &DbIndex, + cache: &mut LuaInferCache, + overloads: Vec>, call_expr: LuaCallExpr, args_count: Option, ) -> InferCallFuncResult { - let mut overloads = Vec::new(); - for typ in union.into_vec() { - overloads.extend(collect_callable_overloads(db, &typ)?); - } - - resolve_signature(db, cache, overloads, call_expr.clone(), false, args_count) + let contains_tpl = overloads.iter().any(|func| func.contain_tpl()); + let overloads = + filter_callable_overloads_by_call_args(db, cache, overloads, &call_expr, args_count)?; + resolve_signature(db, cache, overloads, call_expr, contains_tpl, args_count) } fn infer_signature_doc_function( @@ -482,8 +536,7 @@ fn infer_union( if overloads.is_empty() { return Err(InferFailReason::None); } - let is_generic = overloads.iter().any(|func| func.contain_tpl()); - resolve_signature(db, cache, overloads, call_expr, is_generic, args_count) + resolve_filtered_overloads(db, cache, overloads, call_expr, args_count) } fn collect_callable_overloads( From 3b763c63268607dd29fd9bd9880b0e049a5d1d91 Mon Sep 17 00:00:00 2001 From: Lewis Russell Date: Thu, 2 Apr 2026 17:42:51 +0100 Subject: [PATCH 5/5] fix(code-analysis): keep overload returns wide on unknown args --- .../test/callable_return_infer_test.rs | 30 +++++++++++++++++++ .../instantiate_func_generic.rs | 17 +++++++++-- 2 files changed, 44 insertions(+), 3 deletions(-) diff --git a/crates/emmylua_code_analysis/src/compilation/test/callable_return_infer_test.rs b/crates/emmylua_code_analysis/src/compilation/test/callable_return_infer_test.rs index 6dd26cc9b..2cffe25bc 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/callable_return_infer_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/callable_return_infer_test.rs @@ -261,6 +261,36 @@ mod test { assert_eq!(ws.humanize_type(result_ty), "string"); } + #[test] + fn test_apply_return_infer_keeps_same_arity_overload_returns_when_tail_is_unknown() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic A, B, R + ---@param f fun(x: A, y: B): R + ---@param x A + ---@param y B + ---@return R + local function apply2(f, x, y) + return f(x, y) + end + + ---@overload fun(x: integer, y: number): number + ---@param x integer + ---@param y string + ---@return string + local function run(x, y) end + + local source ---@type table + + result = apply2(run, 1, source.missing) + "#, + ); + + let result_ty = ws.expr_ty("result"); + assert_eq!(result_ty, ws.ty("number|string")); + } + #[test] fn test_union_call_ignores_non_matching_generic_callable_member() { let mut ws = VirtualWorkspace::new(); diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs index 528203d92..a7e59fe97 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs @@ -216,8 +216,19 @@ fn infer_callable_return_from_arg_types( } else { &preferred_overloads }; + let unresolved_arg_match = overloads_to_resolve.len() > 1 + && call_arg_types + .iter() + .any(|arg_type| arg_type.is_any() || arg_type.is_unknown()); - member_returns.push( + member_returns.push(if unresolved_arg_match { + LuaType::from_vec( + overloads_to_resolve + .iter() + .map(|callable| callable.get_ret().clone()) + .collect(), + ) + } else { resolve_signature_by_args( context.db, overloads_to_resolve, @@ -226,8 +237,8 @@ fn infer_callable_return_from_arg_types( None, ) .map(|callable| callable.get_ret().clone()) - .unwrap_or_else(|_| fallback_return_from_overloads(context.db, overloads)), - ); + .unwrap_or_else(|_| fallback_return_from_overloads(context.db, overloads)) + }); } if member_returns.is_empty() { return Ok(if fallback_on_no_match {