summaryrefslogtreecommitdiff
path: root/script/core/collector.lua
blob: 57ae3adc78af31278cb0864f8f80091493a7c7a4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
local scope = require 'workspace.scope'
local ws    = require 'workspace'

local collect    = {}
local subscribed = {}

local m = {}

--- 订阅一个名字
---@param uri uri
---@param name string
---@param value any
function m.subscribe(uri, name, value)
    -- 订阅部分
    local uriSubscribed = subscribed[uri]
    if not uriSubscribed then
        uriSubscribed = {}
        subscribed[uri] = uriSubscribed
    end
    uriSubscribed[name] = true
    -- 收集部分
    local nameCollect = collect[name]
    if not nameCollect then
        nameCollect = {}
        collect[name] = nameCollect
    end
    if value == nil then
        value = true
    end
    nameCollect[uri] = value
end

--- 丢弃掉某个 uri 中收集的所有信息
---@param uri uri
function m.dropUri(uri)
    local uriSubscribed = subscribed[uri]
    if not uriSubscribed then
        return
    end
    subscribed[uri] = nil
    for name in pairs(uriSubscribed) do
        collect[name][uri] = nil
    end
end

--- 是否包含某个名字的订阅
---@param name string
---@return boolean
function m.has(name)
    local nameCollect = collect[name]
    if not nameCollect then
        return false
    end
    if next(nameCollect) == nil then
        collect[name] = nil
        return false
    end
    return true
end

local DUMMY_FUNCTION = function () end

---@param scp scope
local function eachOfFolder(nameCollect, scp)
    local curi, value

    local function getNext()
        curi, value = next(nameCollect, curi)
        if not curi then
            return nil, nil
        end
        if scp:isChildUri(curi)
        or scp:isLinkedUri(curi) then
            return value, curi
        end
        return getNext()
    end

    return getNext
end

---@param scp scope
local function eachOfLinked(nameCollect, scp)
    local curi, value

    local function getNext()
        curi, value = next(nameCollect, curi)
        if not curi then
            return nil, nil
        end
        if  scp:isChildUri(curi)
        and scp:isLinkedUri(curi) then
            return value, curi
        end

        local cscp =   scope.getFolder(curi)
                    or scope.getLinkedScope(curi)
                    or scope.fallback

        if cscp == scp
        or cscp:isChildUri(scp.uri)
        or cscp:isLinkedUri(scp.uri) then
            return value, curi
        end

        return getNext()
    end

    return getNext
end

---@param scp scope
local function eachOfFallback(nameCollect, scp)
    local curi, value

    local function getNext()
        curi, value = next(nameCollect, curi)
        if not curi then
            return nil, nil
        end
        if scp:isLinkedUri(curi) then
            return value, curi
        end

        local cscp =   scope.getFolder(curi)
                    or scope.getLinkedScope(curi)
                    or scope.fallback

        if cscp == scp then
            return value, curi
        end

        return getNext()
    end

    return getNext
end

--- 迭代某个名字的订阅
---@param uri  uri
---@param name string
function m.each(uri, name)
    local nameCollect = collect[name]
    if not nameCollect then
        return DUMMY_FUNCTION
    end

    local scp = scope.getFolder(uri)

    if scp then
        return eachOfFolder(nameCollect, scp)
    end

    scp = scope.getLinkedScope(uri)

    if scp then
        return eachOfLinked(nameCollect, scp)
    end

    return eachOfFallback(nameCollect, scope.fallback)
end

return m