diff options
-rw-r--r-- | script/vm/runner.lua | 274 | ||||
-rw-r--r-- | test/type_inference/init.lua | 11 |
2 files changed, 156 insertions, 129 deletions
diff --git a/script/vm/runner.lua b/script/vm/runner.lua index 9f9a53a2..4a26947a 100644 --- a/script/vm/runner.lua +++ b/script/vm/runner.lua @@ -57,7 +57,7 @@ function mt:_collect() end table.sort(self._objs, function (a, b) - return (a.range or a.start) < (b.range or b.start) + return a.start < b.start end) end @@ -69,17 +69,12 @@ end function mt:_fastWard(pos, node) for i = self._index, #self._objs do local obj = self._objs[i] - if (obj.range or obj.finish) > pos then + if obj.finish > pos then self._index = i return node, obj end if obj.type == 'getlocal' then self._callback(obj, node) - elseif obj.type == 'setlocal' then - local newNode = self._callback(obj, node) - if newNode then - node = newNode:copy() - end elseif obj.type == 'doc.cast' then node = node:copy() for _, cast in ipairs(obj.casts) do @@ -110,126 +105,64 @@ function mt:_fastWard(pos, node) return node, nil end ----@param action parser.object +---@param exp parser.object ---@param topNode vm.node ---@param outNode? vm.node ---@return vm.node topNode ---@return vm.node outNode -function mt:_lookInto(action, topNode, outNode) - if not action then +function mt:_lookIntoExp(exp, topNode, outNode) + if not exp then return topNode, outNode or topNode end - if self._mark[action] then + if self._mark[exp] then return topNode, outNode or topNode end - self._mark[action] = true + self._mark[exp] = true local top = self._objs[self._index] if not top then return topNode, outNode or topNode end - if not guide.isInRange(action, top.finish) - -- trick for `local tp = type(x);if tp == 'string' then` - and action.type ~= 'binary' then - return topNode, outNode or topNode - end - local set - local value = vm.getObjectValue(action) - if value then - set = action - action = value - end - if action.type == 'function' then - self:_launchBlock(action, topNode:copy()) - elseif action.type == 'loop' - or action.type == 'in' - or action.type == 'repeat' - or action.type == 'for' then - topNode = self:_launchBlock(action, topNode:copy()) - elseif action.type == 'while' then - local blockNode, mainNode = self:_lookInto(action.filter, topNode:copy(), topNode:copy()) - if action.filter then - self:_fastWard(action.filter.finish, blockNode) - end - blockNode = self:_launchBlock(action, blockNode:copy()) - if mainNode then - topNode = mainNode:merge(blockNode) - end - elseif action.type == 'if' then - local hasElse - local mainNode = topNode:copy() - local blockNodes = {} - for _, subBlock in ipairs(action) do - local blockNode = mainNode:copy() - if subBlock.filter then - blockNode, mainNode = self:_lookInto(subBlock.filter, blockNode, mainNode) - self:_fastWard(subBlock.filter.finish, blockNode) - else - hasElse = true - mainNode:clear() - end - blockNode = self:_launchBlock(subBlock, blockNode:copy()) - local neverReturn = subBlock.hasReturn - or subBlock.hasGoTo - or subBlock.hasBreak - or subBlock.hasError - if not neverReturn then - blockNodes[#blockNodes+1] = blockNode - end - end - if not hasElse and not topNode:hasKnownType() then - mainNode:merge(vm.declareGlobal('type', 'unknown')) - end - for _, blockNode in ipairs(blockNodes) do - mainNode:merge(blockNode) - end - topNode = mainNode - elseif action.type == 'getlocal' then - if action.node == self._loc then - topNode = self:_fastWard(action.finish, topNode) - topNode = topNode:copy():setTruthy() - if outNode then - outNode:setFalsy() - end - end - elseif action.type == 'unary' then - if not action[1] then + if exp.type == 'function' then + self:_launchBlock(exp, topNode:copy()) + elseif exp.type == 'unary' then + if not exp[1] then goto RETURN end - if action.op.type == 'not' then + if exp.op.type == 'not' then outNode = outNode or topNode:copy() - outNode, topNode = self:_lookInto(action[1], topNode, outNode) + outNode, topNode = self:_lookIntoExp(exp[1], topNode, outNode) end - elseif action.type == 'binary' then - if not action[1] or not action[2] then + elseif exp.type == 'binary' then + if not exp[1] or not exp[2] then goto RETURN end - if action.op.type == 'and' then - topNode = self:_lookInto(action[1], topNode) - topNode = self:_lookInto(action[2], topNode) - elseif action.op.type == 'or' then + if exp.op.type == 'and' then + topNode = self:_lookIntoExp(exp[1], topNode) + topNode = self:_lookIntoExp(exp[2], topNode) + elseif exp.op.type == 'or' then outNode = outNode or topNode:copy() - local topNode1, outNode1 = self:_lookInto(action[1], topNode, outNode) - local topNode2, outNode2 = self:_lookInto(action[2], outNode1, outNode1:copy()) + local topNode1, outNode1 = self:_lookIntoExp(exp[1], topNode, outNode) + local topNode2, outNode2 = self:_lookIntoExp(exp[2], outNode1, outNode1:copy()) topNode = vm.createNode(topNode1, topNode2) outNode = outNode2:copy() - elseif action.op.type == '==' - or action.op.type == '~=' then - local exp, checker + elseif exp.op.type == '==' + or exp.op.type == '~=' then + local handler, checker for i = 1, 2 do - if guide.isLiteral(action[i]) then - checker = action[i] - exp = action[3-i] -- Copilot tells me use `3-i` instead of `i%2+1` + if guide.isLiteral(exp[i]) then + checker = exp[i] + handler = exp[3-i] -- Copilot tells me use `3-i` instead of `i%2+1` end end - if not exp then + if not handler then goto RETURN end - if exp.type == 'getlocal' - and exp.node == self._loc then + if handler.type == 'getlocal' + and handler.node == self._loc then -- if x == y then self:_fastWard(exp.finish, topNode:copy()) local checkerNode = vm.compileNode(checker) - if action.op.type == '==' then + if exp.op.type == '==' then topNode = checkerNode if outNode then outNode:removeNode(topNode) @@ -240,16 +173,16 @@ function mt:_lookInto(action, topNode, outNode) outNode = checkerNode end end - elseif exp.type == 'call' + elseif handler.type == 'call' and checker.type == 'string' - and exp.node.special == 'type' - and exp.args - and exp.args[1] - and exp.args[1].type == 'getlocal' - and exp.args[1].node == self._loc then + and handler.node.special == 'type' + and handler.args + and handler.args[1] + and handler.args[1].type == 'getlocal' + and handler.args[1].node == self._loc then -- if type(x) == 'string' then self:_fastWard(exp.finish, topNode:copy()) - if action.op.type == '==' then + if exp.op.type == '==' then topNode:narrow(checker[1]) if outNode then outNode:remove(checker[1]) @@ -260,9 +193,9 @@ function mt:_lookInto(action, topNode, outNode) outNode:narrow(checker[1]) end end - elseif exp.type == 'getlocal' + elseif handler.type == 'getlocal' and checker.type == 'string' then - local nodeValue = vm.getObjectValue(exp.node) + local nodeValue = vm.getObjectValue(handler.node) if nodeValue and nodeValue.type == 'select' and nodeValue.sindex == 1 then @@ -275,7 +208,7 @@ function mt:_lookInto(action, topNode, outNode) and call.args[1].type == 'getlocal' and call.args[1].node == self._loc then -- `local tp = type(x);if tp == 'string' then` - if action.op.type == '==' then + if exp.op.type == '==' then topNode:narrow(checker[1]) if outNode then outNode:remove(checker[1]) @@ -290,33 +223,117 @@ function mt:_lookInto(action, topNode, outNode) end end end + elseif exp.type == 'getlocal' then + if exp.node == self._loc then + topNode = self:_fastWard(exp.finish, topNode) + topNode = topNode:copy():setTruthy() + if outNode then + outNode:setFalsy() + end + end + elseif exp.type == 'paren' then + topNode, outNode = self:_lookIntoExp(exp.exp, topNode, outNode) + elseif exp.type == 'getindex' then + self:_lookIntoExp(exp.index, topNode) + elseif exp.type == 'table' then + for _, field in ipairs(exp) do + self:_lookIntoAction(field, topNode) + end + end + ::RETURN:: + topNode = self:_fastWard(exp.finish, topNode) + return topNode, outNode or topNode +end + +---@param action parser.object +---@param topNode vm.node +---@return vm.node topNode +function mt:_lookIntoAction(action, topNode) + if not action then + return topNode + end + if self._mark[action] then + return topNode + end + self._mark[action] = true + local top = self._objs[self._index] + if not top then + return topNode + end + if not guide.isInRange(action, top.finish) + -- trick for `local tp = type(x);if tp == 'string' then` + and action.type ~= 'binary' then + return topNode + end + local value = vm.getObjectValue(action) + if value then + self:_lookIntoExp(value, topNode:copy()) + end + if action.type == 'setlocal' then + local newTopNode = self._callback(action, topNode) + if newTopNode then + topNode = newTopNode + end + elseif action.type == 'function' then + self:_launchBlock(action, topNode:copy()) + elseif action.type == 'loop' + or action.type == 'in' + or action.type == 'repeat' + or action.type == 'for' then + topNode = self:_launchBlock(action, topNode:copy()) + elseif action.type == 'while' then + local blockNode, mainNode = self:_lookIntoExp(action.filter, topNode:copy(), topNode:copy()) + if action.filter then + self:_fastWard(action.filter.finish, blockNode) + end + blockNode = self:_launchBlock(action, blockNode:copy()) + if mainNode then + topNode = mainNode:merge(blockNode) + end + elseif action.type == 'if' then + local hasElse + local mainNode = topNode:copy() + local blockNodes = {} + for _, subBlock in ipairs(action) do + local blockNode = mainNode:copy() + if subBlock.filter then + blockNode, mainNode = self:_lookIntoExp(subBlock.filter, blockNode, mainNode) + self:_fastWard(subBlock.filter.finish, blockNode) + else + hasElse = true + mainNode:clear() + end + blockNode = self:_launchBlock(subBlock, blockNode:copy()) + local neverReturn = subBlock.hasReturn + or subBlock.hasGoTo + or subBlock.hasBreak + or subBlock.hasError + if not neverReturn then + blockNodes[#blockNodes+1] = blockNode + end + end + if not hasElse and not topNode:hasKnownType() then + mainNode:merge(vm.declareGlobal('type', 'unknown')) + end + for _, blockNode in ipairs(blockNodes) do + mainNode:merge(blockNode) + end + topNode = mainNode elseif action.type == 'call' then if action.node.special == 'assert' and action.args and action.args[1] then - topNode = self:_lookInto(action.args[1], topNode) + topNode = self:_lookIntoExp(action.args[1], topNode) elseif action.args then for _, arg in ipairs(action.args) do - self:_lookInto(arg, topNode) + self:_lookIntoExp(arg, topNode) end end - elseif action.type == 'paren' then - topNode, outNode = self:_lookInto(action.exp, topNode, outNode) elseif action.type == 'return' then for _, rtn in ipairs(action) do - self:_lookInto(rtn, topNode) - end - elseif action.type == 'getindex' then - self:_lookInto(action.index, topNode) - elseif action.type == 'table' then - for _, field in ipairs(action) do - self:_lookInto(field, topNode) + self:_lookIntoExp(rtn, topNode) end end - ::RETURN:: topNode = self:_fastWard(action.finish, topNode) - if set then - topNode = self:_fastWard(set.range or set.finish, topNode) - end - return topNode, outNode or topNode + return topNode end ---@param block parser.object @@ -328,18 +345,17 @@ function mt:_launchBlock(block, node) return topNode end for _, action in ipairs(block) do - if (action.range or action.finish) < (top.range or top.finish) then + if (action.range or action.finish) < top.finish then goto CONTINUE end - topNode = self:_lookInto(action, topNode) - topNode, top = self:_fastWard(action.range or action.finish, topNode) + topNode = self:_lookIntoAction(action, topNode) + topNode, top = self:_fastWard(action.finish, topNode) if not top then return topNode end ::CONTINUE:: end - -- `x = function () end`: don't touch `x` in the end of function - topNode = self:_fastWard(block.finish - 1, topNode) + topNode = self:_fastWard(block.finish, topNode) return topNode end diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua index 0c9be21e..37cf0324 100644 --- a/test/type_inference/init.lua +++ b/test/type_inference/init.lua @@ -3270,3 +3270,14 @@ local b local <?c?> = a or b ]] + +TEST 'number|table|nil' [[ +---@type table|nil +local a + +---@type number|nil +local b + +local c = a and b +local <?d?> = a or b +]] |