Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -370,13 +370,6 @@ fn infer_binary_type(analyzer: &mut DocAnalyzer, binary_type: &LuaDocBinaryType)
if let Some((left, right)) = binary_type.get_types() {
let left_type = infer_type(analyzer, left);
let right_type = infer_type(analyzer, right);
if left_type.is_unknown() {
return right_type;
}
if right_type.is_unknown() {
return left_type;
}

if let Some(op) = binary_type.get_op_token() {
match op.get_op() {
LuaTypeBinaryOperator::Union => match (left_type, right_type) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,31 +208,33 @@ fn analyze_lambda_returns(
pub fn analyze_return_point(
db: &DbIndex,
cache: &mut LuaInferCache,
return_points: &Vec<LuaReturnPoint>,
return_points: &[LuaReturnPoint],
) -> Result<Vec<LuaDocReturnInfo>, InferFailReason> {
let mut return_type = LuaType::Unknown;
let mut return_type = None;
for point in return_points {
match point {
LuaReturnPoint::Expr(expr) => {
let expr_type = infer_expr(db, cache, expr.clone())?;
return_type = union_return_expr(db, return_type, expr_type);
}
let point_type = match point {
LuaReturnPoint::Expr(expr) => Some(infer_expr(db, cache, expr.clone())?),
LuaReturnPoint::MuliExpr(exprs) => {
let mut multi_return = vec![];
let mut multi_return = Vec::with_capacity(exprs.len());
for expr in exprs {
let expr_type = infer_expr(db, cache, expr.clone())?;
multi_return.push(expr_type);
multi_return.push(infer_expr(db, cache, expr.clone())?);
}
let typ = LuaType::Variadic(VariadicType::Multi(multi_return).into());
return_type = union_return_expr(db, return_type, typ);
}
LuaReturnPoint::Nil => {
return_type = union_return_expr(db, return_type, LuaType::Nil);
Some(LuaType::Variadic(VariadicType::Multi(multi_return).into()))
}
_ => {}
LuaReturnPoint::Nil => Some(LuaType::Nil),
_ => None,
};

if let Some(point_type) = point_type {
return_type = Some(match return_type {
Some(return_type) => union_return_expr(db, return_type, point_type),
None => point_type,
});
}
}

let return_type = return_type.unwrap_or(LuaType::Unknown);

Ok(vec![LuaDocReturnInfo {
type_ref: return_type,
description: None,
Expand All @@ -242,10 +244,6 @@ pub fn analyze_return_point(
}

fn union_return_expr(db: &DbIndex, left: LuaType, right: LuaType) -> LuaType {
if left == LuaType::Unknown {
return right;
}

match (&left, &right) {
(LuaType::Variadic(left_variadic), LuaType::Variadic(right_variadic)) => {
match (&left_variadic.deref(), &right_variadic.deref()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,15 +205,15 @@ fn find_custom_type_function_member(
let key = LuaMemberKey::from_index_key(db, cache, &index_key)?;
if let Some(member_item) = db.get_member_index().get_member_item(&owner, &key) {
let index_member_id = get_member_id(cache, &index_expr);
let mut result_type = LuaType::Unknown;
let mut result_type = LuaType::Never;
for member_id in member_item.get_member_ids() {
if index_member_id != member_id
&& let Some(type_cache) = db.get_type_index().get_type_cache(&member_id.into())
{
result_type = TypeOps::Union.apply(db, &result_type, type_cache.as_type());
}
}
if !result_type.is_unknown() {
if !result_type.is_never() {
return Ok(result_type);
}
}
Expand Down Expand Up @@ -540,7 +540,7 @@ fn find_member_by_index_table(
.get_member_index()
.get_members(&LuaMemberOwner::Element(table_range.clone()));
if let Some(members) = members {
let mut result_type = LuaType::Unknown;
let mut result_type = LuaType::Never;
for member in members {
let member_key_type = match member.get_key() {
LuaMemberKey::Name(s) => LuaType::StringConst(s.clone().into()),
Expand All @@ -558,7 +558,7 @@ fn find_member_by_index_table(
}
}

if !result_type.is_unknown() {
if !result_type.is_never() {
if matches!(
key_type,
LuaType::String | LuaType::Number | LuaType::Integer
Expand Down Expand Up @@ -698,7 +698,7 @@ fn find_member_by_index_union(
infer_guard: &InferGuardRef,
deep_guard: &mut DeepLevel,
) -> FunctionTypeResult {
let mut member_type = LuaType::Unknown;
let mut member_type = LuaType::Never;
for member in union.into_vec() {
let result = find_function_type_by_operator(
db,
Expand All @@ -719,7 +719,7 @@ fn find_member_by_index_union(
}
}

if member_type.is_unknown() {
if member_type.is_never() {
return Err(InferFailReason::FieldNotFound);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,8 @@ fn resolve_closure_member_type(
.get(&closure_params.signature_id)
.ok_or(InferFailReason::None)?;
let mut final_params = signature.get_type_params().to_vec();
let mut final_ret = LuaType::Unknown;
let mut final_ret = LuaType::Never;
let mut has_final_ret = false;

let mut multi_function_type = Vec::new();
for typ in union_types.into_vec() {
Expand Down Expand Up @@ -369,17 +370,21 @@ fn resolve_closure_member_type(

break;
}
let new_type = TypeOps::Union.apply(
db,
final_param.1.as_ref().unwrap_or(&LuaType::Unknown),
param.1.as_ref().unwrap_or(&LuaType::Unknown),
);
final_params[idx] = (final_param.0.clone(), Some(new_type));
let new_type = match (&final_param.1, &param.1) {
(Some(final_type), Some(param_type)) => {
Some(TypeOps::Union.apply(db, final_type, param_type))
}
(Some(final_type), None) => Some(final_type.clone()),
(None, Some(param_type)) => Some(param_type.clone()),
(None, None) => None,
};
final_params[idx] = (final_param.0.clone(), new_type);
} else {
final_params.push((param.0.clone(), param.1.clone()));
}
}

has_final_ret = true;
final_ret = TypeOps::Union.apply(db, &final_ret, doc_func.get_ret());
}

Expand All @@ -389,6 +394,12 @@ fn resolve_closure_member_type(
param.1 = Some(variadic_type);
}

let final_ret = if !has_final_ret {
LuaType::Unknown
} else {
final_ret
};

resolve_doc_function(
db,
closure_params,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,4 +80,49 @@ mod test {
"#,
));
}

#[test]
fn test_inferred_return_preserves_never() {
let mut ws = VirtualWorkspace::new();

ws.def(
r#"
---@return { y: number } & { y: string }
local function impossible() end

local function f()
return impossible().y
end

result = f()
"#,
);

assert_eq!(ws.expr_ty("result"), ws.ty("never"));
}

#[test]
fn test_member_doc_return_preserves_never() {
let mut ws = VirtualWorkspace::new();

ws.def(
r#"
---@return { y: number } & { y: string }
local function impossible() end

---@class ClosureTest
---@field e fun(): never
---@field e fun(): never
local Test

function Test.e()
return impossible().y
end

result = Test.e()
"#,
);

assert_eq!(ws.expr_ty("result"), ws.ty("never"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -324,4 +324,48 @@ mod test {
value_ty
);
}

#[test]
fn test_union_member_access_preserves_never() {
let mut ws = VirtualWorkspace::new();

ws.def(
r#"
---@class A
---@field y never

---@class B
---@field y never

---@return A|B
local function make() end

local value = make()

result = value.y
"#,
);

assert_eq!(ws.expr_ty("result"), ws.ty("never"));
}

#[test]
fn test_table_expr_index_preserves_never() {
let mut ws = VirtualWorkspace::new();

ws.def(
r#"
---@return { y: number } & { y: string }
local function impossible() end

local t = {
a = impossible().y,
}

result = t["a"]
"#,
);

assert_eq!(ws.expr_ty("result"), ws.ty("never"));
}
}
24 changes: 24 additions & 0 deletions crates/emmylua_code_analysis/src/compilation/test/pcall_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,4 +175,28 @@ mod test {
);
assert_eq!(ws.expr_ty("narrowed"), ws.ty("integer"));
}

#[test]
fn test_pcall_any_callable_splits_success_unknown_and_failure_string() {
let mut ws = VirtualWorkspace::new_with_init_std_lib();

ws.def(
r#"
---@type any
local x

local ok, result = pcall(x)
outside = result
if ok then
success = result
else
failure = result
end
"#,
);

assert_eq!(ws.expr_ty("outside"), ws.ty("unknown|string"));
assert_eq!(ws.expr_ty("success"), ws.ty("unknown"));
assert_eq!(ws.expr_ty("failure"), ws.ty("string"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,71 @@ mod test {
assert_eq!(ws.expr_ty("d"), ws.ty("boolean"));
}

#[test]
fn test_return_overload_narrow_with_overlapping_target_union() {
let mut ws = VirtualWorkspace::new();

ws.def(
r#"
---@param ok boolean
---@return boolean
---@return string|number
---@return_overload true, string
---@return_overload false, string|number
local function pick(ok)
if ok then
return true, "value"
end
return false, 1
end

local cond ---@type boolean
local ok, result = pick(cond)

if ok then
success_branch = result
else
failure_branch = result
end
"#,
);

assert_eq!(ws.expr_ty("success_branch"), ws.ty("string"));
assert_eq!(ws.expr_ty("failure_branch"), ws.ty("string|number"));
}

#[test]
fn test_return_overload_narrow_with_overlapping_supertype_target() {
let mut ws = VirtualWorkspace::new();

ws.def(
r#"
---@param ok boolean
---@return boolean
---@return number
---@return_overload true, integer
---@return_overload false, number
local function pick(ok)
if ok then
return true, 1
end
return false, 1.5
end

local cond ---@type boolean
local ok, result = pick(cond)

if ok then
success_branch = result
else
failure_branch = result
end
"#,
);

assert_eq!(ws.expr_ty("success_branch"), ws.ty("integer"));
}

#[test]
fn test_return_overload_reassign_clears_multi_return_mapping() {
let mut ws = VirtualWorkspace::new();
Expand Down
Loading
Loading