diff options
Diffstat (limited to 'script')
-rw-r--r-- | script/vm/runner.lua | 128 |
1 files changed, 80 insertions, 48 deletions
diff --git a/script/vm/runner.lua b/script/vm/runner.lua index 87619e54..6c3c69a1 100644 --- a/script/vm/runner.lua +++ b/script/vm/runner.lua @@ -6,7 +6,7 @@ local guide = require 'parser.guide' ---@class vm.runner ---@field _loc parser.object ----@field _objs parser.object[] +---@field _casts parser.object[] ---@field _callback vm.runner.callback ---@field _mark table ---@field _has table<parser.object, true> @@ -51,34 +51,60 @@ function mt:_collect() for _, ref in ipairs(self._loc.ref) do if ref.type == 'getlocal' or ref.type == 'setlocal' then - self._objs[#self._objs+1] = ref - if ref.start > finishPos then - finishPos = ref.start + self:_markHas(ref) + if ref.finish > finishPos then + finishPos = ref.finish end end end - if #self._objs == 0 then - return - end - local casts = self:_getCasts() for _, cast in ipairs(casts) do if cast.loc[1] == self._loc[1] and cast.start > startPos and cast.finish < finishPos and guide.getLocal(self._loc, self._loc[1], cast.start) == self._loc then - self._objs[#self._objs+1] = cast + self._casts[#self._casts+1] = cast + self:_markHas(cast) end end +end - table.sort(self._objs, function (a, b) - return a.start < b.start - end) - - for _, obj in ipairs(self._objs) do - self:_markHas(obj) +---@param pos integer +---@param topNode vm.node +---@return vm.node +function mt:_fastWardCasts(pos, topNode) + for i = self._index, #self._casts do + local action = self._casts[i] + if action.start > pos then + self._index = i + return topNode + end + topNode = topNode:copy() + for _, cast in ipairs(action.casts) do + if cast.mode == '+' then + if cast.optional then + topNode:addOptional() + end + if cast.extends then + topNode:merge(vm.compileNode(cast.extends)) + end + elseif cast.mode == '-' then + if cast.optional then + topNode:removeOptional() + end + if cast.extends then + topNode:removeNode(vm.compileNode(cast.extends)) + end + else + if cast.extends then + topNode:clear() + topNode:merge(vm.compileNode(cast.extends)) + end + end + end end + return topNode end ---@param action parser.object @@ -93,8 +119,13 @@ function mt:_lookIntoChild(action, topNode, outNode) end self._mark[action] = true if action.type == 'getlocal' then - self._callback(action, topNode) - topNode = topNode:copy():setTruthy() + if action.node == self._loc then + self._callback(action, topNode) + if outNode then + topNode = topNode:copy():setTruthy() + outNode = outNode:copy():setFalsy() + end + end elseif action.type == 'function' then self:_lookIntoBlock(action, topNode:copy()) elseif action.type == 'unary' then @@ -110,8 +141,9 @@ function mt:_lookIntoChild(action, topNode, outNode) goto RETURN end if action.op.type == 'and' then - topNode = self:_lookIntoChild(action[1], topNode) - topNode = self:_lookIntoChild(action[2], topNode) + local dummyNode = topNode:copy() + topNode = self:_lookIntoChild(action[1], topNode, dummyNode) + topNode = self:_lookIntoChild(action[2], topNode, dummyNode) elseif action.op.type == 'or' then outNode = outNode or topNode:copy() local topNode1, outNode1 = self:_lookIntoChild(action[1], topNode, outNode) @@ -133,7 +165,7 @@ function mt:_lookIntoChild(action, topNode, outNode) if handler.type == 'getlocal' and handler.node == self._loc then -- if x == y then - self:_lookIntoChild(handler, topNode:copy()) + topNode = self:_lookIntoChild(handler, topNode, outNode) local checkerNode = vm.compileNode(checker) if action.op.type == '==' then topNode = checkerNode @@ -202,9 +234,12 @@ function mt:_lookIntoChild(action, topNode, outNode) or action.type == 'for' then topNode = self:_lookIntoBlock(action, topNode:copy()) elseif action.type == 'while' then - local blockNode, mainNode = self:_lookIntoChild(action.filter, topNode:copy(), topNode:copy()) + local blockNode, mainNode if action.filter then - self:_lookIntoChild(action.filter, topNode) + blockNode, mainNode = self:_lookIntoChild(action.filter, topNode:copy(), topNode:copy()) + else + blockNode = topNode:copy() + mainNode = topNode:copy() end blockNode = self:_lookIntoBlock(action, blockNode:copy()) if mainNode then @@ -240,8 +275,10 @@ function mt:_lookIntoChild(action, topNode, outNode) topNode = mainNode elseif action.type == 'call' then if action.node.special == 'assert' and action.args and action.args[1] then - topNode = self:_lookIntoChild(action.args[1], topNode) + topNode = self:_lookIntoChild(action.args[1], topNode, topNode:copy()) end + elseif action.type == 'paren' then + topNode, outNode = self:_lookIntoChild(action.exp, topNode, outNode) elseif action.type == 'setlocal' then if action.node == self._loc then if action.value then @@ -249,27 +286,23 @@ function mt:_lookIntoChild(action, topNode, outNode) end topNode = self._callback(action) end - elseif action.type == 'doc.cast' then - topNode = topNode:copy() - for _, cast in ipairs(action.casts) do - if cast.mode == '+' then - if cast.optional then - topNode:addOptional() - end - if cast.extends then - topNode:merge(vm.compileNode(cast.extends)) - end - elseif cast.mode == '-' then - if cast.optional then - topNode:removeOptional() - end - if cast.extends then - topNode:removeNode(vm.compileNode(cast.extends)) - end - else - if cast.extends then - topNode:clear() - topNode:merge(vm.compileNode(cast.extends)) + elseif action.type == 'local' then + if action.value + and action.value.type == 'select' then + local index = action.value.sindex + local call = action.value.vararg + if index == 1 + and call.type == 'call' + and call.node + and call.node.special == 'type' + and call.args then + local getLoc = call.args[1] + if getLoc + and getLoc.type == 'getlocal' + and getLoc.node == self._loc then + for _, ref in ipairs(action.ref) do + self:_markHas(ref) + end end end end @@ -292,9 +325,11 @@ function mt:_lookIntoBlock(block, topNode) end for _, action in ipairs(block) do if self._has[action] then + topNode = self:_fastWardCasts(action.start, topNode) topNode = self:_lookIntoChild(action, topNode) end end + topNode = self:_fastWardCasts(block.finish, topNode) return topNode end @@ -307,7 +342,7 @@ function vm.launchRunner(loc, callback) end local self = setmetatable({ _loc = loc, - _objs = {}, + _casts = {}, _mark = {}, _has = {}, _main = main, @@ -316,8 +351,5 @@ function vm.launchRunner(loc, callback) self:_collect() - if #self._objs == 0 then - return - end self:_lookIntoBlock(main, vm.getNode(loc):copy()) end |