summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--script-beta/vm/eachRef.lua619
-rw-r--r--test-beta/references/init.lua4
2 files changed, 313 insertions, 310 deletions
diff --git a/script-beta/vm/eachRef.lua b/script-beta/vm/eachRef.lua
index 3cf48987..7e04f69c 100644
--- a/script-beta/vm/eachRef.lua
+++ b/script-beta/vm/eachRef.lua
@@ -4,164 +4,7 @@ local vm = require 'vm.vm'
local library = require 'library'
local await = require 'await'
-local STATE_USED = 1 << 0
-local STATE_LOCAL = 1 << 1
-local STATE_NODE = 1 << 2
-local STATE_LABEL = 1 << 3
-
-local function markFlag(state, source, flag)
- local flags = state[source] or 0
- if flags & flag ~= 0 then
- return false
- end
- state[source] = flags | flag
- return true
-end
-
-local function ofSelf(state, loc, callback)
- -- self 的2个特殊引用位置:
- -- 1. 当前方法定义时的对象(mt)
- local method = loc.method
- local node = method.node
- vm.refOf(state, node, callback)
- -- 2. 调用该方法时传入的对象
-end
-
-local function ofLocal(state, loc, callback)
- if not markFlag(state, loc, STATE_LOCAL) then
- return
- end
- -- 方法中的 self 使用了一个虚拟的定义位置
- if loc.tag ~= 'self' then
- callback(loc, 'declare')
- vm.refOf(state, loc, callback)
- end
- local refs = loc.ref
- if refs then
- for i = 1, #refs do
- local ref = refs[i]
- if ref.type == 'getlocal' then
- callback(ref, 'get')
- vm.refOf(state, ref, callback)
- if loc.tag == '_ENV' then
- local parent = ref.parent
- if parent.type == 'getfield'
- or parent.type == 'getindex' then
- if guide.getKeyName(parent) == '_G' then
- callback(parent, 'declare')
- vm.refOf(state, ref, callback)
- end
- end
- end
- elseif ref.type == 'setlocal' then
- callback(ref, 'set')
- vm.refOf(state, ref, callback)
- elseif ref.type == 'getglobal' then
- if loc.tag == '_ENV' then
- if guide.getName(ref) == '_G' then
- callback(ref, 'get')
- vm.refOf(state, ref, callback)
- end
- end
- end
- end
- end
- if loc.tag == 'self' then
- ofSelf(state, loc, callback)
- end
-end
-
-local function ofGlobal(state, source, callback)
- local key = guide.getKeyName(source)
- local node = source.node
- if not markFlag(state, node, STATE_NODE) then
- return
- end
- if node.tag == '_ENV' then
- local uris = files.findGlobals(key)
- for i = 1, #uris do
- local uri = uris[i]
- local ast = files.getAst(uri)
- local globals = vm.getGlobals(ast.ast)
- if globals and globals[key] then
- for _, info in ipairs(globals[key]) do
- callback(info)
- vm.refOf(state, info.source, callback)
- end
- end
- end
- else
- -- 重载了 _ENV
- vm.eachField(node, function (info)
- if key == info.key then
- callback(info)
- vm.refOf(state, info.source, callback)
- end
- end)
- end
-end
-
-local function ofField(state, source, callback)
- local parent = source.parent
- local key = guide.getKeyName(source)
- local node
- if parent.type == 'tablefield'
- or parent.type == 'tableindex' then
- node = parent.parent
- else
- node = parent.node
- end
- if not markFlag(state, node, STATE_NODE) then
- return
- end
- vm.eachField(node, function (info)
- if key == info.key then
- callback(info)
- vm.refOf(source, state, callback)
- end
- end)
-end
-
-local function ofLabel(state, source, callback)
- if not markFlag(state, source, STATE_LABEL) then
- return
- end
- callback(source, 'set')
- if source.ref then
- for _, ref in ipairs(source.ref) do
- callback(ref, 'get')
- end
- end
-end
-
-local function ofGoTo(state, source, callback)
- local name = source[1]
- local label = guide.getLabel(source, name)
- if label then
- ofLabel(state, label, callback)
- end
-end
-
-local function ofValue(state, source, callback)
- callback(source, 'value')
-end
-
-local function ofIndex(state, source, callback)
- local parent = source.parent
- if not parent then
- return
- end
- if parent.type == 'setindex'
- or parent.type == 'getindex'
- or parent.type == 'tableindex' then
- ofField(state, source, callback)
- end
-end
-
-local function ofCallRecv(state, func, index, callback, offset)
- if not markFlag(state, func, STATE_USED) then
- return
- end
+local function ofCall(func, index, callback, offset)
offset = offset or 0
vm.eachRef(func, function (info)
local src = info.source
@@ -171,30 +14,49 @@ local function ofCallRecv(state, func, index, callback, offset)
end
if returns then
-- 搜索函数第 index 个返回值
- for i = 1, #returns do
- local rtn = returns[i]
+ for _, rtn in ipairs(returns) do
local val = rtn[index-offset]
if val then
- vm.refOf(state, val, callback)
+ vm.eachRef(val, callback)
end
end
end
end)
end
-local function ofSpecialCallRecv(state, call, func, index, callback, offset)
+local function ofCallSelect(call, index, callback)
+ local slc = call.parent
+ if slc.index == index then
+ vm.eachRef(slc.parent, callback)
+ return
+ end
+ if call.extParent then
+ for i = 1, #call.extParent do
+ slc = call.extParent[i]
+ if slc.index == index then
+ vm.eachRef(slc.parent, callback)
+ return
+ end
+ end
+ end
+end
+
+local function ofSpecialCall(call, func, index, callback, offset)
local name = func.special
offset = offset or 0
if name == 'setmetatable' then
if index == 1 + offset then
local args = call.args
if args[1+offset] then
- vm.refOf(state, args[1+offset], callback)
+ vm.eachRef(args[1+offset], callback)
end
if args[2+offset] then
vm.eachField(args[2+offset], function (info)
if info.key == 's|__index' then
- vm.refOf(state, info.source, callback)
+ vm.eachRef(info.source, callback)
+ if info.value then
+ vm.eachRef(info.value, callback)
+ end
end
end)
end
@@ -205,12 +67,11 @@ local function ofSpecialCallRecv(state, call, func, index, callback, offset)
local result = vm.getLinkUris(call)
if result then
local myUri = guide.getRoot(call).uri
- for i = 1, #result do
- local uri = result[i]
+ for _, uri in ipairs(result) do
if not files.eq(uri, myUri) then
local ast = files.getAst(uri)
if ast then
- ofCallRecv(state, ast.ast, 1, callback)
+ ofCall(ast.ast, 1, callback)
end
end
end
@@ -222,7 +83,10 @@ local function ofSpecialCallRecv(state, call, func, index, callback, offset)
local objName = args[1+offset][1]
local lib = library.library[objName]
if lib then
- callback(lib, 'value')
+ callback {
+ source = lib,
+ mode = 'value',
+ }
end
end
end
@@ -231,13 +95,12 @@ local function ofSpecialCallRecv(state, call, func, index, callback, offset)
or name == 'xpcall' then
if index >= 2-offset then
local args = call.args
- if args[1+offset]
- and markFlag(state, args[1+offset], STATE_USED) then
+ if args[1+offset] then
vm.eachRef(args[1+offset], function (info)
local src = info.source
if src.type == 'function' then
- ofCallRecv(state, src, index, callback, 1+offset)
- ofSpecialCallRecv(state, call, src, index, callback, 1+offset)
+ ofCall(src, index, callback, 1+offset)
+ ofSpecialCall(call, src, index, callback, 1+offset)
end
end)
end
@@ -245,17 +108,59 @@ local function ofSpecialCallRecv(state, call, func, index, callback, offset)
end
end
--- 自己是函数调用的接收者,引用函数定义的返回值
-local function ofSelect(state, source, callback)
+local function asSetValue(value, callback)
+ if value.type == 'field'
+ or value.type == 'method' then
+ value = value.parent
+ end
+ local parent = value.parent
+ if not parent then
+ return
+ end
+ if parent.type == 'local'
+ or parent.type == 'setglobal'
+ or parent.type == 'setlocal'
+ or parent.type == 'setfield'
+ or parent.type == 'setmethod'
+ or parent.type == 'setindex'
+ or parent.type == 'tablefield'
+ or parent.type == 'tableindex' then
+ if parent.value == value then
+ vm.eachRef(parent, callback)
+ if guide.getName(parent) == '__index' then
+ if parent.type == 'tablefield'
+ or parent.type == 'tableindex' then
+ local t = parent.parent
+ local args = t.parent
+ if args[2] == t then
+ local call = args.parent
+ local func = call.node
+ if func.special == 'setmetatable' then
+ vm.eachRef(args[1], callback)
+ end
+ end
+ end
+ end
+ end
+ end
+end
+
+local function ofSelect(source, callback)
+ -- 检查函数返回值
local call = source.vararg
if call.type == 'call' then
- ofCallRecv(state, call.node, source.index, callback)
- ofSpecialCallRecv(state, call, call.node, source.index, callback)
+ ofCall(call.node, source.index, callback)
+ ofSpecialCall(call, call.node, source.index, callback)
end
end
-local function ofMain(state, source, callback)
- callback(source, 'main')
+local function ofSelf(loc, callback)
+ -- self 的2个特殊引用位置:
+ -- 1. 当前方法定义时的对象(mt)
+ local method = loc.method
+ local node = method.node
+ vm.eachRef(node, callback)
+ -- 2. 调用该方法时传入的对象
end
local function getCallRecvs(call)
@@ -267,8 +172,7 @@ local function getCallRecvs(call)
local recvs = {}
recvs[1] = parent.parent
if extParent then
- for i = 1, #extParent do
- local p = extParent[i]
+ for _, p in ipairs(extParent) do
recvs[#recvs+1] = p.parent
end
end
@@ -276,7 +180,7 @@ local function getCallRecvs(call)
end
--- 自己作为函数的参数
-local function checkAsArg(state, source, callback)
+local function asArg(source, callback)
local parent = source.parent
if not parent then
return
@@ -290,13 +194,16 @@ local function checkAsArg(state, source, callback)
if parent[2] then
vm.eachField(parent[2], function (info)
if info.key == 's|__index' then
- vm.refOf(state, info.source, callback)
+ vm.eachRef(info.source, callback)
+ if info.value then
+ vm.eachRef(info.value, callback)
+ end
end
end)
end
local recvs = getCallRecvs(call)
if recvs and recvs[1] then
- vm.refOf(state, recvs[1], callback)
+ vm.eachRef(recvs[1], callback)
end
vm.setMeta(source, parent[2])
end
@@ -304,25 +211,8 @@ local function checkAsArg(state, source, callback)
end
end
-local function ofCallSelect(state, call, index, callback)
- local slc = call.parent
- if slc.index == index then
- vm.refOf(state, slc.parent, callback)
- return
- end
- if call.extParent then
- for i = 1, #call.extParent do
- slc = call.extParent[i]
- if slc.index == index then
- vm.refOf(state, slc.parent, callback)
- return
- end
- end
- end
-end
-
--- 自己作为函数的返回值
-local function checkAsReturn(state, source, callback)
+local function asReturn(source, callback)
local parent = source.parent
if source.type == 'field'
or source.type == 'method' then
@@ -338,16 +228,15 @@ local function checkAsReturn(state, source, callback)
if not uris then
return
end
- for i = 1, #uris do
- local uri = uris[i]
+ for _, uri in ipairs(uris) do
local ast = files.getAst(uri)
if ast then
local links = vm.getLinks(ast.ast)
if links then
for linkUri, calls in pairs(links) do
if files.eq(linkUri, myUri) then
- for j = 1, #calls do
- ofCallSelect(state, calls[j], 1, callback)
+ for i = 1, #calls do
+ ofCallSelect(calls[i], 1, callback)
end
end
end
@@ -373,161 +262,267 @@ local function checkAsReturn(state, source, callback)
end
local recvs = getCallRecvs(call)
if recvs and recvs[index] then
- vm.refOf(state, recvs[index], callback)
+ vm.eachRef(recvs[index], callback)
elseif index == 1 then
- callback(call, 'call')
+ callback {
+ type = 'call',
+ source = call,
+ }
end
end)
end
end
-local function checkAsParen(state, source, callback)
- if source.parent and source.parent.type == 'paren' then
- vm.refOf(state, source.parent, callback)
+local function ofLocal(loc, callback)
+ -- 方法中的 self 使用了一个虚拟的定义位置
+ if loc.tag ~= 'self' then
+ callback {
+ source = loc,
+ mode = 'declare',
+ }
+ end
+ if loc.ref then
+ for _, ref in ipairs(loc.ref) do
+ if ref.type == 'getlocal' then
+ callback {
+ source = ref,
+ mode = 'get',
+ }
+ vm.eachRef(ref, callback)
+ elseif ref.type == 'setlocal' then
+ callback {
+ source = ref,
+ mode = 'set',
+ }
+ vm.eachRef(ref, callback)
+ if ref.value then
+ vm.eachRef(ref.value, callback)
+ end
+ end
+ end
+ end
+ if loc.tag == 'self' then
+ ofSelf(loc, callback)
+ end
+ if loc.value then
+ vm.eachRef(loc.value, callback)
+ end
+ if loc.tag == '_ENV' and loc.ref then
+ for _, ref in ipairs(loc.ref) do
+ if ref.type == 'getlocal' then
+ local parent = ref.parent
+ if parent.type == 'getfield'
+ or parent.type == 'getindex' then
+ if guide.getKeyName(parent) == '_G' then
+ callback {
+ source = parent,
+ mode = 'get',
+ }
+ end
+ end
+ elseif ref.type == 'getglobal' then
+ if guide.getName(ref) == '_G' then
+ callback {
+ source = ref,
+ mode = 'get',
+ }
+ end
+ end
+ end
end
end
-local function checkValue(state, source, callback)
- if source.value then
- vm.refOf(state, source.value, callback)
+local function ofGlobal(source, callback)
+ local key = guide.getKeyName(source)
+ local node = source.node
+ if node.tag == '_ENV' then
+ local uris = files.findGlobals(key)
+ for _, uri in ipairs(uris) do
+ local ast = files.getAst(uri)
+ local globals = vm.getGlobals(ast.ast)
+ if globals and globals[key] then
+ for _, info in ipairs(globals[key]) do
+ callback(info)
+ if info.value then
+ vm.eachRef(info.value, callback)
+ end
+ end
+ end
+ end
+ else
+ vm.eachField(node, function (info)
+ if key == info.key then
+ callback {
+ source = info.source,
+ mode = info.mode,
+ }
+ if info.value then
+ vm.eachRef(info.value, callback)
+ end
+ end
+ end)
end
end
-local function checkSetValue(state, value, callback)
- if value.type == 'field'
- or value.type == 'method' then
- value = value.parent
- end
- local parent = value.parent
- if not parent then
+local function ofField(source, callback)
+ if not source then
return
end
- if parent.type == 'local'
- or parent.type == 'setglobal'
- or parent.type == 'setlocal'
- or parent.type == 'setfield'
- or parent.type == 'setmethod'
- or parent.type == 'setindex'
- or parent.type == 'tablefield'
+ local parent = source.parent
+ local key = guide.getKeyName(source)
+ if parent.type == 'tablefield'
or parent.type == 'tableindex' then
- if parent.value == value then
- vm.refOf(state, parent, callback)
- if guide.getName(parent) == '__index' then
- if parent.type == 'tablefield'
- or parent.type == 'tableindex' then
- local t = parent.parent
- local args = t.parent
- if args[2] == t then
- local call = args.parent
- local func = call.node
- if func.special == 'setmetatable' then
- vm.refOf(state, args[1], callback)
- end
- end
+ local tbl = parent.parent
+ vm.eachField(tbl, function (info)
+ if key == info.key then
+ callback {
+ source = info.source,
+ mode = info.mode,
+ }
+ vm.eachRef(info.source, callback)
+ if info.value then
+ vm.eachRef(info.value, callback)
end
end
- end
+ end)
+ else
+ local node = parent.node
+ vm.eachField(node, function (info)
+ if key == info.key then
+ callback {
+ source = info.source,
+ mode = info.mode,
+ }
+ vm.eachRef(info.source, callback)
+ if info.value then
+ vm.eachRef(info.value, callback)
+ end
+ end
+ end)
end
end
-local function ofInParen(state, source, callback)
- vm.refOf(state, source, callback)
+local function ofLiteral(source, callback)
+ local parent = source.parent
+ if not parent then
+ return
+ end
+ if parent.type == 'setindex'
+ or parent.type == 'getindex'
+ or parent.type == 'tableindex' then
+ ofField(source, callback)
+ end
end
-local function applyCache(cache, callback, max)
- await.delay(function ()
- return files.globalVersion
- end)
- if max then
- if max > #cache then
- max = #cache
+local function ofLabel(source, callback)
+ callback {
+ source = source,
+ mode = 'set',
+ }
+ if source.ref then
+ for _, ref in ipairs(source.ref) do
+ callback {
+ source = ref,
+ mode = 'get',
+ }
end
- else
- max = #cache
end
- for i = 1, max do
- local res = callback(cache[i])
- if res ~= nil then
- return res
- end
+end
+
+local function ofGoTo(source, callback)
+ local name = source[1]
+ local label = guide.getLabel(source, name)
+ if label then
+ ofLabel(label, callback)
end
end
-local function eachRef(source, result)
- local mark = {}
- vm.refOf({}, source, function (src, mode)
- local info
- if src.mode then
- info = src
- src = info.source
- end
- if mark[src] then
- return
- end
- mark[src] = true
- if info then
- result[#result+1] = info
- elseif mode then
- result[#result+1] = {
- source = src,
- mode = mode,
- }
- end
- end)
- return result
+local function ofMain(source, callback)
+ callback {
+ source = source,
+ mode = 'main',
+ }
end
-function vm.refOf(state, source, callback)
- if not markFlag(state, source, STATE_USED) then
- return
+local function asParen(source, callback)
+ if source.parent and source.parent.type == 'paren' then
+ vm.eachRef(source.parent, callback)
end
+end
+
+local function ofSelfValue(source, callback)
+ callback {
+ source = source,
+ mode = 'value',
+ }
+end
+
+local function eachRef(source, callback)
local stype = source.type
- if stype == 'local' then
- ofLocal(state, source, callback)
+ if stype == 'local' then
+ ofLocal(source, callback)
elseif stype == 'getlocal'
or stype == 'setlocal' then
- ofLocal(state, source.node, callback)
+ ofLocal(source.node, callback)
elseif stype == 'setglobal'
or stype == 'getglobal' then
- ofGlobal(state, source, callback)
+ ofGlobal(source, callback)
elseif stype == 'field'
or stype == 'method' then
- ofField(state, source, callback)
+ ofField(source, callback)
elseif stype == 'setfield'
or stype == 'getfield'
or stype == 'tablefield' then
- ofField(state, source.field, callback)
+ ofField(source.field, callback)
elseif stype == 'setmethod'
or stype == 'getmethod' then
- ofField(state, source.method, callback)
- elseif stype == 'goto' then
- ofGoTo(state, source, callback)
- elseif stype == 'label' then
- ofLabel(state, source, callback)
+ ofField(source.method, callback)
elseif stype == 'number'
or stype == 'boolean'
or stype == 'string' then
- ofIndex(state, source, callback)
- ofValue(state, source, callback)
+ ofLiteral(source, callback)
+ ofSelfValue(source, callback)
+ elseif stype == 'goto' then
+ ofGoTo(source, callback)
+ elseif stype == 'label' then
+ ofLabel(source, callback)
elseif stype == 'table'
or stype == 'function'
or stype == 'nil' then
- ofValue(state, source, callback)
+ ofSelfValue(source, callback)
elseif stype == 'select' then
- ofSelect(state, source, callback)
+ ofSelect(source, callback)
elseif stype == 'call' then
- ofCallRecv(state, source.node, 1, callback)
- ofSpecialCallRecv(state, source, source.node, 1, callback)
+ ofCall(source.node, 1, callback)
+ ofSpecialCall(source, source.node, 1, callback)
elseif stype == 'main' then
- ofMain(state, source, callback)
+ ofMain(source, callback)
elseif stype == 'paren' then
- ofInParen(state, source.exp, callback)
+ eachRef(source.exp, callback)
+ end
+ asArg(source, callback)
+ asReturn(source, callback)
+ asParen(source, callback)
+ asSetValue(source, callback)
+end
+
+local function applyCache(cache, callback, max)
+ await.delay(function ()
+ return files.globalVersion
+ end)
+ if max then
+ if max > #cache then
+ max = #cache
+ end
+ else
+ max = #cache
+ end
+ for i = 1, max do
+ local res = callback(cache[i])
+ if res ~= nil then
+ return res
+ end
end
- checkValue(state, source, callback)
- checkSetValue(state, source, callback)
- checkAsParen(state, source, callback)
- checkAsReturn(state, source, callback)
- checkAsArg(state, source, callback)
end
--- 判断2个对象是否拥有相同的引用
@@ -557,7 +552,15 @@ function vm.eachRef(source, callback, max)
end
cache = {}
vm.cache.eachRef[source] = cache
- eachRef(source, cache)
+ local mark = {}
+ eachRef(source, function (info)
+ local src = info.source
+ if mark[src] then
+ return
+ end
+ mark[src] = true
+ cache[#cache+1] = info
+ end)
unlock()
for i = 1, #cache do
local src = cache[i].source
diff --git a/test-beta/references/init.lua b/test-beta/references/init.lua
index 557ac15a..c6727eca 100644
--- a/test-beta/references/init.lua
+++ b/test-beta/references/init.lua
@@ -149,7 +149,7 @@ print(obj.<?x?>)
TEST [[
local <!x!>
local function f()
- return x
+ return <!x!>
end
local <?y?> = f()
]]
@@ -166,7 +166,7 @@ TEST [[
local <!x!>
local function f()
return function ()
- return x
+ return <!x!>
end
end
local <?y?> = f()()