diff options
-rw-r--r-- | script/parser/guide.lua | 47 | ||||
-rw-r--r-- | script/vm/compiler.lua | 8 | ||||
-rw-r--r-- | script/vm/runner.lua | 280 |
3 files changed, 156 insertions, 179 deletions
diff --git a/script/parser/guide.lua b/script/parser/guide.lua index b783a9e1..f782c43a 100644 --- a/script/parser/guide.lua +++ b/script/parser/guide.lua @@ -198,6 +198,43 @@ end return f end}) +local eachChildMap = setmetatable({}, {__index = function (self, name) + local defs = childMap[name] + if not defs then + self[name] = false + return false + end + local text = {} + text[#text+1] = 'local obj, callback = ...' + for _, def in ipairs(defs) do + if def == '#' then + text[#text+1] = [[ +for i = 1, #obj do + callback(obj[i]) +end +]] + elseif type(def) == 'string' and def:sub(1, 1) == '#' then + local key = def:sub(2) + text[#text+1] = ([[ +local childs = obj.%s +if childs then + for i = 1, #childs do + callback(childs[i]) + end +end +]]):format(key) + elseif type(def) == 'string' then + text[#text+1] = ('callback(obj.%s)'):format(def) + else + text[#text+1] = ('callback(obj[%q])'):format(def) + end + end + local buf = table.concat(text, '\n') + local f = load(buf, buf, 't') + self[name] = f + return f +end}) + m.actionMap = { ['main'] = {'#'}, ['repeat'] = {'#'}, @@ -752,6 +789,16 @@ function m.eachSource(ast, callback) end end +---@param source parser.object +---@param callback fun(src: parser.object) +function m.eachChild(source, callback) + local f = eachChildMap[source.type] + if not f then + return + end + f(source, callback) +end + --- 获取指定的 special function m.eachSpecialOf(ast, name, callback) local root = m.getRoot(ast) diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua index 1ba20785..6932a8c8 100644 --- a/script/vm/compiler.lua +++ b/script/vm/compiler.lua @@ -1313,14 +1313,6 @@ local compilerSwitch = util.switch() else vm.setNode(src, vm.compileNode(src.value), true) end - elseif src.value - and src.value.type == 'binary' - and src.value.op and src.value.op.type == 'or' - and src.value[1] and src.value[1].type == 'getlocal' and src.value[1].node == source then - -- x = x or 1 - vm.setNode(src, vm.compileNode(src.value)) - else - vm.setNode(src, node, true) end return vm.getNode(src) elseif src.type == 'getlocal' then diff --git a/script/vm/runner.lua b/script/vm/runner.lua index 97f362f3..87619e54 100644 --- a/script/vm/runner.lua +++ b/script/vm/runner.lua @@ -2,13 +2,15 @@ local vm = require 'vm.vm' local guide = require 'parser.guide' ----@alias vm.runner.callback fun(src: parser.object, node: vm.node) +---@alias vm.runner.callback fun(src: parser.object, node?: vm.node) ---@class vm.runner ---@field _loc parser.object ---@field _objs parser.object[] ---@field _callback vm.runner.callback ---@field _mark table +---@field _has table<parser.object, true> +---@field _main parser.object local mt = {} mt.__index = mt mt._index = 1 @@ -28,6 +30,20 @@ function mt:_getCasts() return root._casts end +---@param obj parser.object +function mt:_markHas(obj) + while true do + if self._has[obj] then + return + end + self._has[obj] = true + if obj == self._main then + return + end + obj = obj.parent + end +end + function mt:_collect() local startPos = self._loc.start local finishPos = 0 @@ -59,99 +75,56 @@ function mt:_collect() table.sort(self._objs, function (a, b) return a.start < b.start end) -end - ----@param pos integer ----@param node vm.node ----@return vm.node ----@return parser.object? -function mt:_fastWard(pos, node) - for i = self._index, #self._objs do - local obj = self._objs[i] - if obj.finish > pos then - self._index = i - return node, obj - end - if obj.type == 'getlocal' then - self._callback(obj, node) - elseif obj.type == 'doc.cast' then - node = node:copy() - for _, cast in ipairs(obj.casts) do - if cast.mode == '+' then - if cast.optional then - node:addOptional() - end - if cast.extends then - node:merge(vm.compileNode(cast.extends)) - end - elseif cast.mode == '-' then - if cast.optional then - node:removeOptional() - end - if cast.extends then - node:removeNode(vm.compileNode(cast.extends)) - end - else - if cast.extends then - node:clear() - node:merge(vm.compileNode(cast.extends)) - end - end - end - end + for _, obj in ipairs(self._objs) do + self:_markHas(obj) end - self._index = #self._objs + 1 - return node, nil end ----@param exp parser.object +---@param action parser.object ---@param topNode vm.node ---@param outNode? vm.node ---@return vm.node topNode ---@return vm.node outNode -function mt:_lookIntoExp(exp, topNode, outNode) - if not exp then - return topNode, outNode or topNode +function mt:_lookIntoChild(action, topNode, outNode) + if not self._has[action] + or self._mark[action] then + return topNode, topNode or outNode end - if self._mark[exp] then - return topNode, outNode or topNode - end - self._mark[exp] = true - local top = self._objs[self._index] - if not top then - return topNode, outNode or topNode - end - if exp.type == 'function' then - self:_launchBlock(exp, topNode:copy()) - elseif exp.type == 'unary' then - if not exp[1] then + self._mark[action] = true + if action.type == 'getlocal' then + self._callback(action, topNode) + topNode = topNode:copy():setTruthy() + elseif action.type == 'function' then + self:_lookIntoBlock(action, topNode:copy()) + elseif action.type == 'unary' then + if not action[1] then goto RETURN end - if exp.op.type == 'not' then + if action.op.type == 'not' then outNode = outNode or topNode:copy() - outNode, topNode = self:_lookIntoExp(exp[1], topNode, outNode) + outNode, topNode = self:_lookIntoChild(action[1], topNode, outNode) end - elseif exp.type == 'binary' then - if not exp[1] or not exp[2] then + elseif action.type == 'binary' then + if not action[1] or not action[2] then goto RETURN end - if exp.op.type == 'and' then - topNode = self:_lookIntoExp(exp[1], topNode) - topNode = self:_lookIntoExp(exp[2], topNode) - elseif exp.op.type == 'or' then + if action.op.type == 'and' then + topNode = self:_lookIntoChild(action[1], topNode) + topNode = self:_lookIntoChild(action[2], topNode) + elseif action.op.type == 'or' then outNode = outNode or topNode:copy() - local topNode1, outNode1 = self:_lookIntoExp(exp[1], topNode, outNode) - local topNode2, outNode2 = self:_lookIntoExp(exp[2], outNode1, outNode1:copy()) + local topNode1, outNode1 = self:_lookIntoChild(action[1], topNode, outNode) + local topNode2, outNode2 = self:_lookIntoChild(action[2], outNode1, outNode1:copy()) topNode = vm.createNode(topNode1, topNode2) outNode = outNode2:copy() - elseif exp.op.type == '==' - or exp.op.type == '~=' then + elseif action.op.type == '==' + or action.op.type == '~=' then local handler, checker for i = 1, 2 do - if guide.isLiteral(exp[i]) then - checker = exp[i] - handler = exp[3-i] -- Copilot tells me use `3-i` instead of `i%2+1` + if guide.isLiteral(action[i]) then + checker = action[i] + handler = action[3-i] -- Copilot tells me use `3-i` instead of `i%2+1` end end if not handler then @@ -160,9 +133,9 @@ function mt:_lookIntoExp(exp, topNode, outNode) if handler.type == 'getlocal' and handler.node == self._loc then -- if x == y then - self:_fastWard(exp.finish, topNode:copy()) + self:_lookIntoChild(handler, topNode:copy()) local checkerNode = vm.compileNode(checker) - if exp.op.type == '==' then + if action.op.type == '==' then topNode = checkerNode if outNode then outNode:removeNode(topNode) @@ -181,8 +154,8 @@ function mt:_lookIntoExp(exp, topNode, outNode) and handler.args[1].type == 'getlocal' and handler.args[1].node == self._loc then -- if type(x) == 'string' then - self:_fastWard(exp.finish, topNode:copy()) - if exp.op.type == '==' then + self:_lookIntoChild(handler, topNode:copy()) + if action.op.type == '==' then topNode:narrow(checker[1]) if outNode then outNode:remove(checker[1]) @@ -208,7 +181,7 @@ function mt:_lookIntoExp(exp, topNode, outNode) and call.args[1].type == 'getlocal' and call.args[1].node == self._loc then -- `local tp = type(x);if tp == 'string' then` - if exp.op.type == '==' then + if action.op.type == '==' then topNode:narrow(checker[1]) if outNode then outNode:remove(checker[1]) @@ -223,72 +196,17 @@ function mt:_lookIntoExp(exp, topNode, outNode) end end end - elseif exp.type == 'getlocal' then - if exp.node == self._loc then - topNode = self:_fastWard(exp.finish, topNode) - topNode = topNode:copy():setTruthy() - if outNode then - outNode:setFalsy() - end - end - elseif exp.type == 'paren' then - topNode, outNode = self:_lookIntoExp(exp.exp, topNode, outNode) - elseif exp.type == 'getindex' then - self:_lookIntoExp(exp.index, topNode) - elseif exp.type == 'table' then - for _, field in ipairs(exp) do - self:_lookIntoAction(field, topNode) - end - end - ::RETURN:: - topNode = self:_fastWard(exp.finish, topNode) - return topNode, outNode or topNode -end - ----@param action parser.object ----@param topNode vm.node ----@return vm.node topNode -function mt:_lookIntoAction(action, topNode) - if not action then - return topNode - end - if self._mark[action] then - return topNode - end - self._mark[action] = true - local top = self._objs[self._index] - if not top then - return topNode - end - if not guide.isInRange(action, top.finish) - -- trick for `local tp = type(x);if tp == 'string' then` - and action.type ~= 'binary' then - return topNode - end - local value = vm.getObjectValue(action) - if value then - self:_lookIntoExp(value, topNode:copy()) - end - if action.type == 'setlocal' then - if action.node == self._loc then - local newTopNode = self._callback(action, topNode) - if newTopNode then - topNode = newTopNode - end - end - elseif action.type == 'function' then - self:_launchBlock(action, topNode:copy()) elseif action.type == 'loop' or action.type == 'in' or action.type == 'repeat' or action.type == 'for' then - topNode = self:_launchBlock(action, topNode:copy()) + topNode = self:_lookIntoBlock(action, topNode:copy()) elseif action.type == 'while' then - local blockNode, mainNode = self:_lookIntoExp(action.filter, topNode:copy(), topNode:copy()) + local blockNode, mainNode = self:_lookIntoChild(action.filter, topNode:copy(), topNode:copy()) if action.filter then - self:_fastWard(action.filter.finish, blockNode) + self:_lookIntoChild(action.filter, topNode) end - blockNode = self:_launchBlock(action, blockNode:copy()) + blockNode = self:_lookIntoBlock(action, blockNode:copy()) if mainNode then topNode = mainNode:merge(blockNode) end @@ -299,13 +217,12 @@ function mt:_lookIntoAction(action, topNode) for _, subBlock in ipairs(action) do local blockNode = mainNode:copy() if subBlock.filter then - blockNode, mainNode = self:_lookIntoExp(subBlock.filter, blockNode, mainNode) - self:_fastWard(subBlock.filter.finish, blockNode) + blockNode, mainNode = self:_lookIntoChild(subBlock.filter, blockNode, mainNode) else hasElse = true mainNode:clear() end - blockNode = self:_launchBlock(subBlock, blockNode:copy()) + blockNode = self:_lookIntoBlock(subBlock, blockNode:copy()) local neverReturn = subBlock.hasReturn or subBlock.hasGoTo or subBlock.hasBreak @@ -323,51 +240,77 @@ function mt:_lookIntoAction(action, topNode) topNode = mainNode elseif action.type == 'call' then if action.node.special == 'assert' and action.args and action.args[1] then - topNode = self:_lookIntoExp(action.args[1], topNode) - elseif action.args then - for _, arg in ipairs(action.args) do - self:_lookIntoExp(arg, topNode) + topNode = self:_lookIntoChild(action.args[1], topNode) + end + elseif action.type == 'setlocal' then + if action.node == self._loc then + if action.value then + self:_lookIntoChild(action.value, topNode) end + topNode = self._callback(action) end - elseif action.type == 'return' then - for _, rtn in ipairs(action) do - self:_lookIntoExp(rtn, topNode) + elseif action.type == 'doc.cast' then + topNode = topNode:copy() + for _, cast in ipairs(action.casts) do + if cast.mode == '+' then + if cast.optional then + topNode:addOptional() + end + if cast.extends then + topNode:merge(vm.compileNode(cast.extends)) + end + elseif cast.mode == '-' then + if cast.optional then + topNode:removeOptional() + end + if cast.extends then + topNode:removeNode(vm.compileNode(cast.extends)) + end + else + if cast.extends then + topNode:clear() + topNode:merge(vm.compileNode(cast.extends)) + end + end end end - topNode = self:_fastWard(action.finish, topNode) - return topNode + guide.eachChild(action, function (src) + if self._has[src] then + self:_lookIntoChild(src, topNode) + end + end) + ::RETURN:: + return topNode, outNode or topNode end ----@param block parser.object ----@param node vm.node ----@return vm.node -function mt:_launchBlock(block, node) - local topNode, top = self:_fastWard(block.start, node) - if not top then +---@param block parser.object +---@param topNode vm.node +---@return vm.node topNode +function mt:_lookIntoBlock(block, topNode) + if not self._has[block] then return topNode end for _, action in ipairs(block) do - if (action.range or action.finish) < top.finish then - goto CONTINUE + if self._has[action] then + topNode = self:_lookIntoChild(action, topNode) end - topNode = self:_lookIntoAction(action, topNode) - topNode, top = self:_fastWard(action.finish, topNode) - if not top then - return topNode - end - ::CONTINUE:: end - topNode = self:_fastWard(block.finish, topNode) return topNode end ---@param loc parser.object ---@param callback vm.runner.callback function vm.launchRunner(loc, callback) + local main = guide.getParentBlock(loc) + if not main then + return + end local self = setmetatable({ _loc = loc, _objs = {}, _mark = {}, + _has = {}, + _main = main, _callback = callback, }, mt) @@ -376,10 +319,5 @@ function vm.launchRunner(loc, callback) if #self._objs == 0 then return end - local main = guide.getParentBlock(loc) - if not main then - return - end - local topNode = self:_launchBlock(main, vm.getNode(loc):copy()) - self:_fastWard(math.maxinteger, topNode) + self:_lookIntoBlock(main, vm.getNode(loc):copy()) end |