summaryrefslogtreecommitdiff
path: root/script
diff options
context:
space:
mode:
Diffstat (limited to 'script')
-rw-r--r--script/vm/runner.lua128
1 files changed, 80 insertions, 48 deletions
diff --git a/script/vm/runner.lua b/script/vm/runner.lua
index 87619e54..6c3c69a1 100644
--- a/script/vm/runner.lua
+++ b/script/vm/runner.lua
@@ -6,7 +6,7 @@ local guide = require 'parser.guide'
---@class vm.runner
---@field _loc parser.object
----@field _objs parser.object[]
+---@field _casts parser.object[]
---@field _callback vm.runner.callback
---@field _mark table
---@field _has table<parser.object, true>
@@ -51,34 +51,60 @@ function mt:_collect()
for _, ref in ipairs(self._loc.ref) do
if ref.type == 'getlocal'
or ref.type == 'setlocal' then
- self._objs[#self._objs+1] = ref
- if ref.start > finishPos then
- finishPos = ref.start
+ self:_markHas(ref)
+ if ref.finish > finishPos then
+ finishPos = ref.finish
end
end
end
- if #self._objs == 0 then
- return
- end
-
local casts = self:_getCasts()
for _, cast in ipairs(casts) do
if cast.loc[1] == self._loc[1]
and cast.start > startPos
and cast.finish < finishPos
and guide.getLocal(self._loc, self._loc[1], cast.start) == self._loc then
- self._objs[#self._objs+1] = cast
+ self._casts[#self._casts+1] = cast
+ self:_markHas(cast)
end
end
+end
- table.sort(self._objs, function (a, b)
- return a.start < b.start
- end)
-
- for _, obj in ipairs(self._objs) do
- self:_markHas(obj)
+---@param pos integer
+---@param topNode vm.node
+---@return vm.node
+function mt:_fastWardCasts(pos, topNode)
+ for i = self._index, #self._casts do
+ local action = self._casts[i]
+ if action.start > pos then
+ self._index = i
+ return topNode
+ end
+ topNode = topNode:copy()
+ for _, cast in ipairs(action.casts) do
+ if cast.mode == '+' then
+ if cast.optional then
+ topNode:addOptional()
+ end
+ if cast.extends then
+ topNode:merge(vm.compileNode(cast.extends))
+ end
+ elseif cast.mode == '-' then
+ if cast.optional then
+ topNode:removeOptional()
+ end
+ if cast.extends then
+ topNode:removeNode(vm.compileNode(cast.extends))
+ end
+ else
+ if cast.extends then
+ topNode:clear()
+ topNode:merge(vm.compileNode(cast.extends))
+ end
+ end
+ end
end
+ return topNode
end
---@param action parser.object
@@ -93,8 +119,13 @@ function mt:_lookIntoChild(action, topNode, outNode)
end
self._mark[action] = true
if action.type == 'getlocal' then
- self._callback(action, topNode)
- topNode = topNode:copy():setTruthy()
+ if action.node == self._loc then
+ self._callback(action, topNode)
+ if outNode then
+ topNode = topNode:copy():setTruthy()
+ outNode = outNode:copy():setFalsy()
+ end
+ end
elseif action.type == 'function' then
self:_lookIntoBlock(action, topNode:copy())
elseif action.type == 'unary' then
@@ -110,8 +141,9 @@ function mt:_lookIntoChild(action, topNode, outNode)
goto RETURN
end
if action.op.type == 'and' then
- topNode = self:_lookIntoChild(action[1], topNode)
- topNode = self:_lookIntoChild(action[2], topNode)
+ local dummyNode = topNode:copy()
+ topNode = self:_lookIntoChild(action[1], topNode, dummyNode)
+ topNode = self:_lookIntoChild(action[2], topNode, dummyNode)
elseif action.op.type == 'or' then
outNode = outNode or topNode:copy()
local topNode1, outNode1 = self:_lookIntoChild(action[1], topNode, outNode)
@@ -133,7 +165,7 @@ function mt:_lookIntoChild(action, topNode, outNode)
if handler.type == 'getlocal'
and handler.node == self._loc then
-- if x == y then
- self:_lookIntoChild(handler, topNode:copy())
+ topNode = self:_lookIntoChild(handler, topNode, outNode)
local checkerNode = vm.compileNode(checker)
if action.op.type == '==' then
topNode = checkerNode
@@ -202,9 +234,12 @@ function mt:_lookIntoChild(action, topNode, outNode)
or action.type == 'for' then
topNode = self:_lookIntoBlock(action, topNode:copy())
elseif action.type == 'while' then
- local blockNode, mainNode = self:_lookIntoChild(action.filter, topNode:copy(), topNode:copy())
+ local blockNode, mainNode
if action.filter then
- self:_lookIntoChild(action.filter, topNode)
+ blockNode, mainNode = self:_lookIntoChild(action.filter, topNode:copy(), topNode:copy())
+ else
+ blockNode = topNode:copy()
+ mainNode = topNode:copy()
end
blockNode = self:_lookIntoBlock(action, blockNode:copy())
if mainNode then
@@ -240,8 +275,10 @@ function mt:_lookIntoChild(action, topNode, outNode)
topNode = mainNode
elseif action.type == 'call' then
if action.node.special == 'assert' and action.args and action.args[1] then
- topNode = self:_lookIntoChild(action.args[1], topNode)
+ topNode = self:_lookIntoChild(action.args[1], topNode, topNode:copy())
end
+ elseif action.type == 'paren' then
+ topNode, outNode = self:_lookIntoChild(action.exp, topNode, outNode)
elseif action.type == 'setlocal' then
if action.node == self._loc then
if action.value then
@@ -249,27 +286,23 @@ function mt:_lookIntoChild(action, topNode, outNode)
end
topNode = self._callback(action)
end
- elseif action.type == 'doc.cast' then
- topNode = topNode:copy()
- for _, cast in ipairs(action.casts) do
- if cast.mode == '+' then
- if cast.optional then
- topNode:addOptional()
- end
- if cast.extends then
- topNode:merge(vm.compileNode(cast.extends))
- end
- elseif cast.mode == '-' then
- if cast.optional then
- topNode:removeOptional()
- end
- if cast.extends then
- topNode:removeNode(vm.compileNode(cast.extends))
- end
- else
- if cast.extends then
- topNode:clear()
- topNode:merge(vm.compileNode(cast.extends))
+ elseif action.type == 'local' then
+ if action.value
+ and action.value.type == 'select' then
+ local index = action.value.sindex
+ local call = action.value.vararg
+ if index == 1
+ and call.type == 'call'
+ and call.node
+ and call.node.special == 'type'
+ and call.args then
+ local getLoc = call.args[1]
+ if getLoc
+ and getLoc.type == 'getlocal'
+ and getLoc.node == self._loc then
+ for _, ref in ipairs(action.ref) do
+ self:_markHas(ref)
+ end
end
end
end
@@ -292,9 +325,11 @@ function mt:_lookIntoBlock(block, topNode)
end
for _, action in ipairs(block) do
if self._has[action] then
+ topNode = self:_fastWardCasts(action.start, topNode)
topNode = self:_lookIntoChild(action, topNode)
end
end
+ topNode = self:_fastWardCasts(block.finish, topNode)
return topNode
end
@@ -307,7 +342,7 @@ function vm.launchRunner(loc, callback)
end
local self = setmetatable({
_loc = loc,
- _objs = {},
+ _casts = {},
_mark = {},
_has = {},
_main = main,
@@ -316,8 +351,5 @@ function vm.launchRunner(loc, callback)
self:_collect()
- if #self._objs == 0 then
- return
- end
self:_lookIntoBlock(main, vm.getNode(loc):copy())
end