summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--script/vm/runner.lua274
-rw-r--r--test/type_inference/init.lua11
2 files changed, 156 insertions, 129 deletions
diff --git a/script/vm/runner.lua b/script/vm/runner.lua
index 9f9a53a2..4a26947a 100644
--- a/script/vm/runner.lua
+++ b/script/vm/runner.lua
@@ -57,7 +57,7 @@ function mt:_collect()
end
table.sort(self._objs, function (a, b)
- return (a.range or a.start) < (b.range or b.start)
+ return a.start < b.start
end)
end
@@ -69,17 +69,12 @@ end
function mt:_fastWard(pos, node)
for i = self._index, #self._objs do
local obj = self._objs[i]
- if (obj.range or obj.finish) > pos then
+ if obj.finish > pos then
self._index = i
return node, obj
end
if obj.type == 'getlocal' then
self._callback(obj, node)
- elseif obj.type == 'setlocal' then
- local newNode = self._callback(obj, node)
- if newNode then
- node = newNode:copy()
- end
elseif obj.type == 'doc.cast' then
node = node:copy()
for _, cast in ipairs(obj.casts) do
@@ -110,126 +105,64 @@ function mt:_fastWard(pos, node)
return node, nil
end
----@param action parser.object
+---@param exp parser.object
---@param topNode vm.node
---@param outNode? vm.node
---@return vm.node topNode
---@return vm.node outNode
-function mt:_lookInto(action, topNode, outNode)
- if not action then
+function mt:_lookIntoExp(exp, topNode, outNode)
+ if not exp then
return topNode, outNode or topNode
end
- if self._mark[action] then
+ if self._mark[exp] then
return topNode, outNode or topNode
end
- self._mark[action] = true
+ self._mark[exp] = true
local top = self._objs[self._index]
if not top then
return topNode, outNode or topNode
end
- if not guide.isInRange(action, top.finish)
- -- trick for `local tp = type(x);if tp == 'string' then`
- and action.type ~= 'binary' then
- return topNode, outNode or topNode
- end
- local set
- local value = vm.getObjectValue(action)
- if value then
- set = action
- action = value
- end
- if action.type == 'function' then
- self:_launchBlock(action, topNode:copy())
- elseif action.type == 'loop'
- or action.type == 'in'
- or action.type == 'repeat'
- or action.type == 'for' then
- topNode = self:_launchBlock(action, topNode:copy())
- elseif action.type == 'while' then
- local blockNode, mainNode = self:_lookInto(action.filter, topNode:copy(), topNode:copy())
- if action.filter then
- self:_fastWard(action.filter.finish, blockNode)
- end
- blockNode = self:_launchBlock(action, blockNode:copy())
- if mainNode then
- topNode = mainNode:merge(blockNode)
- end
- elseif action.type == 'if' then
- local hasElse
- local mainNode = topNode:copy()
- local blockNodes = {}
- for _, subBlock in ipairs(action) do
- local blockNode = mainNode:copy()
- if subBlock.filter then
- blockNode, mainNode = self:_lookInto(subBlock.filter, blockNode, mainNode)
- self:_fastWard(subBlock.filter.finish, blockNode)
- else
- hasElse = true
- mainNode:clear()
- end
- blockNode = self:_launchBlock(subBlock, blockNode:copy())
- local neverReturn = subBlock.hasReturn
- or subBlock.hasGoTo
- or subBlock.hasBreak
- or subBlock.hasError
- if not neverReturn then
- blockNodes[#blockNodes+1] = blockNode
- end
- end
- if not hasElse and not topNode:hasKnownType() then
- mainNode:merge(vm.declareGlobal('type', 'unknown'))
- end
- for _, blockNode in ipairs(blockNodes) do
- mainNode:merge(blockNode)
- end
- topNode = mainNode
- elseif action.type == 'getlocal' then
- if action.node == self._loc then
- topNode = self:_fastWard(action.finish, topNode)
- topNode = topNode:copy():setTruthy()
- if outNode then
- outNode:setFalsy()
- end
- end
- elseif action.type == 'unary' then
- if not action[1] then
+ if exp.type == 'function' then
+ self:_launchBlock(exp, topNode:copy())
+ elseif exp.type == 'unary' then
+ if not exp[1] then
goto RETURN
end
- if action.op.type == 'not' then
+ if exp.op.type == 'not' then
outNode = outNode or topNode:copy()
- outNode, topNode = self:_lookInto(action[1], topNode, outNode)
+ outNode, topNode = self:_lookIntoExp(exp[1], topNode, outNode)
end
- elseif action.type == 'binary' then
- if not action[1] or not action[2] then
+ elseif exp.type == 'binary' then
+ if not exp[1] or not exp[2] then
goto RETURN
end
- if action.op.type == 'and' then
- topNode = self:_lookInto(action[1], topNode)
- topNode = self:_lookInto(action[2], topNode)
- elseif action.op.type == 'or' then
+ if exp.op.type == 'and' then
+ topNode = self:_lookIntoExp(exp[1], topNode)
+ topNode = self:_lookIntoExp(exp[2], topNode)
+ elseif exp.op.type == 'or' then
outNode = outNode or topNode:copy()
- local topNode1, outNode1 = self:_lookInto(action[1], topNode, outNode)
- local topNode2, outNode2 = self:_lookInto(action[2], outNode1, outNode1:copy())
+ local topNode1, outNode1 = self:_lookIntoExp(exp[1], topNode, outNode)
+ local topNode2, outNode2 = self:_lookIntoExp(exp[2], outNode1, outNode1:copy())
topNode = vm.createNode(topNode1, topNode2)
outNode = outNode2:copy()
- elseif action.op.type == '=='
- or action.op.type == '~=' then
- local exp, checker
+ elseif exp.op.type == '=='
+ or exp.op.type == '~=' then
+ local handler, checker
for i = 1, 2 do
- if guide.isLiteral(action[i]) then
- checker = action[i]
- exp = action[3-i] -- Copilot tells me use `3-i` instead of `i%2+1`
+ if guide.isLiteral(exp[i]) then
+ checker = exp[i]
+ handler = exp[3-i] -- Copilot tells me use `3-i` instead of `i%2+1`
end
end
- if not exp then
+ if not handler then
goto RETURN
end
- if exp.type == 'getlocal'
- and exp.node == self._loc then
+ if handler.type == 'getlocal'
+ and handler.node == self._loc then
-- if x == y then
self:_fastWard(exp.finish, topNode:copy())
local checkerNode = vm.compileNode(checker)
- if action.op.type == '==' then
+ if exp.op.type == '==' then
topNode = checkerNode
if outNode then
outNode:removeNode(topNode)
@@ -240,16 +173,16 @@ function mt:_lookInto(action, topNode, outNode)
outNode = checkerNode
end
end
- elseif exp.type == 'call'
+ elseif handler.type == 'call'
and checker.type == 'string'
- and exp.node.special == 'type'
- and exp.args
- and exp.args[1]
- and exp.args[1].type == 'getlocal'
- and exp.args[1].node == self._loc then
+ and handler.node.special == 'type'
+ and handler.args
+ and handler.args[1]
+ and handler.args[1].type == 'getlocal'
+ and handler.args[1].node == self._loc then
-- if type(x) == 'string' then
self:_fastWard(exp.finish, topNode:copy())
- if action.op.type == '==' then
+ if exp.op.type == '==' then
topNode:narrow(checker[1])
if outNode then
outNode:remove(checker[1])
@@ -260,9 +193,9 @@ function mt:_lookInto(action, topNode, outNode)
outNode:narrow(checker[1])
end
end
- elseif exp.type == 'getlocal'
+ elseif handler.type == 'getlocal'
and checker.type == 'string' then
- local nodeValue = vm.getObjectValue(exp.node)
+ local nodeValue = vm.getObjectValue(handler.node)
if nodeValue
and nodeValue.type == 'select'
and nodeValue.sindex == 1 then
@@ -275,7 +208,7 @@ function mt:_lookInto(action, topNode, outNode)
and call.args[1].type == 'getlocal'
and call.args[1].node == self._loc then
-- `local tp = type(x);if tp == 'string' then`
- if action.op.type == '==' then
+ if exp.op.type == '==' then
topNode:narrow(checker[1])
if outNode then
outNode:remove(checker[1])
@@ -290,33 +223,117 @@ function mt:_lookInto(action, topNode, outNode)
end
end
end
+ elseif exp.type == 'getlocal' then
+ if exp.node == self._loc then
+ topNode = self:_fastWard(exp.finish, topNode)
+ topNode = topNode:copy():setTruthy()
+ if outNode then
+ outNode:setFalsy()
+ end
+ end
+ elseif exp.type == 'paren' then
+ topNode, outNode = self:_lookIntoExp(exp.exp, topNode, outNode)
+ elseif exp.type == 'getindex' then
+ self:_lookIntoExp(exp.index, topNode)
+ elseif exp.type == 'table' then
+ for _, field in ipairs(exp) do
+ self:_lookIntoAction(field, topNode)
+ end
+ end
+ ::RETURN::
+ topNode = self:_fastWard(exp.finish, topNode)
+ return topNode, outNode or topNode
+end
+
+---@param action parser.object
+---@param topNode vm.node
+---@return vm.node topNode
+function mt:_lookIntoAction(action, topNode)
+ if not action then
+ return topNode
+ end
+ if self._mark[action] then
+ return topNode
+ end
+ self._mark[action] = true
+ local top = self._objs[self._index]
+ if not top then
+ return topNode
+ end
+ if not guide.isInRange(action, top.finish)
+ -- trick for `local tp = type(x);if tp == 'string' then`
+ and action.type ~= 'binary' then
+ return topNode
+ end
+ local value = vm.getObjectValue(action)
+ if value then
+ self:_lookIntoExp(value, topNode:copy())
+ end
+ if action.type == 'setlocal' then
+ local newTopNode = self._callback(action, topNode)
+ if newTopNode then
+ topNode = newTopNode
+ end
+ elseif action.type == 'function' then
+ self:_launchBlock(action, topNode:copy())
+ elseif action.type == 'loop'
+ or action.type == 'in'
+ or action.type == 'repeat'
+ or action.type == 'for' then
+ topNode = self:_launchBlock(action, topNode:copy())
+ elseif action.type == 'while' then
+ local blockNode, mainNode = self:_lookIntoExp(action.filter, topNode:copy(), topNode:copy())
+ if action.filter then
+ self:_fastWard(action.filter.finish, blockNode)
+ end
+ blockNode = self:_launchBlock(action, blockNode:copy())
+ if mainNode then
+ topNode = mainNode:merge(blockNode)
+ end
+ elseif action.type == 'if' then
+ local hasElse
+ local mainNode = topNode:copy()
+ local blockNodes = {}
+ for _, subBlock in ipairs(action) do
+ local blockNode = mainNode:copy()
+ if subBlock.filter then
+ blockNode, mainNode = self:_lookIntoExp(subBlock.filter, blockNode, mainNode)
+ self:_fastWard(subBlock.filter.finish, blockNode)
+ else
+ hasElse = true
+ mainNode:clear()
+ end
+ blockNode = self:_launchBlock(subBlock, blockNode:copy())
+ local neverReturn = subBlock.hasReturn
+ or subBlock.hasGoTo
+ or subBlock.hasBreak
+ or subBlock.hasError
+ if not neverReturn then
+ blockNodes[#blockNodes+1] = blockNode
+ end
+ end
+ if not hasElse and not topNode:hasKnownType() then
+ mainNode:merge(vm.declareGlobal('type', 'unknown'))
+ end
+ for _, blockNode in ipairs(blockNodes) do
+ mainNode:merge(blockNode)
+ end
+ topNode = mainNode
elseif action.type == 'call' then
if action.node.special == 'assert' and action.args and action.args[1] then
- topNode = self:_lookInto(action.args[1], topNode)
+ topNode = self:_lookIntoExp(action.args[1], topNode)
elseif action.args then
for _, arg in ipairs(action.args) do
- self:_lookInto(arg, topNode)
+ self:_lookIntoExp(arg, topNode)
end
end
- elseif action.type == 'paren' then
- topNode, outNode = self:_lookInto(action.exp, topNode, outNode)
elseif action.type == 'return' then
for _, rtn in ipairs(action) do
- self:_lookInto(rtn, topNode)
- end
- elseif action.type == 'getindex' then
- self:_lookInto(action.index, topNode)
- elseif action.type == 'table' then
- for _, field in ipairs(action) do
- self:_lookInto(field, topNode)
+ self:_lookIntoExp(rtn, topNode)
end
end
- ::RETURN::
topNode = self:_fastWard(action.finish, topNode)
- if set then
- topNode = self:_fastWard(set.range or set.finish, topNode)
- end
- return topNode, outNode or topNode
+ return topNode
end
---@param block parser.object
@@ -328,18 +345,17 @@ function mt:_launchBlock(block, node)
return topNode
end
for _, action in ipairs(block) do
- if (action.range or action.finish) < (top.range or top.finish) then
+ if (action.range or action.finish) < top.finish then
goto CONTINUE
end
- topNode = self:_lookInto(action, topNode)
- topNode, top = self:_fastWard(action.range or action.finish, topNode)
+ topNode = self:_lookIntoAction(action, topNode)
+ topNode, top = self:_fastWard(action.finish, topNode)
if not top then
return topNode
end
::CONTINUE::
end
- -- `x = function () end`: don't touch `x` in the end of function
- topNode = self:_fastWard(block.finish - 1, topNode)
+ topNode = self:_fastWard(block.finish, topNode)
return topNode
end
diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua
index 0c9be21e..37cf0324 100644
--- a/test/type_inference/init.lua
+++ b/test/type_inference/init.lua
@@ -3270,3 +3270,14 @@ local b
local <?c?> = a or b
]]
+
+TEST 'number|table|nil' [[
+---@type table|nil
+local a
+
+---@type number|nil
+local b
+
+local c = a and b
+local <?d?> = a or b
+]]