diff options
author | 最萌小汐 <sumneko@hotmail.com> | 2022-12-15 20:23:26 +0800 |
---|---|---|
committer | 最萌小汐 <sumneko@hotmail.com> | 2022-12-15 20:23:26 +0800 |
commit | a744e3439e165be13f8f0ba5262b8f1efec0d86d (patch) | |
tree | a2a00988add4df42e67b9f2169c6ee2317bb0058 | |
parent | 244a19d365d8b3d3881492e08fefa9847bd11a2f (diff) | |
download | lua-language-server-a744e3439e165be13f8f0ba5262b8f1efec0d86d.zip |
stash
-rw-r--r-- | script/vm/tracer.lua | 459 | ||||
-rw-r--r-- | test/type_inference/init.lua | 9 |
2 files changed, 317 insertions, 151 deletions
diff --git a/script/vm/tracer.lua b/script/vm/tracer.lua index 0c5b6939..c2b7ae39 100644 --- a/script/vm/tracer.lua +++ b/script/vm/tracer.lua @@ -8,11 +8,14 @@ local util = require 'utility' ---@field package _casts? parser.object[] ---@class vm.tracer ----@field source parser.object ----@field assigns parser.object[] ----@field nodes table<parser.object, vm.node|false> ----@field main parser.object ----@field uri uri +---@field source parser.object +---@field assigns parser.object[] +---@field assignMap table<parser.object, true> +---@field careMap table<parser.object, true> +---@field mark table<parser.object, true> +---@field nodes table<parser.object, vm.node|false> +---@field main parser.object +---@field uri uri local mt = {} mt.__index = mt @@ -31,41 +34,54 @@ function mt:getCasts() return root._casts end ----@param obj parser.object ----@param mark table -function mt:collectBlock(obj, mark) +---@param obj parser.object +function mt:collectAssign(obj) while true do local block = guide.getParentBlock(obj) if not block then return end obj = block - if mark[obj] then + if self.assignMap[obj] then return end if obj == self.main then return end - mark[obj] = true + self.assignMap[obj] = true self.assigns[#self.assigns+1] = obj end end +---@param obj parser.object +function mt:collectCare(obj) + while true do + if self.careMap[obj] then + return + end + if obj == self.main then + return + end + self.careMap[obj] = true + obj = obj.parent + end +end + function mt:collectLocal() local startPos = self.source.start local finishPos = 0 - local mark = {} - self.assigns[#self.assigns+1] = self.source + self.assignMap[self.source] = true for _, obj in ipairs(self.source.ref) do if obj.type == 'setlocal' then self.assigns[#self.assigns+1] = obj - self:collectBlock(obj, mark) + self.assignMap[obj] = true + self:collectCare(obj) end if obj.type == 'getlocal' then - self:collectBlock(obj, mark) + self:collectCare(obj) end end @@ -84,167 +100,308 @@ function mt:collectLocal() end) end ----@param block parser.object ----@param pos integer +---@param start integer +---@param finish integer ---@return parser.object? -function mt:getLastAssign(block, pos) - if not block then - return nil - end +function mt:getLastAssign(start, finish) local assign for _, obj in ipairs(self.assigns) do - if obj.start >= pos then + if obj.start < start then + goto CONTINUE + end + if obj.start >= finish then break end local objBlock = guide.getParentBlock(obj) if not objBlock then break end - if objBlock == block then + if objBlock.start <= finish + and objBlock.finish >= finish then assign = obj end + ::CONTINUE:: end return assign end ----@param source parser.object ----@return vm.node? -function mt:narrow(source) - local node = self:getNode(source) - if not node then - return nil - end - - if source.type == 'getlocal' then - node = node:copy() - node:setTruthy() - end - - return node -end - ----@param source parser.object ----@return vm.node? -function mt:calcGet(source) - local parent = source.parent - if parent.type == 'filter' then - return self:calcGet(parent) +---@param action parser.object +---@param topNode vm.node +---@param outNode? vm.node +---@return vm.node topNode +---@return vm.node outNode +function mt:lookIntoChild(action, topNode, outNode) + if not self.careMap[action] + or self.mark[action] then + return topNode, outNode or topNode end - if parent.type == 'ifblock' then - local parentBlock = guide.getParentBlock(parent.parent) - if parentBlock then - local lastAssign = self:getLastAssign(parentBlock, parent.start) - local node = self:getNode(lastAssign or parentBlock) - return node + self.mark[action] = true + if action.type == 'getlocal' then + if action.node == self.source then + self.nodes[action] = topNode + if outNode then + topNode = topNode:copy():setTruthy() + outNode = outNode:copy():setFalsy() + end end - end - if parent.type == 'unary' then - return self:calcGet(parent) - end - return nil -end - ----@param source parser.object ----@return vm.node? -function mt:calcNode(source) - if source.type == 'getlocal' then - if source.node ~= self.source then - return nil + elseif action.type == 'filter' then + return self:lookIntoChild(action.exp, topNode, outNode) + elseif action.type == 'function' then + self:lookIntoBlock(action, 0, topNode:copy()) + elseif action.type == 'unary' then + if not action[1] then + goto RETURN end - local block = guide.getParentBlock(source) - if not block then - return nil + if action.op.type == 'not' then + outNode = outNode or topNode:copy() + outNode, topNode = self:lookIntoChild(action[1], topNode, outNode) + outNode = outNode:copy() + end + elseif action.type == 'binary' then + if not action[1] or not action[2] then + goto RETURN + end + if action.op.type == 'and' then + topNode = self:lookIntoChild(action[1], topNode, topNode:copy()) + topNode = self:lookIntoChild(action[2], topNode, topNode:copy()) + elseif action.op.type == 'or' then + outNode = outNode or topNode:copy() + local topNode1, outNode1 = self:lookIntoChild(action[1], topNode, outNode) + local topNode2, outNode2 = self:lookIntoChild(action[2], outNode1, outNode1:copy()) + topNode = vm.createNode(topNode1, topNode2) + outNode = outNode2:copy() + elseif action.op.type == '==' + or action.op.type == '~=' then + local handler, checker + for i = 1, 2 do + if guide.isLiteral(action[i]) then + checker = action[i] + handler = action[3-i] -- Copilot tells me use `3-i` instead of `i%2+1` + end + end + if not handler then + goto RETURN + end + if handler.type == 'getlocal' + and handler.node == self.source then + -- if x == y then + topNode = self:lookIntoChild(handler, topNode, outNode) + local checkerNode = vm.compileNode(checker) + local checkerName = vm.getNodeName(checker) + if checkerName then + topNode = topNode:copy() + if action.op.type == '==' then + topNode:narrow(self.uri, checkerName) + if outNode then + outNode:removeNode(checkerNode) + end + else + topNode:removeNode(checkerNode) + if outNode then + outNode:narrow(self.uri, checkerName) + end + end + end + elseif handler.type == 'call' + and checker.type == 'string' + and handler.node.special == 'type' + and handler.args + and handler.args[1] + and handler.args[1].type == 'getlocal' + and handler.args[1].node == self.source then + -- if type(x) == 'string' then + self:lookIntoChild(handler, topNode:copy()) + if action.op.type == '==' then + topNode:narrow(self.uri, checker[1]) + if outNode then + outNode:remove(checker[1]) + end + else + topNode:remove(checker[1]) + if outNode then + outNode:narrow(self.uri, checker[1]) + end + end + elseif handler.type == 'getlocal' + and checker.type == 'string' then + local nodeValue = vm.getObjectValue(handler.node) + if nodeValue + and nodeValue.type == 'select' + and nodeValue.sindex == 1 then + local call = nodeValue.vararg + if call + and call.type == 'call' + and call.node.special == 'type' + and call.args + and call.args[1] + and call.args[1].type == 'getlocal' + and call.args[1].node == self.source then + -- `local tp = type(x);if tp == 'string' then` + if action.op.type == '==' then + topNode:narrow(self.uri, checker[1]) + if outNode then + outNode:remove(checker[1]) + end + else + topNode:remove(checker[1]) + if outNode then + outNode:narrow(self.uri, checker[1]) + end + end + end + end + end end - local lastAssign = self:getLastAssign(block, source.start) + elseif action.type == 'loop' + or action.type == 'in' + or action.type == 'repeat' + or action.type == 'for' + or action.type == 'do' then + self:lookIntoBlock(action, 0, topNode:copy()) + local lastAssign = self:getLastAssign(action.start, action.finish) if lastAssign then local node = self:getNode(lastAssign) - return node - end - local node = self:calcGet(source) - if node then - return node + if node then + topNode = node:copy() + end end - end - if source.type == 'setlocal' then - if source.node ~= self.source then - return nil + elseif action.type == 'while' then + local blockNode, mainNode + if action.filter then + blockNode, mainNode = self:lookIntoChild(action.filter, topNode:copy(), topNode:copy()) + else + blockNode = topNode:copy() + mainNode = topNode:copy() end - local node = vm.compileNode(source) - return node - end - if source.type == 'local' - or source.type == 'self' then - if source ~= self.source then - return nil + self:lookIntoBlock(action, 0, blockNode:copy()) + local lastAssign = self:getLastAssign(action.start, action.finish) + if lastAssign then + local node = self:getNode(lastAssign) + if node then + topNode = mainNode:merge(node) + end end - local node = vm.compileNode(source) - return node - end - if source.type == 'filter' then - local node = self:narrow(source.exp) - return node - end - if source.type == 'do' then - local lastAssign = self:getLastAssign(source, source.finish) - local node = self:getNode(lastAssign or source.parent) - return node - end - if source.type == 'ifblock' then - local filter = source.filter - if filter then - local node = self:getNode(filter) - return node + if action.filter then + -- look into filter again + guide.eachSource(action.filter, function (src) + self.mark[src] = nil + end) + blockNode, topNode = self:lookIntoChild(action.filter, topNode:copy(), topNode:copy()) end - end - if source.type == 'if' then - local parentBlock = guide.getParentBlock(source) - if not parentBlock then - return nil - end - local lastAssign = self:getLastAssign(parentBlock, source.start) - local outNode = self:getNode(lastAssign or source.parent) or vm.createNode() - for _, block in ipairs(source) do - local blockNode = self:getNode(block) - if not blockNode then - goto CONTINUE + elseif action.type == 'if' then + local hasElse + local mainNode = topNode:copy() + local blockNodes = {} + for _, subBlock in ipairs(action) do + local blockNode = mainNode:copy() + if subBlock.filter then + blockNode, mainNode = self:lookIntoChild(subBlock.filter, blockNode, mainNode) + else + hasElse = true + mainNode:clear() end - if block.hasReturn - or block.hasError - or block.hasBreak then - outNode:removeNode(blockNode) - goto CONTINUE + self:lookIntoBlock(subBlock, 0, blockNode:copy()) + local neverReturn = subBlock.hasReturn + or subBlock.hasGoTo + or subBlock.hasBreak + or subBlock.hasError + if not neverReturn then + local lastAssign = self:getLastAssign(subBlock.start, subBlock.finish) + if lastAssign then + local node = self:getNode(lastAssign) + if node then + blockNodes[#blockNodes+1] = node + end + end end - local blockAssign = self:getLastAssign(block, block.finish) - if not blockAssign then - goto CONTINUE - end - local blockAssignNode = self:getNode(blockAssign) - if not blockAssignNode then - goto CONTINUE + end + if not hasElse and not topNode:hasKnownType() then + mainNode:merge(vm.declareGlobal('type', 'unknown')) + end + for _, blockNode in ipairs(blockNodes) do + mainNode:merge(blockNode) + end + topNode = mainNode + elseif action.type == 'call' then + if action.node.special == 'assert' and action.args and action.args[1] then + topNode = self:lookIntoChild(action.args[1], topNode, topNode:copy()) + end + elseif action.type == 'paren' then + topNode, outNode = self:lookIntoChild(action.exp, topNode, outNode) + elseif action.type == 'setlocal' then + if action.value then + self:lookIntoChild(action.value, topNode) + end + elseif action.type == 'local' then + if action.value + and action.ref + and action.value.type == 'select' then + local index = action.value.sindex + local call = action.value.vararg + if index == 1 + and call.type == 'call' + and call.node + and call.node.special == 'type' + and call.args then + local getLoc = call.args[1] + if getLoc + and getLoc.type == 'getlocal' + and getLoc.node == self.source then + for _, ref in ipairs(action.ref) do + self:collectCare(ref) + end + end end - outNode:removeNode(blockNode) - outNode:merge(blockAssignNode) - ::CONTINUE:: end end - if source.type == 'unary' then - if source.op.type == 'not' then - local node = self:getNode(source[1]) - if node then - node = node:copy() - node:setFalsy() - return node - end + ::RETURN:: + guide.eachChild(action, function (src) + if self.careMap[src] then + self:lookIntoChild(src, topNode) end + end) + return topNode, outNode or topNode +end + +---@param block parser.object +---@param start integer +---@param node vm.node +function mt:lookIntoBlock(block, start, node) + for _, action in ipairs(block) do + if action.start < start then + goto CONTINUE + end + if self.careMap[action] then + node = self:lookIntoChild(action, node) + end + if self.assignMap[action] then + break + end + ::CONTINUE:: end +end - local block = guide.getParentBlock(source) - if not block then - return nil +---@param source parser.object +function mt:calcNode(source) + if source.type == 'getlocal' then + local lastAssign = self:getLastAssign(0, source.start) + if not lastAssign then + return + end + self:calcNode(lastAssign) + return + end + if source.type == 'local' + or source.type == 'self' + or source.type == 'setlocal' then + local node = vm.compileNode(source) + self.nodes[source] = node + local parentBlock = guide.getParentBlock(source) + if parentBlock then + self:lookIntoBlock(parentBlock, source.finish, node) + end + return end - local lastAssign = self:getLastAssign(block, source.start) - local node = self:getNode(lastAssign or source.parent) - return node end ---@param source parser.object @@ -259,11 +416,8 @@ function mt:getNode(source) return nil end self.nodes[source] = false - local node = self:calcNode(source) - if node then - self.nodes[source] = node - end - return node + self:calcNode(source) + return self.nodes[source] or nil end ---@param source parser.object @@ -277,11 +431,14 @@ local function createTracer(source) return nil end local tracer = setmetatable({ - source = source, - assigns = {}, - nodes = {}, - main = main, - uri = guide.getUri(source), + source = source, + assigns = {}, + assignMap = {}, + careMap = {}, + mark = {}, + nodes = {}, + main = main, + uri = guide.getUri(source), }, mt) source._tracer = tracer diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua index eb2d2e5c..8ae65e48 100644 --- a/test/type_inference/init.lua +++ b/test/type_inference/init.lua @@ -1871,6 +1871,15 @@ if xxx and x then end ]] +TEST 'unknown' [[ +---@type integer? +local x + +if not x and x then + print(<?x?>) +end +]] + TEST 'integer' [[ ---@type integer? local x |