summaryrefslogtreecommitdiff
path: root/script/vm/runner.lua
blob: 87630512dc57f5c806110c4b9d60add8d732a1e7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
---@class vm
local vm        = require 'vm.vm'
local guide     = require 'parser.guide'

---@alias vm.runner.callback fun(src: parser.object, node: vm.node)

---@class vm.runner
---@field _loc      parser.object
---@field _objs     parser.object[]
---@field _callback vm.runner.callback
local mt = {}
mt.__index = mt
mt._index = 1

---@return parser.object[]
function mt:_getCasts()
    local root = guide.getRoot(self._loc)
    if not root._casts then
        root._casts = {}
        local docs = root.docs
        for _, doc in ipairs(docs) do
            if doc.type == 'doc.cast' and doc.loc then
                root._casts[#root._casts+1] = doc
            end
        end
    end
    return root._casts
end

function mt:_collect()
    local startPos  = self._loc.start
    local finishPos = 0

    for _, ref in ipairs(self._loc.ref) do
        if ref.type == 'getlocal'
        or ref.type == 'setlocal' then
            self._objs[#self._objs+1] = ref
            if ref.start > finishPos then
                finishPos = ref.start
            end
        end
    end

    if #self._objs == 0 then
        return
    end

    local casts = self:_getCasts()
    for _, cast in ipairs(casts) do
        if  cast.loc[1] == self._loc[1]
        and cast.start > startPos
        and cast.finish < finishPos
        and guide.getLocal(self._loc, self._loc[1], cast.start) == self._loc then
            self._objs[#self._objs+1] = cast
        end
    end

    table.sort(self._objs, function (a, b)
        return (a.range or a.start) < (b.range or b.start)
    end)
end


---@param pos  integer
---@param node vm.node
---@return parser.object
---@return vm.node
function mt:_fastWard(pos, node)
    for i = self._index, #self._objs do
        local obj = self._objs[i]
        if obj.start > pos then
            self._index = i
            return obj, node
        end
        if obj.type == 'getlocal' then
            self._callback(obj, node)
        elseif obj.type == 'setlocal' then
            local newNode = self._callback(obj, node)
            if newNode then
                node = newNode:copy()
            end
        else
            error('unexpected type: ' .. obj.type)
        end
    end
end

---@param action  parser.object
---@param topNode vm.node
function mt:_lookInto(action, topNode)
    action = vm.getObjectValue(action) or action
    if action.type == 'function' then
        self:_launchBlock(action, topNode:copy())
    end
end

---@param block parser.object
---@param node  vm.node
function mt:_launchBlock(block, node)
    local top, topNode = self:_fastWard(block.start, node)
    if not top then
        return
    end
    for _, action in ipairs(block) do
        local finish = action.range or action.finish
        if finish < top.start then
            goto CONTINUE
        end
        self:_lookInto(action, topNode)
        top, topNode = self:_fastWard(action.finish, topNode)
        if not top then
            return
        end
        ::CONTINUE::
    end
    self:_fastWard(block.finish, topNode)
end

---@param loc parser.object
---@param callback vm.runner.callback
function vm.launchRunner(loc, callback)
    local self = setmetatable({
        _loc      = loc,
        _objs     = {},
        _callback = callback,
    }, mt)

    self:_collect()

    if #self._objs == 0 then
        return
    end

    self:_launchBlock(guide.getParentBlock(loc), vm.getNode(loc):copy())
end