summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--script/parser/guide.lua47
-rw-r--r--script/vm/compiler.lua8
-rw-r--r--script/vm/runner.lua280
3 files changed, 156 insertions, 179 deletions
diff --git a/script/parser/guide.lua b/script/parser/guide.lua
index b783a9e1..f782c43a 100644
--- a/script/parser/guide.lua
+++ b/script/parser/guide.lua
@@ -198,6 +198,43 @@ end
return f
end})
+local eachChildMap = setmetatable({}, {__index = function (self, name)
+ local defs = childMap[name]
+ if not defs then
+ self[name] = false
+ return false
+ end
+ local text = {}
+ text[#text+1] = 'local obj, callback = ...'
+ for _, def in ipairs(defs) do
+ if def == '#' then
+ text[#text+1] = [[
+for i = 1, #obj do
+ callback(obj[i])
+end
+]]
+ elseif type(def) == 'string' and def:sub(1, 1) == '#' then
+ local key = def:sub(2)
+ text[#text+1] = ([[
+local childs = obj.%s
+if childs then
+ for i = 1, #childs do
+ callback(childs[i])
+ end
+end
+]]):format(key)
+ elseif type(def) == 'string' then
+ text[#text+1] = ('callback(obj.%s)'):format(def)
+ else
+ text[#text+1] = ('callback(obj[%q])'):format(def)
+ end
+ end
+ local buf = table.concat(text, '\n')
+ local f = load(buf, buf, 't')
+ self[name] = f
+ return f
+end})
+
m.actionMap = {
['main'] = {'#'},
['repeat'] = {'#'},
@@ -752,6 +789,16 @@ function m.eachSource(ast, callback)
end
end
+---@param source parser.object
+---@param callback fun(src: parser.object)
+function m.eachChild(source, callback)
+ local f = eachChildMap[source.type]
+ if not f then
+ return
+ end
+ f(source, callback)
+end
+
--- 获取指定的 special
function m.eachSpecialOf(ast, name, callback)
local root = m.getRoot(ast)
diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua
index 1ba20785..6932a8c8 100644
--- a/script/vm/compiler.lua
+++ b/script/vm/compiler.lua
@@ -1313,14 +1313,6 @@ local compilerSwitch = util.switch()
else
vm.setNode(src, vm.compileNode(src.value), true)
end
- elseif src.value
- and src.value.type == 'binary'
- and src.value.op and src.value.op.type == 'or'
- and src.value[1] and src.value[1].type == 'getlocal' and src.value[1].node == source then
- -- x = x or 1
- vm.setNode(src, vm.compileNode(src.value))
- else
- vm.setNode(src, node, true)
end
return vm.getNode(src)
elseif src.type == 'getlocal' then
diff --git a/script/vm/runner.lua b/script/vm/runner.lua
index 97f362f3..87619e54 100644
--- a/script/vm/runner.lua
+++ b/script/vm/runner.lua
@@ -2,13 +2,15 @@
local vm = require 'vm.vm'
local guide = require 'parser.guide'
----@alias vm.runner.callback fun(src: parser.object, node: vm.node)
+---@alias vm.runner.callback fun(src: parser.object, node?: vm.node)
---@class vm.runner
---@field _loc parser.object
---@field _objs parser.object[]
---@field _callback vm.runner.callback
---@field _mark table
+---@field _has table<parser.object, true>
+---@field _main parser.object
local mt = {}
mt.__index = mt
mt._index = 1
@@ -28,6 +30,20 @@ function mt:_getCasts()
return root._casts
end
+---@param obj parser.object
+function mt:_markHas(obj)
+ while true do
+ if self._has[obj] then
+ return
+ end
+ self._has[obj] = true
+ if obj == self._main then
+ return
+ end
+ obj = obj.parent
+ end
+end
+
function mt:_collect()
local startPos = self._loc.start
local finishPos = 0
@@ -59,99 +75,56 @@ function mt:_collect()
table.sort(self._objs, function (a, b)
return a.start < b.start
end)
-end
-
----@param pos integer
----@param node vm.node
----@return vm.node
----@return parser.object?
-function mt:_fastWard(pos, node)
- for i = self._index, #self._objs do
- local obj = self._objs[i]
- if obj.finish > pos then
- self._index = i
- return node, obj
- end
- if obj.type == 'getlocal' then
- self._callback(obj, node)
- elseif obj.type == 'doc.cast' then
- node = node:copy()
- for _, cast in ipairs(obj.casts) do
- if cast.mode == '+' then
- if cast.optional then
- node:addOptional()
- end
- if cast.extends then
- node:merge(vm.compileNode(cast.extends))
- end
- elseif cast.mode == '-' then
- if cast.optional then
- node:removeOptional()
- end
- if cast.extends then
- node:removeNode(vm.compileNode(cast.extends))
- end
- else
- if cast.extends then
- node:clear()
- node:merge(vm.compileNode(cast.extends))
- end
- end
- end
- end
+ for _, obj in ipairs(self._objs) do
+ self:_markHas(obj)
end
- self._index = #self._objs + 1
- return node, nil
end
----@param exp parser.object
+---@param action parser.object
---@param topNode vm.node
---@param outNode? vm.node
---@return vm.node topNode
---@return vm.node outNode
-function mt:_lookIntoExp(exp, topNode, outNode)
- if not exp then
- return topNode, outNode or topNode
+function mt:_lookIntoChild(action, topNode, outNode)
+ if not self._has[action]
+ or self._mark[action] then
+ return topNode, topNode or outNode
end
- if self._mark[exp] then
- return topNode, outNode or topNode
- end
- self._mark[exp] = true
- local top = self._objs[self._index]
- if not top then
- return topNode, outNode or topNode
- end
- if exp.type == 'function' then
- self:_launchBlock(exp, topNode:copy())
- elseif exp.type == 'unary' then
- if not exp[1] then
+ self._mark[action] = true
+ if action.type == 'getlocal' then
+ self._callback(action, topNode)
+ topNode = topNode:copy():setTruthy()
+ elseif action.type == 'function' then
+ self:_lookIntoBlock(action, topNode:copy())
+ elseif action.type == 'unary' then
+ if not action[1] then
goto RETURN
end
- if exp.op.type == 'not' then
+ if action.op.type == 'not' then
outNode = outNode or topNode:copy()
- outNode, topNode = self:_lookIntoExp(exp[1], topNode, outNode)
+ outNode, topNode = self:_lookIntoChild(action[1], topNode, outNode)
end
- elseif exp.type == 'binary' then
- if not exp[1] or not exp[2] then
+ elseif action.type == 'binary' then
+ if not action[1] or not action[2] then
goto RETURN
end
- if exp.op.type == 'and' then
- topNode = self:_lookIntoExp(exp[1], topNode)
- topNode = self:_lookIntoExp(exp[2], topNode)
- elseif exp.op.type == 'or' then
+ if action.op.type == 'and' then
+ topNode = self:_lookIntoChild(action[1], topNode)
+ topNode = self:_lookIntoChild(action[2], topNode)
+ elseif action.op.type == 'or' then
outNode = outNode or topNode:copy()
- local topNode1, outNode1 = self:_lookIntoExp(exp[1], topNode, outNode)
- local topNode2, outNode2 = self:_lookIntoExp(exp[2], outNode1, outNode1:copy())
+ local topNode1, outNode1 = self:_lookIntoChild(action[1], topNode, outNode)
+ local topNode2, outNode2 = self:_lookIntoChild(action[2], outNode1, outNode1:copy())
topNode = vm.createNode(topNode1, topNode2)
outNode = outNode2:copy()
- elseif exp.op.type == '=='
- or exp.op.type == '~=' then
+ elseif action.op.type == '=='
+ or action.op.type == '~=' then
local handler, checker
for i = 1, 2 do
- if guide.isLiteral(exp[i]) then
- checker = exp[i]
- handler = exp[3-i] -- Copilot tells me use `3-i` instead of `i%2+1`
+ if guide.isLiteral(action[i]) then
+ checker = action[i]
+ handler = action[3-i] -- Copilot tells me use `3-i` instead of `i%2+1`
end
end
if not handler then
@@ -160,9 +133,9 @@ function mt:_lookIntoExp(exp, topNode, outNode)
if handler.type == 'getlocal'
and handler.node == self._loc then
-- if x == y then
- self:_fastWard(exp.finish, topNode:copy())
+ self:_lookIntoChild(handler, topNode:copy())
local checkerNode = vm.compileNode(checker)
- if exp.op.type == '==' then
+ if action.op.type == '==' then
topNode = checkerNode
if outNode then
outNode:removeNode(topNode)
@@ -181,8 +154,8 @@ function mt:_lookIntoExp(exp, topNode, outNode)
and handler.args[1].type == 'getlocal'
and handler.args[1].node == self._loc then
-- if type(x) == 'string' then
- self:_fastWard(exp.finish, topNode:copy())
- if exp.op.type == '==' then
+ self:_lookIntoChild(handler, topNode:copy())
+ if action.op.type == '==' then
topNode:narrow(checker[1])
if outNode then
outNode:remove(checker[1])
@@ -208,7 +181,7 @@ function mt:_lookIntoExp(exp, topNode, outNode)
and call.args[1].type == 'getlocal'
and call.args[1].node == self._loc then
-- `local tp = type(x);if tp == 'string' then`
- if exp.op.type == '==' then
+ if action.op.type == '==' then
topNode:narrow(checker[1])
if outNode then
outNode:remove(checker[1])
@@ -223,72 +196,17 @@ function mt:_lookIntoExp(exp, topNode, outNode)
end
end
end
- elseif exp.type == 'getlocal' then
- if exp.node == self._loc then
- topNode = self:_fastWard(exp.finish, topNode)
- topNode = topNode:copy():setTruthy()
- if outNode then
- outNode:setFalsy()
- end
- end
- elseif exp.type == 'paren' then
- topNode, outNode = self:_lookIntoExp(exp.exp, topNode, outNode)
- elseif exp.type == 'getindex' then
- self:_lookIntoExp(exp.index, topNode)
- elseif exp.type == 'table' then
- for _, field in ipairs(exp) do
- self:_lookIntoAction(field, topNode)
- end
- end
- ::RETURN::
- topNode = self:_fastWard(exp.finish, topNode)
- return topNode, outNode or topNode
-end
-
----@param action parser.object
----@param topNode vm.node
----@return vm.node topNode
-function mt:_lookIntoAction(action, topNode)
- if not action then
- return topNode
- end
- if self._mark[action] then
- return topNode
- end
- self._mark[action] = true
- local top = self._objs[self._index]
- if not top then
- return topNode
- end
- if not guide.isInRange(action, top.finish)
- -- trick for `local tp = type(x);if tp == 'string' then`
- and action.type ~= 'binary' then
- return topNode
- end
- local value = vm.getObjectValue(action)
- if value then
- self:_lookIntoExp(value, topNode:copy())
- end
- if action.type == 'setlocal' then
- if action.node == self._loc then
- local newTopNode = self._callback(action, topNode)
- if newTopNode then
- topNode = newTopNode
- end
- end
- elseif action.type == 'function' then
- self:_launchBlock(action, topNode:copy())
elseif action.type == 'loop'
or action.type == 'in'
or action.type == 'repeat'
or action.type == 'for' then
- topNode = self:_launchBlock(action, topNode:copy())
+ topNode = self:_lookIntoBlock(action, topNode:copy())
elseif action.type == 'while' then
- local blockNode, mainNode = self:_lookIntoExp(action.filter, topNode:copy(), topNode:copy())
+ local blockNode, mainNode = self:_lookIntoChild(action.filter, topNode:copy(), topNode:copy())
if action.filter then
- self:_fastWard(action.filter.finish, blockNode)
+ self:_lookIntoChild(action.filter, topNode)
end
- blockNode = self:_launchBlock(action, blockNode:copy())
+ blockNode = self:_lookIntoBlock(action, blockNode:copy())
if mainNode then
topNode = mainNode:merge(blockNode)
end
@@ -299,13 +217,12 @@ function mt:_lookIntoAction(action, topNode)
for _, subBlock in ipairs(action) do
local blockNode = mainNode:copy()
if subBlock.filter then
- blockNode, mainNode = self:_lookIntoExp(subBlock.filter, blockNode, mainNode)
- self:_fastWard(subBlock.filter.finish, blockNode)
+ blockNode, mainNode = self:_lookIntoChild(subBlock.filter, blockNode, mainNode)
else
hasElse = true
mainNode:clear()
end
- blockNode = self:_launchBlock(subBlock, blockNode:copy())
+ blockNode = self:_lookIntoBlock(subBlock, blockNode:copy())
local neverReturn = subBlock.hasReturn
or subBlock.hasGoTo
or subBlock.hasBreak
@@ -323,51 +240,77 @@ function mt:_lookIntoAction(action, topNode)
topNode = mainNode
elseif action.type == 'call' then
if action.node.special == 'assert' and action.args and action.args[1] then
- topNode = self:_lookIntoExp(action.args[1], topNode)
- elseif action.args then
- for _, arg in ipairs(action.args) do
- self:_lookIntoExp(arg, topNode)
+ topNode = self:_lookIntoChild(action.args[1], topNode)
+ end
+ elseif action.type == 'setlocal' then
+ if action.node == self._loc then
+ if action.value then
+ self:_lookIntoChild(action.value, topNode)
end
+ topNode = self._callback(action)
end
- elseif action.type == 'return' then
- for _, rtn in ipairs(action) do
- self:_lookIntoExp(rtn, topNode)
+ elseif action.type == 'doc.cast' then
+ topNode = topNode:copy()
+ for _, cast in ipairs(action.casts) do
+ if cast.mode == '+' then
+ if cast.optional then
+ topNode:addOptional()
+ end
+ if cast.extends then
+ topNode:merge(vm.compileNode(cast.extends))
+ end
+ elseif cast.mode == '-' then
+ if cast.optional then
+ topNode:removeOptional()
+ end
+ if cast.extends then
+ topNode:removeNode(vm.compileNode(cast.extends))
+ end
+ else
+ if cast.extends then
+ topNode:clear()
+ topNode:merge(vm.compileNode(cast.extends))
+ end
+ end
end
end
- topNode = self:_fastWard(action.finish, topNode)
- return topNode
+ guide.eachChild(action, function (src)
+ if self._has[src] then
+ self:_lookIntoChild(src, topNode)
+ end
+ end)
+ ::RETURN::
+ return topNode, outNode or topNode
end
----@param block parser.object
----@param node vm.node
----@return vm.node
-function mt:_launchBlock(block, node)
- local topNode, top = self:_fastWard(block.start, node)
- if not top then
+---@param block parser.object
+---@param topNode vm.node
+---@return vm.node topNode
+function mt:_lookIntoBlock(block, topNode)
+ if not self._has[block] then
return topNode
end
for _, action in ipairs(block) do
- if (action.range or action.finish) < top.finish then
- goto CONTINUE
+ if self._has[action] then
+ topNode = self:_lookIntoChild(action, topNode)
end
- topNode = self:_lookIntoAction(action, topNode)
- topNode, top = self:_fastWard(action.finish, topNode)
- if not top then
- return topNode
- end
- ::CONTINUE::
end
- topNode = self:_fastWard(block.finish, topNode)
return topNode
end
---@param loc parser.object
---@param callback vm.runner.callback
function vm.launchRunner(loc, callback)
+ local main = guide.getParentBlock(loc)
+ if not main then
+ return
+ end
local self = setmetatable({
_loc = loc,
_objs = {},
_mark = {},
+ _has = {},
+ _main = main,
_callback = callback,
}, mt)
@@ -376,10 +319,5 @@ function vm.launchRunner(loc, callback)
if #self._objs == 0 then
return
end
- local main = guide.getParentBlock(loc)
- if not main then
- return
- end
- local topNode = self:_launchBlock(main, vm.getNode(loc):copy())
- self:_fastWard(math.maxinteger, topNode)
+ self:_lookIntoBlock(main, vm.getNode(loc):copy())
end