summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
author最萌小汐 <sumneko@hotmail.com>2022-12-15 20:23:26 +0800
committer最萌小汐 <sumneko@hotmail.com>2022-12-15 20:23:26 +0800
commita744e3439e165be13f8f0ba5262b8f1efec0d86d (patch)
treea2a00988add4df42e67b9f2169c6ee2317bb0058
parent244a19d365d8b3d3881492e08fefa9847bd11a2f (diff)
downloadlua-language-server-a744e3439e165be13f8f0ba5262b8f1efec0d86d.zip
stash
-rw-r--r--script/vm/tracer.lua459
-rw-r--r--test/type_inference/init.lua9
2 files changed, 317 insertions, 151 deletions
diff --git a/script/vm/tracer.lua b/script/vm/tracer.lua
index 0c5b6939..c2b7ae39 100644
--- a/script/vm/tracer.lua
+++ b/script/vm/tracer.lua
@@ -8,11 +8,14 @@ local util = require 'utility'
---@field package _casts? parser.object[]
---@class vm.tracer
----@field source parser.object
----@field assigns parser.object[]
----@field nodes table<parser.object, vm.node|false>
----@field main parser.object
----@field uri uri
+---@field source parser.object
+---@field assigns parser.object[]
+---@field assignMap table<parser.object, true>
+---@field careMap table<parser.object, true>
+---@field mark table<parser.object, true>
+---@field nodes table<parser.object, vm.node|false>
+---@field main parser.object
+---@field uri uri
local mt = {}
mt.__index = mt
@@ -31,41 +34,54 @@ function mt:getCasts()
return root._casts
end
----@param obj parser.object
----@param mark table
-function mt:collectBlock(obj, mark)
+---@param obj parser.object
+function mt:collectAssign(obj)
while true do
local block = guide.getParentBlock(obj)
if not block then
return
end
obj = block
- if mark[obj] then
+ if self.assignMap[obj] then
return
end
if obj == self.main then
return
end
- mark[obj] = true
+ self.assignMap[obj] = true
self.assigns[#self.assigns+1] = obj
end
end
+---@param obj parser.object
+function mt:collectCare(obj)
+ while true do
+ if self.careMap[obj] then
+ return
+ end
+ if obj == self.main then
+ return
+ end
+ self.careMap[obj] = true
+ obj = obj.parent
+ end
+end
+
function mt:collectLocal()
local startPos = self.source.start
local finishPos = 0
- local mark = {}
-
self.assigns[#self.assigns+1] = self.source
+ self.assignMap[self.source] = true
for _, obj in ipairs(self.source.ref) do
if obj.type == 'setlocal' then
self.assigns[#self.assigns+1] = obj
- self:collectBlock(obj, mark)
+ self.assignMap[obj] = true
+ self:collectCare(obj)
end
if obj.type == 'getlocal' then
- self:collectBlock(obj, mark)
+ self:collectCare(obj)
end
end
@@ -84,167 +100,308 @@ function mt:collectLocal()
end)
end
----@param block parser.object
----@param pos integer
+---@param start integer
+---@param finish integer
---@return parser.object?
-function mt:getLastAssign(block, pos)
- if not block then
- return nil
- end
+function mt:getLastAssign(start, finish)
local assign
for _, obj in ipairs(self.assigns) do
- if obj.start >= pos then
+ if obj.start < start then
+ goto CONTINUE
+ end
+ if obj.start >= finish then
break
end
local objBlock = guide.getParentBlock(obj)
if not objBlock then
break
end
- if objBlock == block then
+ if objBlock.start <= finish
+ and objBlock.finish >= finish then
assign = obj
end
+ ::CONTINUE::
end
return assign
end
----@param source parser.object
----@return vm.node?
-function mt:narrow(source)
- local node = self:getNode(source)
- if not node then
- return nil
- end
-
- if source.type == 'getlocal' then
- node = node:copy()
- node:setTruthy()
- end
-
- return node
-end
-
----@param source parser.object
----@return vm.node?
-function mt:calcGet(source)
- local parent = source.parent
- if parent.type == 'filter' then
- return self:calcGet(parent)
+---@param action parser.object
+---@param topNode vm.node
+---@param outNode? vm.node
+---@return vm.node topNode
+---@return vm.node outNode
+function mt:lookIntoChild(action, topNode, outNode)
+ if not self.careMap[action]
+ or self.mark[action] then
+ return topNode, outNode or topNode
end
- if parent.type == 'ifblock' then
- local parentBlock = guide.getParentBlock(parent.parent)
- if parentBlock then
- local lastAssign = self:getLastAssign(parentBlock, parent.start)
- local node = self:getNode(lastAssign or parentBlock)
- return node
+ self.mark[action] = true
+ if action.type == 'getlocal' then
+ if action.node == self.source then
+ self.nodes[action] = topNode
+ if outNode then
+ topNode = topNode:copy():setTruthy()
+ outNode = outNode:copy():setFalsy()
+ end
end
- end
- if parent.type == 'unary' then
- return self:calcGet(parent)
- end
- return nil
-end
-
----@param source parser.object
----@return vm.node?
-function mt:calcNode(source)
- if source.type == 'getlocal' then
- if source.node ~= self.source then
- return nil
+ elseif action.type == 'filter' then
+ return self:lookIntoChild(action.exp, topNode, outNode)
+ elseif action.type == 'function' then
+ self:lookIntoBlock(action, 0, topNode:copy())
+ elseif action.type == 'unary' then
+ if not action[1] then
+ goto RETURN
end
- local block = guide.getParentBlock(source)
- if not block then
- return nil
+ if action.op.type == 'not' then
+ outNode = outNode or topNode:copy()
+ outNode, topNode = self:lookIntoChild(action[1], topNode, outNode)
+ outNode = outNode:copy()
+ end
+ elseif action.type == 'binary' then
+ if not action[1] or not action[2] then
+ goto RETURN
+ end
+ if action.op.type == 'and' then
+ topNode = self:lookIntoChild(action[1], topNode, topNode:copy())
+ topNode = self:lookIntoChild(action[2], topNode, topNode:copy())
+ elseif action.op.type == 'or' then
+ outNode = outNode or topNode:copy()
+ local topNode1, outNode1 = self:lookIntoChild(action[1], topNode, outNode)
+ local topNode2, outNode2 = self:lookIntoChild(action[2], outNode1, outNode1:copy())
+ topNode = vm.createNode(topNode1, topNode2)
+ outNode = outNode2:copy()
+ elseif action.op.type == '=='
+ or action.op.type == '~=' then
+ local handler, checker
+ for i = 1, 2 do
+ if guide.isLiteral(action[i]) then
+ checker = action[i]
+ handler = action[3-i] -- Copilot tells me use `3-i` instead of `i%2+1`
+ end
+ end
+ if not handler then
+ goto RETURN
+ end
+ if handler.type == 'getlocal'
+ and handler.node == self.source then
+ -- if x == y then
+ topNode = self:lookIntoChild(handler, topNode, outNode)
+ local checkerNode = vm.compileNode(checker)
+ local checkerName = vm.getNodeName(checker)
+ if checkerName then
+ topNode = topNode:copy()
+ if action.op.type == '==' then
+ topNode:narrow(self.uri, checkerName)
+ if outNode then
+ outNode:removeNode(checkerNode)
+ end
+ else
+ topNode:removeNode(checkerNode)
+ if outNode then
+ outNode:narrow(self.uri, checkerName)
+ end
+ end
+ end
+ elseif handler.type == 'call'
+ and checker.type == 'string'
+ and handler.node.special == 'type'
+ and handler.args
+ and handler.args[1]
+ and handler.args[1].type == 'getlocal'
+ and handler.args[1].node == self.source then
+ -- if type(x) == 'string' then
+ self:lookIntoChild(handler, topNode:copy())
+ if action.op.type == '==' then
+ topNode:narrow(self.uri, checker[1])
+ if outNode then
+ outNode:remove(checker[1])
+ end
+ else
+ topNode:remove(checker[1])
+ if outNode then
+ outNode:narrow(self.uri, checker[1])
+ end
+ end
+ elseif handler.type == 'getlocal'
+ and checker.type == 'string' then
+ local nodeValue = vm.getObjectValue(handler.node)
+ if nodeValue
+ and nodeValue.type == 'select'
+ and nodeValue.sindex == 1 then
+ local call = nodeValue.vararg
+ if call
+ and call.type == 'call'
+ and call.node.special == 'type'
+ and call.args
+ and call.args[1]
+ and call.args[1].type == 'getlocal'
+ and call.args[1].node == self.source then
+ -- `local tp = type(x);if tp == 'string' then`
+ if action.op.type == '==' then
+ topNode:narrow(self.uri, checker[1])
+ if outNode then
+ outNode:remove(checker[1])
+ end
+ else
+ topNode:remove(checker[1])
+ if outNode then
+ outNode:narrow(self.uri, checker[1])
+ end
+ end
+ end
+ end
+ end
end
- local lastAssign = self:getLastAssign(block, source.start)
+ elseif action.type == 'loop'
+ or action.type == 'in'
+ or action.type == 'repeat'
+ or action.type == 'for'
+ or action.type == 'do' then
+ self:lookIntoBlock(action, 0, topNode:copy())
+ local lastAssign = self:getLastAssign(action.start, action.finish)
if lastAssign then
local node = self:getNode(lastAssign)
- return node
- end
- local node = self:calcGet(source)
- if node then
- return node
+ if node then
+ topNode = node:copy()
+ end
end
- end
- if source.type == 'setlocal' then
- if source.node ~= self.source then
- return nil
+ elseif action.type == 'while' then
+ local blockNode, mainNode
+ if action.filter then
+ blockNode, mainNode = self:lookIntoChild(action.filter, topNode:copy(), topNode:copy())
+ else
+ blockNode = topNode:copy()
+ mainNode = topNode:copy()
end
- local node = vm.compileNode(source)
- return node
- end
- if source.type == 'local'
- or source.type == 'self' then
- if source ~= self.source then
- return nil
+ self:lookIntoBlock(action, 0, blockNode:copy())
+ local lastAssign = self:getLastAssign(action.start, action.finish)
+ if lastAssign then
+ local node = self:getNode(lastAssign)
+ if node then
+ topNode = mainNode:merge(node)
+ end
end
- local node = vm.compileNode(source)
- return node
- end
- if source.type == 'filter' then
- local node = self:narrow(source.exp)
- return node
- end
- if source.type == 'do' then
- local lastAssign = self:getLastAssign(source, source.finish)
- local node = self:getNode(lastAssign or source.parent)
- return node
- end
- if source.type == 'ifblock' then
- local filter = source.filter
- if filter then
- local node = self:getNode(filter)
- return node
+ if action.filter then
+ -- look into filter again
+ guide.eachSource(action.filter, function (src)
+ self.mark[src] = nil
+ end)
+ blockNode, topNode = self:lookIntoChild(action.filter, topNode:copy(), topNode:copy())
end
- end
- if source.type == 'if' then
- local parentBlock = guide.getParentBlock(source)
- if not parentBlock then
- return nil
- end
- local lastAssign = self:getLastAssign(parentBlock, source.start)
- local outNode = self:getNode(lastAssign or source.parent) or vm.createNode()
- for _, block in ipairs(source) do
- local blockNode = self:getNode(block)
- if not blockNode then
- goto CONTINUE
+ elseif action.type == 'if' then
+ local hasElse
+ local mainNode = topNode:copy()
+ local blockNodes = {}
+ for _, subBlock in ipairs(action) do
+ local blockNode = mainNode:copy()
+ if subBlock.filter then
+ blockNode, mainNode = self:lookIntoChild(subBlock.filter, blockNode, mainNode)
+ else
+ hasElse = true
+ mainNode:clear()
end
- if block.hasReturn
- or block.hasError
- or block.hasBreak then
- outNode:removeNode(blockNode)
- goto CONTINUE
+ self:lookIntoBlock(subBlock, 0, blockNode:copy())
+ local neverReturn = subBlock.hasReturn
+ or subBlock.hasGoTo
+ or subBlock.hasBreak
+ or subBlock.hasError
+ if not neverReturn then
+ local lastAssign = self:getLastAssign(subBlock.start, subBlock.finish)
+ if lastAssign then
+ local node = self:getNode(lastAssign)
+ if node then
+ blockNodes[#blockNodes+1] = node
+ end
+ end
end
- local blockAssign = self:getLastAssign(block, block.finish)
- if not blockAssign then
- goto CONTINUE
- end
- local blockAssignNode = self:getNode(blockAssign)
- if not blockAssignNode then
- goto CONTINUE
+ end
+ if not hasElse and not topNode:hasKnownType() then
+ mainNode:merge(vm.declareGlobal('type', 'unknown'))
+ end
+ for _, blockNode in ipairs(blockNodes) do
+ mainNode:merge(blockNode)
+ end
+ topNode = mainNode
+ elseif action.type == 'call' then
+ if action.node.special == 'assert' and action.args and action.args[1] then
+ topNode = self:lookIntoChild(action.args[1], topNode, topNode:copy())
+ end
+ elseif action.type == 'paren' then
+ topNode, outNode = self:lookIntoChild(action.exp, topNode, outNode)
+ elseif action.type == 'setlocal' then
+ if action.value then
+ self:lookIntoChild(action.value, topNode)
+ end
+ elseif action.type == 'local' then
+ if action.value
+ and action.ref
+ and action.value.type == 'select' then
+ local index = action.value.sindex
+ local call = action.value.vararg
+ if index == 1
+ and call.type == 'call'
+ and call.node
+ and call.node.special == 'type'
+ and call.args then
+ local getLoc = call.args[1]
+ if getLoc
+ and getLoc.type == 'getlocal'
+ and getLoc.node == self.source then
+ for _, ref in ipairs(action.ref) do
+ self:collectCare(ref)
+ end
+ end
end
- outNode:removeNode(blockNode)
- outNode:merge(blockAssignNode)
- ::CONTINUE::
end
end
- if source.type == 'unary' then
- if source.op.type == 'not' then
- local node = self:getNode(source[1])
- if node then
- node = node:copy()
- node:setFalsy()
- return node
- end
+ ::RETURN::
+ guide.eachChild(action, function (src)
+ if self.careMap[src] then
+ self:lookIntoChild(src, topNode)
end
+ end)
+ return topNode, outNode or topNode
+end
+
+---@param block parser.object
+---@param start integer
+---@param node vm.node
+function mt:lookIntoBlock(block, start, node)
+ for _, action in ipairs(block) do
+ if action.start < start then
+ goto CONTINUE
+ end
+ if self.careMap[action] then
+ node = self:lookIntoChild(action, node)
+ end
+ if self.assignMap[action] then
+ break
+ end
+ ::CONTINUE::
end
+end
- local block = guide.getParentBlock(source)
- if not block then
- return nil
+---@param source parser.object
+function mt:calcNode(source)
+ if source.type == 'getlocal' then
+ local lastAssign = self:getLastAssign(0, source.start)
+ if not lastAssign then
+ return
+ end
+ self:calcNode(lastAssign)
+ return
+ end
+ if source.type == 'local'
+ or source.type == 'self'
+ or source.type == 'setlocal' then
+ local node = vm.compileNode(source)
+ self.nodes[source] = node
+ local parentBlock = guide.getParentBlock(source)
+ if parentBlock then
+ self:lookIntoBlock(parentBlock, source.finish, node)
+ end
+ return
end
- local lastAssign = self:getLastAssign(block, source.start)
- local node = self:getNode(lastAssign or source.parent)
- return node
end
---@param source parser.object
@@ -259,11 +416,8 @@ function mt:getNode(source)
return nil
end
self.nodes[source] = false
- local node = self:calcNode(source)
- if node then
- self.nodes[source] = node
- end
- return node
+ self:calcNode(source)
+ return self.nodes[source] or nil
end
---@param source parser.object
@@ -277,11 +431,14 @@ local function createTracer(source)
return nil
end
local tracer = setmetatable({
- source = source,
- assigns = {},
- nodes = {},
- main = main,
- uri = guide.getUri(source),
+ source = source,
+ assigns = {},
+ assignMap = {},
+ careMap = {},
+ mark = {},
+ nodes = {},
+ main = main,
+ uri = guide.getUri(source),
}, mt)
source._tracer = tracer
diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua
index eb2d2e5c..8ae65e48 100644
--- a/test/type_inference/init.lua
+++ b/test/type_inference/init.lua
@@ -1871,6 +1871,15 @@ if xxx and x then
end
]]
+TEST 'unknown' [[
+---@type integer?
+local x
+
+if not x and x then
+ print(<?x?>)
+end
+]]
+
TEST 'integer' [[
---@type integer?
local x