diff --git a/crates/emmylua_code_analysis/src/compilation/test/callable_return_infer_test.rs b/crates/emmylua_code_analysis/src/compilation/test/callable_return_infer_test.rs index 5771d9442..2cffe25bc 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/callable_return_infer_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/callable_return_infer_test.rs @@ -85,4 +85,243 @@ mod test { assert_eq!(ws.expr_ty("result"), ws.ty("string")); } + + #[test] + fn test_apply_return_infer_prefers_informative_callback_overloads_and_keeps_param_shape() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic A, R + ---@param f fun(x: A): R + ---@param x A + ---@return R + local function apply(f, x) + return f(x) + end + + ---@overload fun(cb: fun(): T): T + ---@overload fun(cb: function): boolean + local function run(cb) end + + ---@overload fun(cb: fun(x: integer): T): integer + ---@overload fun(cb: fun(x: string): T): string + ---@overload fun(cb: function): boolean + local function classify(cb) end + + ---@overload fun(cb: fun(): T): { value: T } + ---@overload fun(cb: function): boolean + local function wrap(cb) end + + local source ---@type table + + ---@return integer + local function cb_concrete() + return 1 + end + + ---@type fun(): unknown + local cb_unknown + + ---@type fun(x: integer): unknown + local cb_param_unknown + + ---@type fun(x: string): unknown + local cb_param_unknown_string + + ---@param x integer + local function cb_param_unresolved(x) + return source.missing + end + + ---@param x string + local function cb_param_unresolved_string(x) + return source.missing + end + + local function cb_named_unresolved() + return source.missing + end + + run_concrete = apply(run, cb_concrete) + + run_unknown = apply(run, cb_unknown) + + run_unresolved = apply(run, function() + return source.missing + end) + + run_named_unresolved = apply(run, cb_named_unresolved) + + wrap_named_unresolved = apply(wrap, cb_named_unresolved) + + classify_unknown = apply(classify, cb_param_unknown) + + classify_unresolved = apply(classify, cb_param_unresolved) + + classify_string_unknown = apply(classify, cb_param_unknown_string) + + classify_string_unresolved = apply(classify, cb_param_unresolved_string) + "#, + ); + + // The callback return is concrete, so `T` is inferred as `integer` and the generic + // overload is more informative than the `function -> boolean` fallback. + assert_eq!(ws.expr_ty("run_concrete"), ws.ty("integer")); + + // `fun(): unknown` matches the generic parameter shape, but its return stays opaque, so + // the concrete `function -> boolean` fallback is the best available result. + assert_eq!(ws.expr_ty("run_unknown"), ws.ty("boolean")); + + // An unresolved closure return should be treated the same as an explicit `unknown` + // return, so this should resolve to the same fallback `boolean`. + assert_eq!(ws.expr_ty("run_unresolved"), ws.ty("boolean")); + + // The named-callback path should detect the unresolved closure signature too, so this + // should stay aligned with the inline unresolved callback case above. + assert_eq!(ws.expr_ty("run_named_unresolved"), ws.ty("boolean")); + + // If the instantiated return still embeds the unresolved template, the generic overload + // should be discarded and the concrete fallback should win. + assert_eq!(ws.expr_ty("wrap_named_unresolved"), ws.ty("boolean")); + + // The callback's parameter type is known, so the generic `fun(x: integer): T` overload + // should still win even though the callback return is only `unknown`. + assert_eq!(ws.expr_ty("classify_unknown"), ws.ty("integer")); + + // The callback return is unresolved, but its parameter is still `integer`, so overload + // ranking should keep using that known shape and pick the generic integer branch. + assert_eq!(ws.expr_ty("classify_unresolved"), ws.ty("integer")); + + // The callback's parameter type is `string`, so overload selection should not fall back + // to the first generic branch when the callback return is only `unknown`. + assert_eq!(ws.expr_ty("classify_string_unknown"), ws.ty("string")); + + // The same `string`-parameter branch should still win when the callback return is + // unresolved and carried through a named callback value. + assert_eq!(ws.expr_ty("classify_string_unresolved"), ws.ty("string")); + } + + #[test] + fn test_apply_return_infer_leaves_result_unknown_when_no_callable_member_matches_arg_shape() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic A, R + ---@param f fun(x: A): R + ---@param x A + ---@return R + local function apply(f, x) + return f(x) + end + + ---@alias FnInt fun(x: integer): integer + ---@alias FnString fun(x: string): string + + ---@type FnInt | FnString + local run + + ---@type boolean + local b + + result = apply(run, b) + "#, + ); + + let result_ty = ws.expr_ty("result"); + assert_eq!(result_ty, ws.ty("unknown")); + } + + #[test] + fn test_apply_return_infer_keeps_only_arity_compatible_fallbacks() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic A, B, R + ---@param f fun(x: A, y: B): R + ---@param x A + ---@param y B + ---@return R + local function apply2(f, x, y) + return f(x, y) + end + + ---@overload fun(x: integer): integer + ---@param x integer + ---@param y string + ---@return string + local function run(x, y) end + + local source ---@type table + + result = apply2(run, 1, source.missing) + "#, + ); + + let result_ty = ws.expr_ty("result"); + assert_eq!(ws.humanize_type(result_ty), "string"); + } + + #[test] + fn test_apply_return_infer_keeps_same_arity_overload_returns_when_tail_is_unknown() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic A, B, R + ---@param f fun(x: A, y: B): R + ---@param x A + ---@param y B + ---@return R + local function apply2(f, x, y) + return f(x, y) + end + + ---@overload fun(x: integer, y: number): number + ---@param x integer + ---@param y string + ---@return string + local function run(x, y) end + + local source ---@type table + + result = apply2(run, 1, source.missing) + "#, + ); + + let result_ty = ws.expr_ty("result"); + assert_eq!(result_ty, ws.ty("number|string")); + } + + #[test] + fn test_union_call_ignores_non_matching_generic_callable_member() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@type (fun(x: T): T) | fun(x: integer): integer + local run + + result = run(1) + "#, + ); + + let result_ty = ws.expr_ty("result"); + assert_eq!(ws.humanize_type(result_ty), "integer"); + } + + #[test] + fn test_union_call_ignores_non_matching_generic_alias_member() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias GenericStr fun(x: T): T + + ---@type GenericStr | fun(x: integer): integer + local run + + result = run(1) + "#, + ); + + let result_ty = ws.expr_ty("result"); + assert_eq!(ws.humanize_type(result_ty), "integer"); + } } diff --git a/crates/emmylua_code_analysis/src/compilation/test/pcall_test.rs b/crates/emmylua_code_analysis/src/compilation/test/pcall_test.rs index e1a8cadf3..0f23c3352 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/pcall_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/pcall_test.rs @@ -1,6 +1,6 @@ #[cfg(test)] mod test { - use crate::{DiagnosticCode, VirtualWorkspace}; + use crate::{DiagnosticCode, LuaType, VirtualWorkspace}; const STACKED_PCALL_ALIAS_GUARDS: usize = 180; @@ -199,4 +199,108 @@ mod test { assert_eq!(ws.expr_ty("success"), ws.ty("unknown")); assert_eq!(ws.expr_ty("failure"), ws.ty("string")); } + + #[test] + fn test_return_overload_unions_callable_union_members_with_same_domain() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + ws.def( + r#" + ---@alias FnA fun(x: integer): integer + ---@alias FnB fun(x: integer): boolean + + ---@type FnA | FnB + local run + + _, a = pcall(run, 1) + "#, + ); + + assert_eq!( + ws.expr_ty("a"), + LuaType::from_vec(vec![LuaType::Integer, LuaType::Boolean, LuaType::String]) + ); + } + + #[test] + fn test_return_overload_callable_union_is_not_order_dependent() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + ws.def( + r#" + ---@alias FnA fun(x: integer): integer + ---@alias FnB fun(x: integer): boolean + + ---@type FnB | FnA + local run + + _, a = pcall(run, 1) + "#, + ); + + assert_eq!( + ws.expr_ty("a"), + LuaType::from_vec(vec![LuaType::Integer, LuaType::Boolean, LuaType::String]) + ); + } + + #[test] + fn test_return_overload_unions_callable_intersection_members_with_same_domain() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + ws.def( + r#" + ---@alias FnA fun(x: integer): integer + ---@alias FnB fun(x: integer): boolean + + ---@type FnA & FnB + local run + + _, a = pcall(run, 1) + "#, + ); + + assert_eq!( + ws.expr_ty("a"), + LuaType::from_vec(vec![LuaType::Integer, LuaType::Boolean, LuaType::String]) + ); + } + + #[test] + fn test_return_overload_narrows_callable_union_member_by_arg_shape() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + ws.def( + r#" + ---@alias FnA fun(x: integer): integer + ---@alias FnB fun(x: string): integer + + ---@type FnA | FnB + local run + + _, a = pcall(run, 1) + "#, + ); + + assert_eq!(ws.expr_ty("a"), ws.ty("integer|string")); + } + + #[test] + fn test_return_overload_narrows_callable_intersection_member_by_arg_shape() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + ws.def( + r#" + ---@alias FnA fun(x: integer): integer + ---@alias FnB fun(x: string): boolean + + ---@type FnA & FnB + local run + + _, a = pcall(run, 1) + "#, + ); + + assert_eq!(ws.expr_ty("a"), ws.ty("integer|string")); + } } diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs index 39f09fcd3..a7e59fe97 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs @@ -22,14 +22,15 @@ use crate::{ }, infer::InferFailReason, infer_expr, + overload_resolve::{callable_accepts_args, resolve_signature_by_args}, }, }; use crate::{ LuaMemberOwner, LuaSemanticDeclId, SemanticDeclLevel, infer_node_semantic_decl, - tpl_pattern_match_args, + tpl_pattern_match_args_skip_unknown, }; -use super::TypeSubstitutor; +use super::{TypeSubstitutor, collect_callable_overload_groups}; pub fn instantiate_func_generic( db: &DbIndex, @@ -147,69 +148,205 @@ fn infer_return_from_callable( } } -pub fn infer_callable_return_from_remaining_args( +fn fallback_return_from_overloads(db: &DbIndex, overloads: &[Arc]) -> LuaType { + LuaType::from_vec( + overloads + .iter() + .map(|callable| { + let mut callable_tpls = HashSet::new(); + callable.visit_type(&mut |ty| { + if let LuaType::TplRef(generic_tpl) | LuaType::ConstTplRef(generic_tpl) = ty { + callable_tpls.insert(generic_tpl.get_tpl_id()); + } + }); + + let mut callable_substitutor = TypeSubstitutor::new(); + callable_substitutor.add_need_infer_tpls(callable_tpls); + infer_return_from_callable(db, callable, &callable_substitutor) + }) + .collect(), + ) +} + +fn infer_callable_return_from_arg_types( context: &mut TplContext, callable_type: &LuaType, - arg_exprs: &[LuaExpr], + call_arg_types: Option<&[LuaType]>, + fallback_on_no_match: bool, ) -> Result, InferFailReason> { - if arg_exprs.is_empty() { + let mut overload_groups = Vec::new(); + collect_callable_overload_groups(context.db, callable_type, &mut overload_groups)?; + if overload_groups.is_empty() { return Ok(None); } - let Some(callable) = as_doc_function_type(context.db, callable_type)? else { - return Ok(None); + let Some(call_arg_types) = call_arg_types else { + return Ok(Some(fallback_return_from_overloads( + context.db, + &overload_groups + .iter() + .flatten() + .cloned() + .collect::>(), + ))); }; - let mut callable_tpls = HashSet::new(); - callable.visit_type(&mut |ty| { - if let LuaType::TplRef(generic_tpl) | LuaType::ConstTplRef(generic_tpl) = ty { - callable_tpls.insert(generic_tpl.get_tpl_id()); + let mut member_returns = Vec::new(); + for overloads in &overload_groups { + let instantiated_overloads = overloads + .iter() + .filter_map(|callable| { + instantiate_callable_from_arg_types(context, callable, call_arg_types) + }) + .collect::>(); + if instantiated_overloads.is_empty() { + continue; } - }); - if callable_tpls.is_empty() { - return Ok(Some(callable.get_ret().clone())); + + // Prefer candidates that produce an informative instantiated return over ones that only + // narrow to `unknown`. This keeps explicit `fun(...): unknown` callbacks aligned with + // unresolved closure returns while still preserving parameter-shape matching. + let preferred_overloads = instantiated_overloads + .iter() + .filter(|callable| !callable.get_ret().is_unknown()) + .cloned() + .collect::>(); + let overloads_to_resolve = if preferred_overloads.is_empty() { + &instantiated_overloads + } else { + &preferred_overloads + }; + let unresolved_arg_match = overloads_to_resolve.len() > 1 + && call_arg_types + .iter() + .any(|arg_type| arg_type.is_any() || arg_type.is_unknown()); + + member_returns.push(if unresolved_arg_match { + LuaType::from_vec( + overloads_to_resolve + .iter() + .map(|callable| callable.get_ret().clone()) + .collect(), + ) + } else { + resolve_signature_by_args( + context.db, + overloads_to_resolve, + call_arg_types, + false, + None, + ) + .map(|callable| callable.get_ret().clone()) + .unwrap_or_else(|_| fallback_return_from_overloads(context.db, overloads)) + }); + } + if member_returns.is_empty() { + return Ok(if fallback_on_no_match { + Some(fallback_return_from_overloads( + context.db, + &overload_groups + .iter() + .flatten() + .cloned() + .collect::>(), + )) + } else { + None + }); } - let mut callable_substitutor = TypeSubstitutor::new(); - callable_substitutor.add_need_infer_tpls(callable_tpls); - let fallback_return = infer_return_from_callable(context.db, &callable, &callable_substitutor); + Ok(Some(LuaType::from_vec(member_returns))) +} + +pub fn infer_callable_return_from_remaining_args( + context: &mut TplContext, + callable_type: &LuaType, + arg_exprs: &[LuaExpr], +) -> Result, InferFailReason> { + if arg_exprs.is_empty() { + return Ok(None); + } let call_arg_types = match infer_expr_list_types(context.db, context.cache, arg_exprs, None, infer_expr) { Ok(types) => types.into_iter().map(|(ty, _)| ty).collect::>(), - Err(_) => return Ok(Some(fallback_return)), + Err(_) => arg_exprs + .iter() + .map(|arg_expr| { + infer_expr(context.db, context.cache, arg_expr.clone()) + .unwrap_or(LuaType::Unknown) + }) + .collect::>(), }; if call_arg_types.is_empty() { return Ok(None); } + // Preserve any known remaining-arg shape, including arity, even when some later arguments + // collapse to `unknown`. This avoids unioning returns from overloads that are impossible + // for the current call. + infer_callable_return_from_arg_types(context, callable_type, Some(&call_arg_types), false) +} + +fn instantiate_callable_from_arg_types( + context: &mut TplContext, + callable: &Arc, + call_arg_types: &[LuaType], +) -> Option> { + if !callable_accepts_args(context.db, callable, call_arg_types, false, None) { + return None; + } + + let mut callable_tpls = HashSet::new(); + callable.visit_type(&mut |ty| { + if let LuaType::TplRef(generic_tpl) | LuaType::ConstTplRef(generic_tpl) = ty { + callable_tpls.insert(generic_tpl.get_tpl_id()); + } + }); + if callable_tpls.is_empty() { + return Some(callable.clone()); + } + let callable_param_types = callable .get_params() .iter() .map(|(_, ty)| ty.clone().unwrap_or(LuaType::Unknown)) .collect::>(); - + let mut callable_substitutor = TypeSubstitutor::new(); + callable_substitutor.add_need_infer_tpls(callable_tpls.clone()); let mut callable_context = TplContext { db: context.db, cache: context.cache, substitutor: &mut callable_substitutor, call_expr: context.call_expr.clone(), }; - if tpl_pattern_match_args( + if tpl_pattern_match_args_skip_unknown( &mut callable_context, &callable_param_types, - &call_arg_types, + call_arg_types, ) .is_err() { - return Ok(Some(fallback_return)); + return None; + } + + let instantiated = match instantiate_doc_function(context.db, callable, &callable_substitutor) { + LuaType::DocFunction(func) => func, + _ => callable.clone(), + }; + let mut has_unresolved_return_tpl = false; + instantiated.get_ret().visit_type(&mut |ty| { + if let LuaType::TplRef(generic_tpl) | LuaType::ConstTplRef(generic_tpl) = ty + && callable_tpls.contains(&generic_tpl.get_tpl_id()) + { + has_unresolved_return_tpl = true; + } + }); + if has_unresolved_return_tpl { + return None; } - Ok(Some(infer_return_from_callable( - context.db, - &callable, - &callable_substitutor, - ))) + Some(instantiated) } fn infer_generic_types_from_call( diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs index 9b1b1e2fe..0104fad18 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs @@ -2,7 +2,7 @@ mod instantiate_func_generic; mod instantiate_special_generic; use hashbrown::{HashMap, HashSet}; -use std::ops::Deref; +use std::{ops::Deref, sync::Arc}; use crate::{ DbIndex, GenericTpl, GenericTplId, LuaAliasCallKind, LuaArrayType, LuaConditionalType, @@ -12,7 +12,10 @@ use crate::{ LuaFunctionType, LuaGenericType, LuaIntersectionType, LuaObjectType, LuaTupleType, LuaType, LuaUnionType, VariadicType, }, - semantic::type_check::{TypeCheckCheckLevel, check_type_compact_with_level}, + semantic::{ + infer::InferFailReason, + type_check::{TypeCheckCheckLevel, check_type_compact_with_level}, + }, }; use super::type_substitutor::{SubstitutorValue, TypeSubstitutor}; @@ -22,6 +25,84 @@ pub use instantiate_func_generic::{build_self_type, infer_self_type, instantiate pub use instantiate_special_generic::get_keyof_members; pub use instantiate_special_generic::instantiate_alias_call; +pub(crate) fn collect_callable_overload_groups( + db: &DbIndex, + callable_type: &LuaType, + groups: &mut Vec>>, +) -> Result<(), InferFailReason> { + let mut visiting_aliases = HashSet::new(); + collect_callable_overload_groups_inner(db, callable_type, groups, &mut visiting_aliases) +} + +fn collect_callable_overload_groups_inner( + db: &DbIndex, + callable_type: &LuaType, + groups: &mut Vec>>, + visiting_aliases: &mut HashSet, +) -> Result<(), InferFailReason> { + match callable_type { + LuaType::Ref(type_id) | LuaType::Def(type_id) => { + let Some(type_decl) = db.get_type_index().get_type_decl(type_id) else { + return Ok(()); + }; + if !visiting_aliases.insert(type_id.clone()) { + return Ok(()); + } + + let result = if let Some(origin_type) = type_decl.get_alias_origin(db, None) { + collect_callable_overload_groups_inner(db, &origin_type, groups, visiting_aliases) + } else { + Ok(()) + }; + visiting_aliases.remove(type_id); + result?; + } + LuaType::Generic(generic) => { + let type_id = generic.get_base_type_id(); + if !visiting_aliases.insert(type_id.clone()) { + return Ok(()); + } + let substitutor = TypeSubstitutor::from_type_array(generic.get_params().to_vec()); + let Some(type_decl) = db.get_type_index().get_type_decl(&type_id) else { + visiting_aliases.remove(&type_id); + return Ok(()); + }; + + let result = if let Some(origin_type) = + type_decl.get_alias_origin(db, Some(&substitutor)) + { + collect_callable_overload_groups_inner(db, &origin_type, groups, visiting_aliases) + } else { + Ok(()) + }; + visiting_aliases.remove(&type_id); + result?; + } + LuaType::Union(union) => { + for member in union.into_vec() { + collect_callable_overload_groups_inner(db, &member, groups, visiting_aliases)?; + } + } + LuaType::Intersection(intersection) => { + for member in intersection.get_types() { + collect_callable_overload_groups_inner(db, member, groups, visiting_aliases)?; + } + } + LuaType::DocFunction(doc_func) => groups.push(vec![doc_func.clone()]), + LuaType::Signature(sig_id) => { + let Some(signature) = db.get_signature_index().get(sig_id) else { + return Ok(()); + }; + let mut overloads = signature.overloads.to_vec(); + overloads.push(signature.to_doc_func_type()); + groups.push(overloads); + } + _ => {} + } + + Ok(()) +} + pub fn instantiate_type_generic( db: &DbIndex, ty: &LuaType, diff --git a/crates/emmylua_code_analysis/src/semantic/generic/mod.rs b/crates/emmylua_code_analysis/src/semantic/generic/mod.rs index d66ae97bc..16749b061 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/mod.rs @@ -10,10 +10,12 @@ pub use call_constraint::{ }; use emmylua_parser::LuaAstNode; use emmylua_parser::LuaExpr; +pub(crate) use instantiate_type::collect_callable_overload_groups; pub use instantiate_type::*; use rowan::NodeOrToken; pub use tpl_context::TplContext; pub use tpl_pattern::tpl_pattern_match_args; +pub use tpl_pattern::tpl_pattern_match_args_skip_unknown; pub use type_substitutor::TypeSubstitutor; use crate::DbIndex; diff --git a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/lambda_tpl_pattern.rs b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/lambda_tpl_pattern.rs index 4b666a0a3..7b21b34cb 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/lambda_tpl_pattern.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/lambda_tpl_pattern.rs @@ -1,7 +1,5 @@ -use emmylua_parser::LuaAstNode; - use crate::{ - InferFailReason, LuaFunctionType, LuaSignatureId, TplContext, + InferFailReason, LuaFunctionType, LuaSignatureId, LuaType, TplContext, infer_expr, semantic::generic::tpl_pattern::TplPatternMatchResult, }; @@ -12,14 +10,14 @@ pub fn check_lambda_tpl_pattern( ) -> TplPatternMatchResult { let call_expr = context.call_expr.clone().ok_or(InferFailReason::None)?; let call_arg_list = call_expr.get_args_list().ok_or(InferFailReason::None)?; - let closure_position = signature_id.get_position(); - let closure_expr = call_arg_list - .get_args() - .find(|arg| arg.get_position() == closure_position); - - if closure_expr.is_none() { - return Err(InferFailReason::UnResolveSignatureReturn(signature_id)); + for arg in call_arg_list.get_args() { + if let Ok(LuaType::Signature(arg_signature_id)) = + infer_expr(context.db, context.cache, arg.clone()) + && arg_signature_id == signature_id + { + return Ok(()); + } } - Ok(()) + Err(InferFailReason::UnResolveSignatureReturn(signature_id)) } diff --git a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs index 88e695057..3339a2dce 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs @@ -32,6 +32,23 @@ pub fn tpl_pattern_match_args( context: &mut TplContext, func_param_types: &[LuaType], call_arg_types: &[LuaType], +) -> TplPatternMatchResult { + tpl_pattern_match_args_inner(context, func_param_types, call_arg_types, false) +} + +pub fn tpl_pattern_match_args_skip_unknown( + context: &mut TplContext, + func_param_types: &[LuaType], + call_arg_types: &[LuaType], +) -> TplPatternMatchResult { + tpl_pattern_match_args_inner(context, func_param_types, call_arg_types, true) +} + +fn tpl_pattern_match_args_inner( + context: &mut TplContext, + func_param_types: &[LuaType], + call_arg_types: &[LuaType], + skip_unknown_tpl: bool, ) -> TplPatternMatchResult { for i in 0..func_param_types.len() { if i >= call_arg_types.len() { @@ -54,6 +71,9 @@ pub fn tpl_pattern_match_args( )?; break; } + _ if skip_unknown_tpl + && func_param_type.contain_tpl() + && (call_arg_type.is_any() || call_arg_type.is_unknown()) => {} _ => { tpl_pattern_match(context, func_param_type, call_arg_type)?; } diff --git a/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs index 3fd266d29..147af1681 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs @@ -1,12 +1,14 @@ use std::sync::Arc; use emmylua_parser::{LuaAstNode, LuaCallExpr, LuaExpr, LuaSyntaxKind}; +use hashbrown::HashSet; use rowan::TextRange; use super::{ super::{InferGuard, LuaInferCache, instantiate_type_generic, resolve_signature}, InferFailReason, InferResult, }; +use crate::semantic::overload_resolve::callable_accepts_args; use crate::{ CacheEntry, DbIndex, InFiled, LuaFunctionType, LuaGenericType, LuaInstanceType, LuaIntersectionType, LuaOperatorMetaMethod, LuaOperatorOwner, LuaSignature, LuaSignatureId, @@ -15,7 +17,10 @@ use crate::{ use crate::{ InferGuardRef, semantic::{ - generic::{TypeSubstitutor, get_tpl_ref_extend_type, instantiate_doc_function}, + generic::{ + TypeSubstitutor, collect_callable_overload_groups, get_tpl_ref_extend_type, + instantiate_doc_function, + }, infer::narrow::get_type_at_call_expr_inline_cast, }, }; @@ -99,12 +104,16 @@ pub fn infer_call_expr_func( ), LuaType::Union(union) => { // 此时我们将其视为泛型实例化联合体 - if union - .into_vec() + let union_types = union.into_vec(); + if union_types .iter() .all(|t| matches!(t, LuaType::DocFunction(_))) { - infer_generic_doc_function_union(db, cache, union, call_expr.clone(), args_count) + let mut overloads = Vec::new(); + for typ in union_types { + overloads.extend(collect_callable_overloads(db, &typ)?); + } + resolve_filtered_overloads(db, cache, overloads, call_expr.clone(), args_count) } else { infer_union(db, cache, union, call_expr.clone(), args_count) } @@ -191,23 +200,67 @@ fn infer_doc_function( Ok(func.clone().into()) } -fn infer_generic_doc_function_union( +fn filter_callable_overloads_by_call_args( db: &DbIndex, cache: &mut LuaInferCache, - union: &LuaUnionType, - call_expr: LuaCallExpr, + overloads: Vec>, + call_expr: &LuaCallExpr, args_count: Option, -) -> InferCallFuncResult { - let overloads = union - .into_vec() - .iter() - .filter_map(|typ| match typ { - LuaType::DocFunction(f) => Some(f.clone()), - _ => None, +) -> Result>, InferFailReason> { + let args = call_expr.get_args_list().ok_or(InferFailReason::None)?; + let expr_types = super::infer_expr_list_types( + db, + cache, + &args.get_args().collect::>(), + args_count, + |db, cache, expr| Ok(infer_expr(db, cache, expr).unwrap_or(LuaType::Unknown)), + )? + .into_iter() + .map(|(ty, _)| ty) + .collect::>(); + let is_colon_call = call_expr.is_colon_call(); + + Ok(overloads + .into_iter() + .filter(|func| { + let mut callable_tpls = HashSet::new(); + func.visit_type(&mut |ty| match ty { + LuaType::TplRef(generic_tpl) | LuaType::ConstTplRef(generic_tpl) => { + callable_tpls.insert(generic_tpl.get_tpl_id()); + } + LuaType::StrTplRef(str_tpl) => { + callable_tpls.insert(str_tpl.get_tpl_id()); + } + _ => {} + }); + + if callable_tpls.is_empty() { + return true; + } + + let mut substitutor = TypeSubstitutor::new(); + substitutor.add_need_infer_tpls(callable_tpls); + let match_func = match instantiate_doc_function(db, func, &substitutor) { + LuaType::DocFunction(doc_func) => doc_func, + _ => func.clone(), + }; + + callable_accepts_args(db, &match_func, &expr_types, is_colon_call, args_count) }) - .collect::>(); + .collect()) +} - resolve_signature(db, cache, overloads, call_expr.clone(), false, args_count) +fn resolve_filtered_overloads( + db: &DbIndex, + cache: &mut LuaInferCache, + overloads: Vec>, + call_expr: LuaCallExpr, + args_count: Option, +) -> InferCallFuncResult { + let contains_tpl = overloads.iter().any(|func| func.contain_tpl()); + let overloads = + filter_callable_overloads_by_call_args(db, cache, overloads, &call_expr, args_count)?; + resolve_signature(db, cache, overloads, call_expr, contains_tpl, args_count) } fn infer_signature_doc_function( @@ -225,40 +278,16 @@ fn infer_signature_doc_function( return Err(InferFailReason::UnResolveSignatureReturn(signature_id)); } let is_generic = signature_is_generic(db, cache, &signature, &call_expr).unwrap_or(false); - let overloads = &signature.overloads; - if overloads.is_empty() { - let mut fake_doc_function = LuaFunctionType::new( - signature.async_state, - signature.is_colon_define, - signature.is_vararg, - signature.get_type_params(), - signature.get_return_type(), - ); - if is_generic { - fake_doc_function = instantiate_func_generic(db, cache, &fake_doc_function, call_expr)?; - } + let overloads = collect_callable_overloads(db, &LuaType::Signature(signature_id))?; - Ok(fake_doc_function.into()) - } else { - let mut new_overloads = signature.overloads.clone(); - let fake_doc_function = Arc::new(LuaFunctionType::new( - signature.async_state, - signature.is_colon_define, - signature.is_vararg, - signature.get_type_params(), - signature.get_return_type(), - )); - new_overloads.push(fake_doc_function); - - resolve_signature( - db, - cache, - new_overloads, - call_expr.clone(), - is_generic, - args_count, - ) - } + resolve_signature( + db, + cache, + overloads, + call_expr.clone(), + is_generic, + args_count, + ) } fn infer_type_doc_function( @@ -486,73 +515,37 @@ fn infer_union( args_count: Option, ) -> InferCallFuncResult { // 此时一般是 signature + doc_function 的联合体 - let mut all_overloads = Vec::new(); - let mut base_signatures = Vec::new(); - + let mut overload_groups = Vec::new(); for ty in union.into_vec() { - match ty { - LuaType::Signature(signature_id) => { - if let Some(signature) = db.get_signature_index().get(&signature_id) { - // 处理 overloads - let overloads = if signature.is_generic() { - signature - .overloads - .iter() - .map(|func| { - Ok(Arc::new(instantiate_func_generic( - db, - cache, - func, - call_expr.clone(), - )?)) - }) - .collect::, _>>()? - } else { - signature.overloads.clone() - }; - all_overloads.extend(overloads); - - // 处理 signature 本身的函数类型 - let mut fake_doc_function = LuaFunctionType::new( - signature.async_state, - signature.is_colon_define, - signature.is_vararg, - signature.get_type_params(), - signature.get_return_type(), - ); - if signature.is_generic() { - fake_doc_function = instantiate_func_generic( - db, - cache, - &fake_doc_function, - call_expr.clone(), - )?; - } - base_signatures.push(Arc::new(fake_doc_function)); - } - } - LuaType::DocFunction(func) => { - let func_to_push = if func.contain_tpl() { - Arc::new(instantiate_func_generic( - db, - cache, - &func, - call_expr.clone(), - )?) - } else { - func.clone() - }; - base_signatures.push(func_to_push); - } - _ => {} + collect_callable_overload_groups(db, &ty, &mut overload_groups)?; + } + + let mut overloads = Vec::new(); + let mut base_signatures = Vec::new(); + for mut group in overload_groups { + if group.len() > 1 { + let base = group.pop().expect("signature group contains base function"); + overloads.extend(group); + base_signatures.push(base); + } else { + base_signatures.extend(group); } } - all_overloads.extend(base_signatures); - if all_overloads.is_empty() { + overloads.extend(base_signatures); + if overloads.is_empty() { return Err(InferFailReason::None); } - resolve_signature(db, cache, all_overloads, call_expr, false, args_count) + resolve_filtered_overloads(db, cache, overloads, call_expr, args_count) +} + +fn collect_callable_overloads( + db: &DbIndex, + callable_type: &LuaType, +) -> Result>, InferFailReason> { + let mut overload_groups = Vec::new(); + collect_callable_overload_groups(db, callable_type, &mut overload_groups)?; + Ok(overload_groups.into_iter().flatten().collect()) } fn infer_intersection( @@ -841,4 +834,35 @@ mod tests { assert_eq!(ws.expr_ty("ok"), ws.ty("boolean")); assert_eq!(ws.expr_ty("payload"), ws.ty("string")); } + + #[test] + fn test_union_call_ignores_unresolved_alias_member() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@type MissingAlias | fun(): integer + local run + + result = run() + "#, + ); + + assert_eq!(ws.expr_ty("result"), ws.ty("integer")); + } + + #[test] + fn test_union_call_breaks_recursive_alias_cycle() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias A A | fun(): integer + ---@type A + local run + + result = run() + "#, + ); + + assert_eq!(ws.expr_ty("result"), ws.ty("integer")); + } } diff --git a/crates/emmylua_code_analysis/src/semantic/overload_resolve/mod.rs b/crates/emmylua_code_analysis/src/semantic/overload_resolve/mod.rs index bb8344e41..11d3d3595 100644 --- a/crates/emmylua_code_analysis/src/semantic/overload_resolve/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/overload_resolve/mod.rs @@ -16,7 +16,7 @@ use super::{ infer::{InferCallFuncResult, InferFailReason}, }; -use resolve_signature_by_args::resolve_signature_by_args; +pub(crate) use resolve_signature_by_args::{callable_accepts_args, resolve_signature_by_args}; pub fn resolve_signature( db: &DbIndex, diff --git a/crates/emmylua_code_analysis/src/semantic/overload_resolve/resolve_signature_by_args.rs b/crates/emmylua_code_analysis/src/semantic/overload_resolve/resolve_signature_by_args.rs index 00a26369e..2fda0c1a4 100644 --- a/crates/emmylua_code_analysis/src/semantic/overload_resolve/resolve_signature_by_args.rs +++ b/crates/emmylua_code_analysis/src/semantic/overload_resolve/resolve_signature_by_args.rs @@ -6,6 +6,33 @@ use crate::{ semantic::infer::InferCallFuncResult, }; +pub(crate) fn callable_accepts_args( + db: &DbIndex, + func: &LuaFunctionType, + expr_types: &[LuaType], + is_colon_call: bool, + arg_count: Option, +) -> bool { + let arg_count = arg_count.unwrap_or(expr_types.len()); + if func.get_params().len() < arg_count && !is_func_last_param_variadic(func) { + return false; + } + + for (arg_index, expr_type) in expr_types.iter().enumerate() { + let param_type = match get_call_arg_param(func, arg_index, is_colon_call) { + CallArgParam::Skip => continue, + CallArgParam::Present { param_type, .. } => param_type, + CallArgParam::Missing => return false, + }; + + if !param_type.is_any() && check_type_compact(db, ¶m_type, expr_type).is_err() { + return false; + } + } + + true +} + pub fn resolve_signature_by_args( db: &DbIndex, overloads: &[Arc], @@ -43,41 +70,22 @@ pub fn resolve_signature_by_args( None => continue, Some(func) => func, }; - let param_len = func.get_params().len(); - if param_len < arg_count && !is_func_last_param_variadic(func) { + if func.get_params().len() < arg_count && !is_func_last_param_variadic(func) { *opt_func = None; continue; } - let colon_define = func.is_colon_define(); - let mut param_index = arg_index; - match (colon_define, is_colon_call) { - (true, false) => { - if param_index == 0 { - continue; - } - param_index -= 1; - } - (false, true) => { - param_index += 1; - } - _ => {} - } - let param_type = if param_index < param_len { - let param_info = func.get_params().get(param_index); - param_info - .map(|it| it.1.clone().unwrap_or(LuaType::Any)) - .unwrap_or(LuaType::Any) - } else if let Some(last_param_info) = func.get_params().last() { - if last_param_info.0 == "..." { - last_param_info.1.clone().unwrap_or(LuaType::Any) - } else { + let (param_index, param_type) = match get_call_arg_param(func, arg_index, is_colon_call) + { + CallArgParam::Skip => continue, + CallArgParam::Present { + param_index, + param_type, + } => (param_index, param_type), + CallArgParam::Missing => { *opt_func = None; continue; } - } else { - *opt_func = None; - continue; }; let match_result = if param_type.is_any() { @@ -144,35 +152,15 @@ pub fn resolve_signature_by_args( None => continue, Some(func) => func, }; - let param_len = func.get_params().len(); - let colon_define = func.is_colon_define(); - let mut param_index = param_index; - match (colon_define, is_colon_call) { - (true, false) => { - if param_index == 0 { - continue; - } - param_index -= 1; - } - (false, true) => { - param_index += 1; - } - _ => {} - } - let param_type = if param_index < param_len { - let param_info = func.get_params().get(param_index); - param_info - .map(|it| it.1.clone().unwrap_or(LuaType::Any)) - .unwrap_or(LuaType::Any) - } else if let Some(last_param_info) = func.get_params().last() { - if last_param_info.0 == "..." { - last_param_info.1.clone().unwrap_or(LuaType::Any) - } else { - return Ok(func.clone()); - } - } else { - return Ok(func.clone()); - }; + let (param_index, param_type) = + match get_call_arg_param(func, param_index, is_colon_call) { + CallArgParam::Skip => continue, + CallArgParam::Present { + param_index, + param_type, + } => (param_index, param_type), + CallArgParam::Missing => return Ok(func.clone()), + }; let match_result = if param_type.is_any() { ParamMatchResult::Any @@ -208,6 +196,44 @@ pub fn resolve_signature_by_args( Ok(best_match_result) } +fn get_call_arg_param( + func: &LuaFunctionType, + arg_index: usize, + is_colon_call: bool, +) -> CallArgParam { + let mut param_index = arg_index; + match (func.is_colon_define(), is_colon_call) { + (true, false) => { + if param_index == 0 { + return CallArgParam::Skip; + } + param_index -= 1; + } + (false, true) => { + param_index += 1; + } + _ => {} + } + + if let Some((_, ty)) = func.get_params().get(param_index) { + return CallArgParam::Present { + param_index, + param_type: ty.clone().unwrap_or(LuaType::Any), + }; + } + + if let Some((name, ty)) = func.get_params().last() + && name == "..." + { + return CallArgParam::Present { + param_index, + param_type: ty.clone().unwrap_or(LuaType::Any), + }; + } + + CallArgParam::Missing +} + fn is_func_last_param_variadic(func: &LuaFunctionType) -> bool { if let Some(last_param) = func.get_params().last() { last_param.0 == "..." @@ -222,3 +248,12 @@ enum ParamMatchResult { Any, Type, } + +enum CallArgParam { + Skip, + Present { + param_index: usize, + param_type: LuaType, + }, + Missing, +}