diff options
-rw-r--r-- | changelog.md | 2 | ||||
-rw-r--r-- | script/vm/node.lua | 13 | ||||
-rw-r--r-- | script/vm/runner.lua | 31 | ||||
-rw-r--r-- | script/vm/type.lua | 16 | ||||
-rw-r--r-- | test/type_inference/init.lua | 9 |
5 files changed, 45 insertions, 26 deletions
diff --git a/changelog.md b/changelog.md index b33ece88..38ea7d92 100644 --- a/changelog.md +++ b/changelog.md @@ -4,10 +4,12 @@ * `FIX` incorrect type check for generic with nil * `FIX` [#1676] * `FIX` [#1677] +* `FIX` [#1679] * `FIX` [#1680] [#1676]: https://github.com/sumneko/lua-language-server/issues/1676 [#1677]: https://github.com/sumneko/lua-language-server/issues/1677 +[#1679]: https://github.com/sumneko/lua-language-server/issues/1679 [#1680]: https://github.com/sumneko/lua-language-server/issues/1680 ## 3.6.1 diff --git a/script/vm/node.lua b/script/vm/node.lua index 3866e56a..c07269ab 100644 --- a/script/vm/node.lua +++ b/script/vm/node.lua @@ -246,7 +246,8 @@ function mt:remove(name) or (c.type == 'doc.type.table' and name == 'table') or (c.type == 'doc.type.array' and name == 'table') or (c.type == 'doc.type.sign' and name == c.node[1]) - or (c.type == 'doc.type.function' and name == 'function') then + or (c.type == 'doc.type.function' and name == 'function') + or (c.type == 'doc.type.string' and name == 'string') then table.remove(self, index) self[c] = nil end @@ -254,9 +255,10 @@ function mt:remove(name) return self end +---@param uri uri ---@param name string -function mt:narrow(name) - if name ~= 'nil' and self.optional == true then +function mt:narrow(uri, name) + if self.optional == true then self.optional = nil end for index = #self, 1, -1 do @@ -267,12 +269,13 @@ function mt:narrow(name) or (c.type == 'doc.type.table' and name == 'table') or (c.type == 'doc.type.array' and name == 'table') or (c.type == 'doc.type.sign' and name == c.node[1]) - or (c.type == 'doc.type.function' and name == 'function') then + or (c.type == 'doc.type.function' and name == 'function') + or (c.type == 'doc.type.string' and name == 'string') then goto CONTINUE end if c.type == 'global' and c.cate == 'type' then if (c.name == name) - or (c.name == 'integer' and name == 'number') then + or (vm.isSubType(uri, c.name, name)) then goto CONTINUE end end diff --git a/script/vm/runner.lua b/script/vm/runner.lua index 7363f0e5..250be481 100644 --- a/script/vm/runner.lua +++ b/script/vm/runner.lua @@ -11,6 +11,7 @@ local guide = require 'parser.guide' ---@field _mark table ---@field _has table<parser.object, true> ---@field _main parser.object +---@field _uri uri local mt = {} mt.__index = mt mt._index = 1 @@ -168,15 +169,18 @@ function mt:_lookIntoChild(action, topNode, outNode) -- if x == y then topNode = self:_lookIntoChild(handler, topNode, outNode) local checkerNode = vm.compileNode(checker) - if action.op.type == '==' then - topNode = checkerNode - if outNode then - outNode:removeNode(topNode) - end - else - topNode:removeNode(checkerNode) - if outNode then - outNode = checkerNode + local checkerName = vm.getNodeName(checker) + if checkerName then + if action.op.type == '==' then + topNode:narrow(self._uri, checkerName) + if outNode then + outNode:removeNode(checkerNode) + end + else + topNode:removeNode(checkerNode) + if outNode then + outNode:narrow(self._uri, checkerName) + end end end elseif handler.type == 'call' @@ -189,14 +193,14 @@ function mt:_lookIntoChild(action, topNode, outNode) -- if type(x) == 'string' then self:_lookIntoChild(handler, topNode:copy()) if action.op.type == '==' then - topNode:narrow(checker[1]) + topNode:narrow(self._uri, checker[1]) if outNode then outNode:remove(checker[1]) end else topNode:remove(checker[1]) if outNode then - outNode:narrow(checker[1]) + outNode:narrow(self._uri, checker[1]) end end elseif handler.type == 'getlocal' @@ -215,14 +219,14 @@ function mt:_lookIntoChild(action, topNode, outNode) and call.args[1].node == self._loc then -- `local tp = type(x);if tp == 'string' then` if action.op.type == '==' then - topNode:narrow(checker[1]) + topNode:narrow(self._uri, checker[1]) if outNode then outNode:remove(checker[1]) end else topNode:remove(checker[1]) if outNode then - outNode:narrow(checker[1]) + outNode:narrow(self._uri, checker[1]) end end end @@ -352,6 +356,7 @@ function vm.launchRunner(loc, callback) _mark = {}, _has = {}, _main = main, + _uri = guide.getUri(loc), _callback = callback, }, mt) diff --git a/script/vm/type.lua b/script/vm/type.lua index 32b8e50b..126f0a83 100644 --- a/script/vm/type.lua +++ b/script/vm/type.lua @@ -9,7 +9,7 @@ local lang = require 'language' ---@param object vm.node.object ---@return string? -local function getNodeName(object) +function vm.getNodeName(object) if object.type == 'global' and object.cate == 'type' then ---@cast object vm.global return object.name @@ -90,14 +90,14 @@ local function checkParentEnum(parentName, child, uri, mark, errs) return false else ---@cast child parser.object - local childName = getNodeName(child) + local childName = vm.getNodeName(child) if childName == 'number' or childName == 'integer' or childName == 'boolean' or childName == 'string' then for _, enum in ipairs(enums) do for nd in vm.compileNode(enum):eachObject() do - if childName == getNodeName(nd) and nd[1] == child[1] then + if childName == vm.getNodeName(nd) and nd[1] == child[1] then return true end end @@ -284,7 +284,7 @@ function vm.isSubType(uri, child, parent, mark, errs) if config.get(uri, 'Lua.type.weakUnionCheck') then local hasKnownType = 0 for n in child:eachObject() do - if getNodeName(n) then + if vm.getNodeName(n) then hasKnownType = hasKnownType + 1 if vm.isSubType(uri, n, parent, mark, errs) == true then return true @@ -303,7 +303,7 @@ function vm.isSubType(uri, child, parent, mark, errs) else local weakNil = config.get(uri, 'Lua.type.weakNilCheck') for n in child:eachObject() do - local nodeName = getNodeName(n) + local nodeName = vm.getNodeName(n) if nodeName and not (nodeName == 'nil' and weakNil) and vm.isSubType(uri, n, parent, mark, errs) == false then @@ -329,7 +329,7 @@ function vm.isSubType(uri, child, parent, mark, errs) end ---@cast child vm.node.object - local childName = getNodeName(child) + local childName = vm.getNodeName(child) if childName == 'any' or childName == 'unknown' then return true @@ -349,7 +349,7 @@ function vm.isSubType(uri, child, parent, mark, errs) elseif parent.type == 'vm.node' then local hasKnownType = 0 for n in parent:eachObject() do - if getNodeName(n) then + if vm.getNodeName(n) then hasKnownType = hasKnownType + 1 if vm.isSubType(uri, child, n, mark, errs) == true then return true @@ -377,7 +377,7 @@ function vm.isSubType(uri, child, parent, mark, errs) ---@cast parent vm.node.object - local parentName = getNodeName(parent) + local parentName = vm.getNodeName(parent) if parentName == 'any' or parentName == 'unknown' then return true diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua index 40f9891d..4a60766b 100644 --- a/test/type_inference/init.lua +++ b/test/type_inference/init.lua @@ -3965,3 +3965,12 @@ local t local <?n?> = t[2] ]] + +TEST 'N' [[ +---@class N: number +local x + +if x == 0.1 then + print(<?x?>) +end +]] |