Skip to content

Commit 96b3202

Browse files
committed
infer type by function as parameters
resolve #1153
1 parent b6e7d27 commit 96b3202

File tree

3 files changed

+127
-5
lines changed

3 files changed

+127
-5
lines changed

changelog.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,18 @@ server will generate `doc.json` and `doc.md` in `LOGPATH`.
4646

4747
print(obj.initValue) --> `obj.initValue` is integer
4848
```
49+
* `CHG` [#1153] infer type by generic parameters or returns of function
50+
```lua
51+
---@generic T
52+
---@param f fun(x: T)
53+
---@return T[]
54+
local function x(f) end
55+
56+
---@type fun(x: integer)
57+
local cb
58+
59+
local arr = x(cb) --> `arr` is inferred as `integer[]`
60+
```
4961
* `FIX` [#1567]
5062
* `FIX` [#1593]
5163
* `FIX` [#1595]
@@ -56,6 +68,7 @@ server will generate `doc.json` and `doc.md` in `LOGPATH`.
5668
* `FIX` [#1640]
5769
* `FIX` [#1642]
5870

71+
[#1153]: https://github.com/sumneko/lua-language-server/issues/1153
5972
[#1177]: https://github.com/sumneko/lua-language-server/issues/1177
6073
[#1458]: https://github.com/sumneko/lua-language-server/issues/1458
6174
[#1557]: https://github.com/sumneko/lua-language-server/issues/1557

script/vm/sign.lua

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,26 @@ function mt:resolve(uri, args, removeGeneric)
2222
if not args then
2323
return nil
2424
end
25+
26+
---@type table<string, vm.node>
2527
local resolved = {}
2628

27-
---@param object vm.node.object
29+
---@param object vm.node|vm.node.object
2830
---@param node vm.node
2931
local function resolve(object, node)
32+
if object.type == 'vm.node' then
33+
for o in object:eachObject() do
34+
resolve(o, node)
35+
end
36+
return
37+
end
38+
if object.type == 'doc.type' then
39+
---@cast object parser.object
40+
resolve(vm.compileNode(object), node)
41+
return
42+
end
3043
if object.type == 'doc.generic.name' then
44+
---@type string
3145
local key = object[1]
3246
if object.literal then
3347
-- 'number' -> `T`
@@ -40,8 +54,21 @@ function mt:resolve(uri, args, removeGeneric)
4054
end
4155
else
4256
-- number -> T
43-
resolved[key] = vm.createNode(node, resolved[key])
57+
for n in node:eachObject() do
58+
if n.type ~= 'doc.generic.name'
59+
and n.type ~= 'generic' then
60+
if resolved[key] then
61+
resolved[key]:merge(n)
62+
else
63+
resolved[key] = vm.createNode(n)
64+
end
65+
end
66+
end
67+
if resolved[key] and node:isOptional() then
68+
resolved[key]:addOptional()
69+
end
4470
end
71+
return
4572
end
4673
if object.type == 'doc.type.array' then
4774
for n in node:eachObject() do
@@ -68,6 +95,7 @@ function mt:resolve(uri, args, removeGeneric)
6895
resolve(object.node, vm.compileNode(n[1]))
6996
end
7097
end
98+
return
7199
end
72100
if object.type == 'doc.type.table' then
73101
for _, ufield in ipairs(object.fields) do
@@ -105,6 +133,34 @@ function mt:resolve(uri, args, removeGeneric)
105133
end
106134
::CONTINUE::
107135
end
136+
return
137+
end
138+
if object.type == 'doc.type.function' then
139+
for i, arg in ipairs(object.args) do
140+
for n in node:eachObject() do
141+
if n.type == 'function'
142+
or n.type == 'doc.type.function' then
143+
---@cast n parser.object
144+
local farg = n.args and n.args[i]
145+
if farg then
146+
resolve(arg.extends, vm.compileNode(farg))
147+
end
148+
end
149+
end
150+
end
151+
for i, ret in ipairs(object.returns) do
152+
for n in node:eachObject() do
153+
if n.type == 'function'
154+
or n.type == 'doc.type.function' then
155+
---@cast n parser.object
156+
local fret = vm.getReturnOfFunction(n, i)
157+
if fret then
158+
resolve(ret, vm.compileNode(fret))
159+
end
160+
end
161+
end
162+
end
163+
return
108164
end
109165
end
110166

@@ -190,9 +246,7 @@ function mt:resolve(uri, args, removeGeneric)
190246
local knownTypes, genericNames = getSignInfo(sign)
191247
if not isAllResolved(genericNames) then
192248
local newArgNode = buildArgNode(argNode,sign, knownTypes)
193-
for n in sign:eachObject() do
194-
resolve(n, newArgNode)
195-
end
249+
resolve(sign, newArgNode)
196250
end
197251
end
198252

test/type_inference/init.lua

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3803,3 +3803,58 @@ local class
38033803
38043804
class.has.nested.<?fn?>()
38053805
]]
3806+
3807+
TEST 'integer[]' [[
3808+
---@generic T
3809+
---@param f fun(x: T)
3810+
---@return T[]
3811+
local function x(f) end
3812+
3813+
---@param x integer
3814+
local <?arr?> = x(function (x) end)
3815+
]]
3816+
3817+
TEST 'integer[]' [[
3818+
---@generic T
3819+
---@param f fun():T
3820+
---@return T[]
3821+
local function x(f) end
3822+
3823+
local <?arr?> = x(function ()
3824+
return 1
3825+
end)
3826+
]]
3827+
3828+
TEST 'integer[]' [[
3829+
---@generic T
3830+
---@param f fun():T
3831+
---@return T[]
3832+
local function x(f) end
3833+
3834+
---@return integer
3835+
local <?arr?> = x(function () end)
3836+
]]
3837+
3838+
TEST 'integer[]' [[
3839+
---@generic T
3840+
---@param f fun(x: T)
3841+
---@return T[]
3842+
local function x(f) end
3843+
3844+
---@type fun(x: integer)
3845+
local cb
3846+
3847+
local <?arr?> = x(cb)
3848+
]]
3849+
3850+
TEST 'integer[]' [[
3851+
---@generic T
3852+
---@param f fun():T
3853+
---@return T[]
3854+
local function x(f) end
3855+
3856+
---@type fun(): integer
3857+
local cb
3858+
3859+
local <?arr?> = x(cb)
3860+
]]

0 commit comments

Comments
 (0)