Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -85,4 +85,243 @@ mod test {

assert_eq!(ws.expr_ty("result"), ws.ty("string"));
}

#[test]
fn test_apply_return_infer_prefers_informative_callback_overloads_and_keeps_param_shape() {
let mut ws = VirtualWorkspace::new();
ws.def(
r#"
---@generic A, R
---@param f fun(x: A): R
---@param x A
---@return R
local function apply(f, x)
return f(x)
end

---@overload fun<T>(cb: fun(): T): T
---@overload fun(cb: function): boolean
local function run(cb) end

---@overload fun<T>(cb: fun(x: integer): T): integer
---@overload fun<T>(cb: fun(x: string): T): string
---@overload fun(cb: function): boolean
local function classify(cb) end

---@overload fun<T>(cb: fun(): T): { value: T }
---@overload fun(cb: function): boolean
local function wrap(cb) end

local source ---@type table

---@return integer
local function cb_concrete()
return 1
end

---@type fun(): unknown
local cb_unknown

---@type fun(x: integer): unknown
local cb_param_unknown

---@type fun(x: string): unknown
local cb_param_unknown_string

---@param x integer
local function cb_param_unresolved(x)
return source.missing
end

---@param x string
local function cb_param_unresolved_string(x)
return source.missing
end

local function cb_named_unresolved()
return source.missing
end

run_concrete = apply(run, cb_concrete)

run_unknown = apply(run, cb_unknown)

run_unresolved = apply(run, function()
return source.missing
end)

run_named_unresolved = apply(run, cb_named_unresolved)

wrap_named_unresolved = apply(wrap, cb_named_unresolved)

classify_unknown = apply(classify, cb_param_unknown)

classify_unresolved = apply(classify, cb_param_unresolved)

classify_string_unknown = apply(classify, cb_param_unknown_string)

classify_string_unresolved = apply(classify, cb_param_unresolved_string)
"#,
);

// The callback return is concrete, so `T` is inferred as `integer` and the generic
// overload is more informative than the `function -> boolean` fallback.
assert_eq!(ws.expr_ty("run_concrete"), ws.ty("integer"));

// `fun(): unknown` matches the generic parameter shape, but its return stays opaque, so
// the concrete `function -> boolean` fallback is the best available result.
assert_eq!(ws.expr_ty("run_unknown"), ws.ty("boolean"));

// An unresolved closure return should be treated the same as an explicit `unknown`
// return, so this should resolve to the same fallback `boolean`.
assert_eq!(ws.expr_ty("run_unresolved"), ws.ty("boolean"));

// The named-callback path should detect the unresolved closure signature too, so this
// should stay aligned with the inline unresolved callback case above.
assert_eq!(ws.expr_ty("run_named_unresolved"), ws.ty("boolean"));

// If the instantiated return still embeds the unresolved template, the generic overload
// should be discarded and the concrete fallback should win.
assert_eq!(ws.expr_ty("wrap_named_unresolved"), ws.ty("boolean"));

// The callback's parameter type is known, so the generic `fun(x: integer): T` overload
// should still win even though the callback return is only `unknown`.
assert_eq!(ws.expr_ty("classify_unknown"), ws.ty("integer"));

// The callback return is unresolved, but its parameter is still `integer`, so overload
// ranking should keep using that known shape and pick the generic integer branch.
assert_eq!(ws.expr_ty("classify_unresolved"), ws.ty("integer"));

// The callback's parameter type is `string`, so overload selection should not fall back
// to the first generic branch when the callback return is only `unknown`.
assert_eq!(ws.expr_ty("classify_string_unknown"), ws.ty("string"));

// The same `string`-parameter branch should still win when the callback return is
// unresolved and carried through a named callback value.
assert_eq!(ws.expr_ty("classify_string_unresolved"), ws.ty("string"));
}

#[test]
fn test_apply_return_infer_leaves_result_unknown_when_no_callable_member_matches_arg_shape() {
let mut ws = VirtualWorkspace::new();
ws.def(
r#"
---@generic A, R
---@param f fun(x: A): R
---@param x A
---@return R
local function apply(f, x)
return f(x)
end

---@alias FnInt fun(x: integer): integer
---@alias FnString fun(x: string): string

---@type FnInt | FnString
local run

---@type boolean
local b

result = apply(run, b)
"#,
);

let result_ty = ws.expr_ty("result");
assert_eq!(result_ty, ws.ty("unknown"));
}

#[test]
fn test_apply_return_infer_keeps_only_arity_compatible_fallbacks() {
let mut ws = VirtualWorkspace::new();
ws.def(
r#"
---@generic A, B, R
---@param f fun(x: A, y: B): R
---@param x A
---@param y B
---@return R
local function apply2(f, x, y)
return f(x, y)
end

---@overload fun(x: integer): integer
---@param x integer
---@param y string
---@return string
local function run(x, y) end

local source ---@type table

result = apply2(run, 1, source.missing)
"#,
);

let result_ty = ws.expr_ty("result");
assert_eq!(ws.humanize_type(result_ty), "string");
}

