Skip to content

Commit 2873cca

Browse files
committed
infer parameter type by return
resolve #1202
1 parent 96b3202 commit 2873cca

File tree

3 files changed

+35
-1
lines changed

3 files changed

+35
-1
lines changed

changelog.md

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ server will generate `doc.json` and `doc.md` in `LOGPATH`.
2828
---@type myClass
2929
local class
3030

31-
print(class.a.b.c.e.f.g) --> infered as integer
31+
print(class.a.b.c.e.f.g) --> inferred as integer
3232
```
3333
* `CHG` [#1582] the following diagnostics consider `overload`
3434
* `missing-return`
@@ -58,6 +58,14 @@ server will generate `doc.json` and `doc.md` in `LOGPATH`.
5858

5959
local arr = x(cb) --> `arr` is inferred as `integer[]`
6060
```
61+
* `CHG` [#1202] infer parameter type by expected returned function of parent function
62+
```lua
63+
---@return fun(x: integer)
64+
local function f()
65+
return function (x) --> `x` is inferred as `integer`
66+
end
67+
end
68+
```
6169
* `FIX` [#1567]
6270
* `FIX` [#1593]
6371
* `FIX` [#1595]
@@ -70,6 +78,7 @@ server will generate `doc.json` and `doc.md` in `LOGPATH`.
7078

7179
[#1153]: https://github.com/sumneko/lua-language-server/issues/1153
7280
[#1177]: https://github.com/sumneko/lua-language-server/issues/1177
81+
[#1202]: https://github.com/sumneko/lua-language-server/issues/1202
7382
[#1458]: https://github.com/sumneko/lua-language-server/issues/1458
7483
[#1557]: https://github.com/sumneko/lua-language-server/issues/1557
7584
[#1558]: https://github.com/sumneko/lua-language-server/issues/1558

script/vm/compiler.lua

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1059,6 +1059,7 @@ local function compileLocal(source)
10591059
end
10601060
if source.parent.type == 'funcargs' and not hasMarkDoc and not hasMarkParam then
10611061
local func = source.parent.parent
1062+
-- local call ---@type fun(f: fun(x: number));call(function (x) end) --> x -> number
10621063
local funcNode = vm.compileNode(func)
10631064
local hasDocArg
10641065
for n in funcNode:eachObject() do
@@ -1158,6 +1159,22 @@ local compilerSwitch = util.switch()
11581159
local call = source.parent.parent
11591160
vm.compileCallArg(source, call)
11601161
end
1162+
1163+
-- function f() return function (<?x?>) end end
1164+
if source.parent.type == 'return' then
1165+
for i, ret in ipairs(source.parent) do
1166+
if ret == source then
1167+
local func = guide.getParentFunction(source.parent)
1168+
if func then
1169+
local returnObj = vm.getReturnOfFunction(func, i)
1170+
if returnObj then
1171+
vm.setNode(source, vm.compileNode(returnObj))
1172+
end
1173+
end
1174+
break
1175+
end
1176+
end
1177+
end
11611178
end)
11621179
: case 'paren'
11631180
: call(function (source)

test/type_inference/init.lua

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3858,3 +3858,11 @@ local cb
38583858
38593859
local <?arr?> = x(cb)
38603860
]]
3861+
3862+
TEST 'integer' [[
3863+
---@return fun(x: integer)
3864+
local function f()
3865+
return function (<?x?>)
3866+
end
3867+
end
3868+
]]

0 commit comments

Comments
 (0)