diff options
-rw-r--r-- | script/vm/runner.lua | 37 | ||||
-rw-r--r-- | test/type_inference/init.lua | 54 |
2 files changed, 79 insertions, 12 deletions
diff --git a/script/vm/runner.lua b/script/vm/runner.lua index 30d9e672..0ec26bef 100644 --- a/script/vm/runner.lua +++ b/script/vm/runner.lua @@ -56,7 +56,7 @@ function mt:_collect() end table.sort(self._objs, function (a, b) - return (a.range or a.start) < (b.range or b.start) + return (a.range or a.finish) < (b.range or b.start) end) end @@ -68,7 +68,7 @@ end function mt:_fastWard(pos, node) for i = self._index, #self._objs do local obj = self._objs[i] - if obj.start > pos then + if (obj.range or obj.finish) > pos then self._index = i return node, obj end @@ -105,7 +105,7 @@ function mt:_fastWard(pos, node) end end end - self._index = math.huge + self._index = #self._objs + 1 return node, nil end @@ -114,10 +114,12 @@ end ---@param outNode? vm.node ---@return vm.node function mt:_lookInto(action, topNode, outNode) - if action.type == 'setlocal' then - topNode = self:_fastWard(action.finish, topNode) + local set + local value = vm.getObjectValue(action) + if value then + set = action + action = value end - action = vm.getObjectValue(action) or action if action.type == 'function' or action.type == 'loop' or action.type == 'in' @@ -178,7 +180,7 @@ function mt:_lookInto(action, topNode, outNode) elseif action.op.type == 'or' then outNode = outNode or topNode:copy() local topNode1, outNode1 = self:_lookInto(action[1], topNode, outNode) - local topNode2, outNode2 = self:_lookInto(action[2], outNode1, outNode) + local topNode2, outNode2 = self:_lookInto(action[2], outNode1, outNode1:copy()) topNode = vm.createNode(topNode1, topNode2) outNode = outNode2 elseif action.op.type == '==' @@ -188,7 +190,7 @@ function mt:_lookInto(action, topNode, outNode) if action[i].type == 'getlocal' and action[i].node == self._loc then loc = action[i] checker = action[3-i] -- Copilot tells me use `3-i` instead of `i%2+1` - else + elseif action[2].type == 'getlocal' and action[2].node == self._loc then loc = action[3-i] checker = action[i] end @@ -214,10 +216,21 @@ function mt:_lookInto(action, topNode, outNode) elseif action.type == 'call' then if action.node.special == 'assert' and action.args and action.args[1] then topNode = self:_lookInto(action.args[1], topNode) + elseif action.args then + for _, arg in ipairs(action.args) do + self:_lookInto(arg, topNode) + end + end + elseif action.type == 'return' then + for _, rtn in ipairs(action) do + self:_lookInto(rtn, topNode) end end ::RETURN:: topNode = self:_fastWard(action.finish, topNode) + if set then + topNode = self:_fastWard(set.range or set.finish, topNode) + end return topNode, outNode end @@ -230,18 +243,18 @@ function mt:_launchBlock(block, node) return topNode end for _, action in ipairs(block) do - local finish = action.range or action.finish - if finish < top.start then + if (action.range or action.finish) < (top.range or top.finish) then goto CONTINUE end topNode = self:_lookInto(action, topNode) - topNode, top = self:_fastWard(action.finish, topNode) + topNode, top = self:_fastWard(action.range or action.finish, topNode) if not top then return topNode end ::CONTINUE:: end - topNode = self:_fastWard(block.finish, topNode) + -- `x = function () end`: don't touch `x` in the end of function + topNode = self:_fastWard(block.finish - 1, topNode) return topNode end diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua index 655f963d..a194ce2d 100644 --- a/test/type_inference/init.lua +++ b/test/type_inference/init.lua @@ -2428,3 +2428,57 @@ function x() end print(<?x?>) ]] + +TEST 'unknown' [[ +local x + +if x.field == 'haha' then + print(<?x?>) +end +]] + +TEST 'string' [[ +---@type string? +local t + +if not t or xxx then + return +end + +print(<?t?>) +]] + +TEST 'table' [[ +---@type table|nil +local t + +return function () + if not t then + return + end + + print(<?t?>) +end +]] + +TEST 'table' [[ +---@type table|nil +local t + +f(function () + if not t then + return + end + + print(<?t?>) +end) +]] + +TEST 'table' [[ +---@type table? +local t + +t = t or {} + +print(<?t?>) +]] |