diff --git a/changelog.md b/changelog.md index 343a2e40c..a8122e5dc 100644 --- a/changelog.md +++ b/changelog.md @@ -4,6 +4,12 @@ * `CHG` Modified the `ResolveRequire` function to pass the source URI as a third argument. * `CHG` Improved the output of test failures during development +* `FIX` Fix type inference for `x == nil and "default" or x` idiom [#2236](https://github.com/LuaLS/lua-language-server/issues/2236) +* `FIX` Fix type loss for assignments inside `if`/`for` blocks due to circular dependency in tracer [#2374](https://github.com/LuaLS/lua-language-server/issues/2374) [#2494](https://github.com/LuaLS/lua-language-server/issues/2494) +* `FIX` Resolve generic class method return types for `@param self list` pattern +* `FIX` Fix `ipairs(self)` type resolution in generic class methods +* `FIX` Fix double angle brackets in generic sign display (`list<>` -> `list`) +* `FIX` Fix nil crash in `getParentClass` for `doc.field` without class ## 3.17.1 `2026-01-20` diff --git a/script/core/diagnostics/no-unknown.lua b/script/core/diagnostics/no-unknown.lua index e706931ad..51f5858fa 100644 --- a/script/core/diagnostics/no-unknown.lua +++ b/script/core/diagnostics/no-unknown.lua @@ -26,11 +26,29 @@ return function (uri, callback) guide.eachSourceTypes(ast.ast, types, function (source) await.delay() if vm.getInfer(source):view(uri) == 'unknown' then - callback { - start = source.start, - finish = source.finish, - message = lang.script('DIAG_UNKNOWN'), - } + -- When a node only contains a 'variable' object whose base + -- declaration has a known type, this is a false positive caused + -- by circular dependency during compilation, not a true unknown. + local dominated = false + local node = vm.getNode(source) + if node then + for n in node:eachObject() do + if n.type == 'variable' and n.base and n.base.value then + local baseView = vm.getInfer(n.base):view(uri) + if baseView ~= 'unknown' then + dominated = true + break + end + end + end + end + if not dominated then + callback { + start = source.start, + finish = source.finish, + message = lang.script('DIAG_UNKNOWN'), + } + end end end) end diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua index 5267a037b..19e1de0be 100644 --- a/script/vm/compiler.lua +++ b/script/vm/compiler.lua @@ -1671,7 +1671,13 @@ local function bindReturnOfFunction(source, mfunc, index, args) else local clonedObject = vm.cloneObject(nd, resolved) if clonedObject then - result:merge(vm.compileNode(clonedObject)) + if clonedObject.type == 'doc.generic.name' + and clonedObject._resolved + and vm.isResolvedToGeneric(clonedObject._resolved) then + result:merge(clonedObject) + else + result:merge(vm.compileNode(clonedObject)) + end end end end @@ -1686,13 +1692,25 @@ local function bindReturnOfFunction(source, mfunc, index, args) end end - if mfunc.type == 'function' then + if mfunc.type == 'function' or mfunc.type == 'doc.type.function' then local hasUnresolvedGeneric = false for rnode in returnNode:eachObject() do if vm.isGenericUnsolved(rnode) then hasUnresolvedGeneric = true break end + -- Also check inside doc.type.sign for unresolved generics + -- (e.g. list where T is not yet resolved) + if rnode.type == 'doc.type.sign' and rnode.signs then + guide.eachSourceType(rnode, 'doc.generic.name', function (src) + if not src._resolved then + hasUnresolvedGeneric = true + end + end) + if hasUnresolvedGeneric then + break + end + end end if hasUnresolvedGeneric then local sign = vm.getSign(mfunc) @@ -1760,6 +1778,12 @@ local function bindReturnOfFunction(source, mfunc, index, args) for rnode in returnNode:eachObject() do if rnode.type ~= 'doc.generic.name' then vm.setNode(source, rnode) + elseif rnode._resolved then + -- Allow generics that resolved to another generic type + -- parameter (e.g. V -> T in generic method's ipairs(self)). + if vm.isResolvedToGeneric(rnode._resolved) then + vm.setNode(source, rnode) + end end end if returnNode:isOptional() then diff --git a/script/vm/generic.lua b/script/vm/generic.lua index d2c75eafa..80602c277 100644 --- a/script/vm/generic.lua +++ b/script/vm/generic.lua @@ -1,5 +1,6 @@ ---@class vm local vm = require 'vm.vm' +local guide = require 'parser.guide' ---@class parser.object ---@field package _generic vm.generic @@ -130,16 +131,34 @@ local function cloneObject(source, resolved) end if source.type == 'doc.type.sign' and source.signs then local needsClone = false + -- Check if any sign parameter has a resolvable name with a concrete + -- (non-generic) resolved type. Skip cloning when the resolved value + -- is just another doc.generic.name (e.g. T -> T inside a method body), + -- which would cause double-wrapping in display (list<>). + local function hasConcreteResolution(name) + local rnode = resolved[name] + if not rnode then + return false + end + for rn in rnode:eachObject() do + if rn.type ~= 'doc.generic.name' and rn.type ~= 'generic' then + return true + end + end + return false + end for _, sign in ipairs(source.signs) do - if sign.type == 'doc.type' then - for _, tp in ipairs(sign.types) do - if tp.type == 'doc.type.name' and resolved[tp[1]] then + guide.eachSourceType(sign, 'doc.type.name', function (src) + if hasConcreteResolution(src[1]) then + needsClone = true + end + end) + if not needsClone then + guide.eachSourceType(sign, 'doc.generic.name', function (src) + if hasConcreteResolution(src[1]) then needsClone = true - break end - end - elseif sign.type == 'doc.type.name' and resolved[sign[1]] then - needsClone = true + end) end if needsClone then break end end @@ -176,8 +195,18 @@ function mt:resolve(uri, args) ---@cast nd -vm.global, -vm.variable local clonedObject = cloneObject(nd, resolved) if clonedObject then - local clonedNode = vm.compileNode(clonedObject) - result:merge(clonedNode) + -- When a generic resolves to another generic (e.g. V -> T + -- inside a generic method), keep the resolved wrapper so + -- the resolution chain is preserved and downstream filters + -- can distinguish "resolved to generic T" from "unresolved". + if clonedObject.type == 'doc.generic.name' + and clonedObject._resolved + and vm.isResolvedToGeneric(clonedObject._resolved) then + result:merge(clonedObject) + else + local clonedNode = vm.compileNode(clonedObject) + result:merge(clonedNode) + end end end end @@ -204,6 +233,20 @@ function vm.isGenericUnsolved(source) return false end +--- Check if a resolved node contains only generic name objects. +--- Used to distinguish "V resolved to generic T" (preserve wrapper) +--- from "V resolved to concrete string" (unwrap normally). +---@param node vm.node +---@return boolean +function vm.isResolvedToGeneric(node) + for rn in node:eachObject() do + if rn.type ~= 'doc.generic.name' then + return false + end + end + return true +end + ---@param source parser.object ---@param generic vm.generic function vm.setGeneric(source, generic) diff --git a/script/vm/infer.lua b/script/vm/infer.lua index 6f21a76ab..13fdc4338 100644 --- a/script/vm/infer.lua +++ b/script/vm/infer.lua @@ -142,7 +142,13 @@ local viewNodeSwitch;viewNodeSwitch = util.switch() infer._hasClass = true local buf = {} for i, sign in ipairs(source.signs) do - buf[i] = vm.getInfer(sign):view(uri) + local view = vm.getInfer(sign):view(uri) + -- Strip outer <> from generic names since the sign + -- already wraps parameters in <>, avoiding list<> + if view and view:sub(1, 1) == '<' and view:sub(-1) == '>' then + view = view:sub(2, -2) + end + buf[i] = view end local node = vm.compileNode(source) for c in node:eachObject() do diff --git a/script/vm/sign.lua b/script/vm/sign.lua index 3718391d1..8c343dafa 100644 --- a/script/vm/sign.lua +++ b/script/vm/sign.lua @@ -2,6 +2,53 @@ local guide = require 'parser.guide' ---@class vm local vm = require 'vm.vm' +--- Find a generic name referenced in a doc.type.table's fields +--- that exists in the given genericMap. +---@param tableType parser.object doc.type.table with fields +---@param genericMap table +---@return string? The matching generic key name +local function findGenericInTableFields(tableType, genericMap) + for _, field in ipairs(tableType.fields) do + if field.extends then + local found + guide.eachSourceType(field.extends, 'doc.generic.name', function (src) + if genericMap[src[1]] then + found = src[1] + end + end) + if found then + return found + end + end + end + return nil +end + +--- Search for a generic name in extends tables of a class definition. +--- For classes like `@class list: {[integer]:T}`, the [integer] field +--- lives in the extends doc.type.table, not in @field annotations. +---@param uri uri +---@param classGlobal vm.global +---@param genericMap table +---@return string? The class generic name that maps to the integer field +local function findGenericInExtendsTable(uri, classGlobal, genericMap) + for _, set in ipairs(classGlobal:getSets(uri)) do + if set.type ~= 'doc.class' or not set.extends then + goto CONTINUE + end + for _, ext in ipairs(set.extends) do + if ext.type == 'doc.type.table' and ext.fields then + local key = findGenericInTableFields(ext, genericMap) + if key then + return key + end + end + end + ::CONTINUE:: + end + return nil +end + ---@class vm.sign ---@field parent parser.object ---@field signList vm.node[] @@ -83,28 +130,67 @@ function mt:resolve(uri, args) return end if object.type == 'doc.type.array' then + -- If the argument contains a doc.type.sign (generic class like + -- list extending { [integer]: V }), resolve element type + -- exclusively through class generic map. This directly maps + -- the array element generic (V) to the sign parameter, even + -- when it's another generic name (T inside a method body). + local handled = false for n in node:eachObject() do - if n.type == 'doc.type.array' then - -- number[] -> T[] - resolve(object.node, vm.compileNode(n.node)) - end - if n.type == 'doc.type.table' then - -- { [integer]: number } -> T[] - local tvalueNode = vm.getTableValue(uri, node, 'integer', true) - if tvalueNode then - resolve(object.node, tvalueNode) + if n.type == 'doc.type.sign' and n.signs and n.node and n.node[1] then + local classGlobal = vm.getGlobal('type', n.node[1]) + if classGlobal then + local genericMap = vm.getClassGenericMap(uri, classGlobal, n.signs) + if genericMap and object.node and object.node.type == 'doc.generic.name' then + -- V[] matching list: look up [integer] field, + -- find which class generic it references, then + -- map V directly to the sign's concrete parameter + local vKey = object.node[1] + -- First try @field annotations + vm.getClassFields(uri, classGlobal, vm.declareGlobal('type', 'integer'), function (field) + if field.extends then + guide.eachSourceType(field.extends, 'doc.generic.name', function (src) + if genericMap[src[1]] then + resolved[vKey] = genericMap[src[1]] + handled = true + end + end) + end + end) + -- Also search extends tables (for @class list: {[integer]:T}) + if not handled then + local genericKey = findGenericInExtendsTable(uri, classGlobal, genericMap) + if genericKey then + resolved[vKey] = genericMap[genericKey] + handled = true + end + end + end end + if handled then break end end - if n.type == 'global' and n.cate == 'type' then - -- ---@field [integer]: number -> T[] - ---@cast n vm.global - vm.getClassFields(uri, n, vm.declareGlobal('type', 'integer'), function (field) - resolve(object.node, vm.compileNode(field.extends)) - end) - end - if n.type == 'table' and #n >= 1 then - -- { x } / { ... } -> T[] - resolve(object.node, vm.compileNode(n[1])) + end + if not handled then + for n in node:eachObject() do + if n.type == 'doc.type.array' then + -- number[] -> T[] + resolve(object.node, vm.compileNode(n.node)) + elseif n.type == 'doc.type.table' then + -- { [integer]: number } -> T[] + local tvalueNode = vm.getTableValue(uri, node, 'integer', true) + if tvalueNode then + resolve(object.node, tvalueNode) + end + elseif n.type == 'global' and n.cate == 'type' then + -- ---@field [integer]: number -> T[] + ---@cast n vm.global + vm.getClassFields(uri, n, vm.declareGlobal('type', 'integer'), function (field) + resolve(object.node, vm.compileNode(field.extends)) + end) + elseif n.type == 'table' and #n >= 1 then + -- { x } / { ... } -> T[] + resolve(object.node, vm.compileNode(n[1])) + end end end return @@ -176,6 +262,21 @@ function mt:resolve(uri, args) end return end + if object.type == 'doc.type.sign' and object.signs then + -- list -> list: match sign parameters positionally + for n in node:eachObject() do + if n.type == 'doc.type.sign' and n.signs + and n.node and object.node + and n.node[1] == object.node[1] then + for i, signParam in ipairs(object.signs) do + if n.signs[i] then + resolve(vm.compileNode(signParam), vm.compileNode(n.signs[i])) + end + end + end + end + return + end end ---@param sign vm.node @@ -191,7 +292,8 @@ function mt:resolve(uri, args) end if obj.type == 'doc.type.table' or obj.type == 'doc.type.function' - or obj.type == 'doc.type.array' then + or obj.type == 'doc.type.array' + or obj.type == 'doc.type.sign' then ---@cast obj parser.object local hasGeneric guide.eachSourceType(obj, 'doc.generic.name', function (src) @@ -203,7 +305,8 @@ function mt:resolve(uri, args) end end if obj.type == 'variable' - or obj.type == 'local' then + or obj.type == 'local' + or obj.type == 'self' then goto CONTINUE end local view = vm.getInfer(obj):view(uri) diff --git a/script/vm/tracer.lua b/script/vm/tracer.lua index cc6d10e59..2a578aa8b 100644 --- a/script/vm/tracer.lua +++ b/script/vm/tracer.lua @@ -638,9 +638,22 @@ local lookIntoChild = util.switch() tracer:lookIntoChild(action[2], topNode) return topNode, outNode end - if action.op.type == 'and' then - topNode = tracer:lookIntoChild(action[1], topNode, topNode:copy()) - topNode = tracer:lookIntoChild(action[2], topNode, topNode:copy()) + if action.op.type == 'and' then + local topNode1, outNode1 = tracer:lookIntoChild(action[1], topNode, topNode:copy()) + topNode = tracer:lookIntoChild(action[2], topNode1, topNode1:copy()) + -- When the right side of `and` is a truthy literal (string, number, + -- true, table, function), the `and` can only be false when the left + -- side is false. Propagate the narrowed outNode so that patterns + -- like `x == nil and "default" or x` correctly infer x as non-nil. + local tp2 = action[2].type + if tp2 == 'string' + or tp2 == 'number' + or tp2 == 'integer' + or tp2 == 'table' + or tp2 == 'function' + or (tp2 == 'boolean' and action[2][1] == true) then + outNode = outNode1 + end elseif action.op.type == 'or' then outNode = outNode or topNode:copy() local topNode1, outNode1 = tracer:lookIntoChild(action[1], topNode, outNode) @@ -844,12 +857,40 @@ function mt:calcNode(source) return end if self.assignMap[source] then + -- Guard against circular dependency: when this setlocal is already + -- being compiled (e.g. if-handler's getNode triggers calcNode for + -- a setlocal whose value is currently being compiled), skip + -- lookIntoBlock to avoid propagating incomplete types and setting + -- marks that would prevent later correct processing. + if self._compilingAssigns and self._compilingAssigns[source] then + self.nodes[source] = vm.compileNode(source) + return + end + if not self._compilingAssigns then + self._compilingAssigns = {} + end + self._compilingAssigns[source] = true local node = vm.compileNode(source) + -- When the compiled node has no known type (only contains a 'variable' + -- due to circular dependency), fall back to the variable's base + -- declaration node. This prevents incomplete nodes from propagating + -- through control flow analysis (e.g. if-blocks inside for-loops), + -- which would otherwise cause the type to degrade to 'unknown'. + if not node:hasKnownType() + and self.mode == 'local' + and self.source.type == 'variable' + and self.source.base then + local baseNode = vm.compileNode(self.source.base) + if baseNode:hasKnownType() then + node = baseNode + end + end self.nodes[source] = node local parentBlock = guide.getParentBlock(source) if parentBlock then self:lookIntoBlock(parentBlock, source.finish, node) end + self._compilingAssigns[source] = nil return end end diff --git a/script/vm/visible.lua b/script/vm/visible.lua index a6b2856dc..d110f2be2 100644 --- a/script/vm/visible.lua +++ b/script/vm/visible.lua @@ -105,6 +105,9 @@ end ---@return vm.global? function vm.getParentClass(source) if source.type == 'doc.field' then + if not source.class then + return nil + end return vm.getGlobalNode(source.class) end if source.type == 'setfield' diff --git a/test/type_inference/common.lua b/test/type_inference/common.lua index c54c82ee1..a5e48c6d8 100644 --- a/test/type_inference/common.lua +++ b/test/type_inference/common.lua @@ -733,7 +733,7 @@ local f2 = f(1) local i, = f2(true) ]] -TEST 'fun(table: table<, >, index?: ):, ' [[ +TEST 'fun(table: table, index?: ):, ' [[ ---@generic T: table, K, V ---@param t T ---@return fun(table: table, index?: K):K, V @@ -914,6 +914,43 @@ for _, in ipairs(t) do end ]] +TEST '' [[ +---@generic T: table, V +---@param t T +---@return fun(table: V[], i?: integer):integer, V +---@return T +---@return integer i +local function ipairs(t) end + +---@class list +---@field [integer] T + +---@generic T +---@param self list +function list:foo() + for _, in ipairs(self) do + end +end +]] + +TEST '' [[ +---@generic T: table, V +---@param t T +---@return fun(table: V[], i?: integer):integer, V +---@return T +---@return integer i +local function ipairs(t) end + +---@class listB: {[integer]:T} + +---@generic T +---@param self listB +function listB:foo() + for _, in ipairs(self) do + end +end +]] + TEST 'boolean' [[ ---@generic T: table, K, V ---@param t T @@ -4883,7 +4920,7 @@ local a, b, , d = unpack(t) ]] -- Test for overflow in circular resolve, only pass requirement is no overflow -TEST 'Callback<>|fun():fun():fun():Success, string' [[ +TEST 'Callback' [[ --- @alias Success fun(): Success --- @alias Callback fun(): Success, T @@ -4989,9 +5026,8 @@ local = w:unwrap() ]] -- Issue #1856: Generic class display format --- Current behavior shows list<>|{...} - the <> indicates an unresolved generic --- The resolved table type is also shown -TEST 'list<>|{ [integer]: string }' [[ +-- Generic class with resolved type parameters +TEST 'list|{ [integer]: string }' [[ ---@class list: {[integer]:T} ---@generic T