#[test]
fn test_apply_return_infer_keeps_same_arity_overload_returns_when_tail_is_unknown() {
let mut ws = VirtualWorkspace::new();
ws.def(
r#"
---@generic A, B, R
---@param f fun(x: A, y: B): R
---@param x A
---@param y B
---@return R
local function apply2(f, x, y)
return f(x, y)
end

---@overload fun(x: integer, y: number): number
---@param x integer
---@param y string
---@return string
local function run(x, y) end

local source ---@type table

result = apply2(run, 1, source.missing)
"#,
);

let result_ty = ws.expr_ty("result");
assert_eq!(result_ty, ws.ty("number|string"));
}

#[test]
fn test_union_call_ignores_non_matching_generic_callable_member() {
let mut ws = VirtualWorkspace::new();
ws.def(
r#"
---@type (fun<T: string>(x: T): T) | fun(x: integer): integer
local run

result = run(1)
"#,
);

let result_ty = ws.expr_ty("result");
assert_eq!(ws.humanize_type(result_ty), "integer");
}

#[test]
fn test_union_call_ignores_non_matching_generic_alias_member() {
let mut ws = VirtualWorkspace::new();
ws.def(
r#"
---@alias GenericStr<T: string> fun(x: T): T

---@type GenericStr | fun(x: integer): integer
local run

result = run(1)
"#,
);

let result_ty = ws.expr_ty("result");
assert_eq!(ws.humanize_type(result_ty), "integer");
}
}
106 changes: 105 additions & 1 deletion crates/emmylua_code_analysis/src/compilation/test/pcall_test.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#[cfg(test)]
mod test {
use crate::{DiagnosticCode, VirtualWorkspace};
use crate::{DiagnosticCode, LuaType, VirtualWorkspace};

const STACKED_PCALL_ALIAS_GUARDS: usize = 180;

Expand Down Expand Up @@ -199,4 +199,108 @@ mod test {
assert_eq!(ws.expr_ty("success"), ws.ty("unknown"));
assert_eq!(ws.expr_ty("failure"), ws.ty("string"));
}

#[test]
fn test_return_overload_unions_callable_union_members_with_same_domain() {
let mut ws = VirtualWorkspace::new_with_init_std_lib();

ws.def(
r#"
---@alias FnA fun(x: integer): integer
---@alias FnB fun(x: integer): boolean

---@type FnA | FnB
local run

_, a = pcall(run, 1)
"#,
);

assert_eq!(
ws.expr_ty("a"),
LuaType::from_vec(vec![LuaType::Integer, LuaType::Boolean, LuaType::String])
);
}

#[test]
fn test_return_overload_callable_union_is_not_order_dependent() {
let mut ws = VirtualWorkspace::new_with_init_std_lib();

ws.def(
r#"
---@alias FnA fun(x: integer): integer
---@alias FnB fun(x: integer): boolean

---@type FnB | FnA
local run

_, a = pcall(run, 1)
"#,
);

assert_eq!(
ws.expr_ty("a"),
LuaType::from_vec(vec![LuaType::Integer, LuaType::Boolean, LuaType::String])
);
}

#[test]
fn test_return_overload_unions_callable_intersection_members_with_same_domain() {
let mut ws = VirtualWorkspace::new_with_init_std_lib();

ws.def(
r#"
---@alias FnA fun(x: integer): integer
---@alias FnB fun(x: integer): boolean

---@type FnA & FnB
local run

_, a = pcall(run, 1)
"#,
);

assert_eq!(
ws.expr_ty("a"),
LuaType::from_vec(vec![LuaType::Integer, LuaType::Boolean, LuaType::String])
);
}

#[test]
fn test_return_overload_narrows_callable_union_member_by_arg_shape() {
let mut ws = VirtualWorkspace::new_with_init_std_lib();

ws.def(
r#"
---@alias FnA fun(x: integer): integer
---@alias FnB fun(x: string): integer

---@type FnA | FnB
local run

_, a = pcall(run, 1)
"#,
);

assert_eq!(ws.expr_ty("a"), ws.ty("integer|string"));
}

#[test]
fn test_return_overload_narrows_callable_intersection_member_by_arg_shape() {
let mut ws = VirtualWorkspace::new_with_init_std_lib();

ws.def(
r#"
---@alias FnA fun(x: integer): integer
---@alias FnB fun(x: string): boolean

---@type FnA & FnB
local run

_, a = pcall(run, 1)
"#,
);

assert_eq!(ws.expr_ty("a"), ws.ty("integer|string"));
}
}
Loading
Loading