diff options
Diffstat (limited to 'script-beta/vm/eachRef.lua')
-rw-r--r-- | script-beta/vm/eachRef.lua | 500 |
1 files changed, 500 insertions, 0 deletions
diff --git a/script-beta/vm/eachRef.lua b/script-beta/vm/eachRef.lua new file mode 100644 index 00000000..cfb2bef8 --- /dev/null +++ b/script-beta/vm/eachRef.lua @@ -0,0 +1,500 @@ +local guide = require 'parser.guide' +local files = require 'files' +local vm = require 'vm.vm' + +local function ofCall(func, index, callback) + vm.eachRef(func, function (info) + local src = info.source + local returns + if src.type == 'main' or src.type == 'function' then + returns = src.returns + end + if returns then + -- 搜索函数第 index 个返回值 + for _, rtn in ipairs(returns) do + local val = rtn[index] + if val then + callback { + source = val, + mode = 'return', + } + vm.eachRef(val, callback) + end + end + end + end) +end + +local function ofCallSelect(call, index, callback) + local slc = call.parent + if slc.index == index then + vm.eachRef(slc.parent, callback) + return + end + if call.extParent then + for i = 1, #call.extParent do + slc = call.extParent[i] + if slc.index == index then + vm.eachRef(slc.parent, callback) + return + end + end + end +end + +local function ofReturn(rtn, index, callback) + local func = guide.getParentFunction(rtn) + if not func then + return + end + -- 搜索函数调用的第 index 个接收值 + if func.type == 'main' then + local myUri = func.uri + local uris = files.findLinkTo(myUri) + if not uris then + return + end + for _, uri in ipairs(uris) do + local ast = files.getAst(uri) + if ast then + local links = vm.getLinks(ast.ast) + if links then + for linkUri, calls in pairs(links) do + if files.eq(linkUri, myUri) then + for i = 1, #calls do + ofCallSelect(calls[i], 1, callback) + end + end + end + end + end + end + else + vm.eachRef(func, function (info) + local source = info.source + local call = source.parent + if not call or call.type ~= 'call' then + return + end + ofCallSelect(call, index, callback) + end) + end +end + +local function ofSpecialCall(call, func, index, callback) + local name = func.special + if name == 'setmetatable' then + if index == 1 then + local args = call.args + if args[1] then + vm.eachRef(args[1], callback) + end + if args[2] then + vm.eachField(args[2], function (info) + if info.key == 's|__index' then + vm.eachRef(info.source, callback) + if info.value then + vm.eachRef(info.value, callback) + end + end + end) + end + end + elseif name == 'require' then + if index == 1 then + local result = vm.getLinkUris(call) + if result then + local myUri = guide.getRoot(call).uri + for _, uri in ipairs(result) do + if not files.eq(uri, myUri) then + local ast = files.getAst(uri) + if ast then + ofCall(ast.ast, 1, callback) + end + end + end + end + end + end +end + +local function ofValue(value, callback) + if value.type == 'select' then + -- 检查函数返回值 + local call = value.vararg + if call.type == 'call' then + ofCall(call.node, value.index, callback) + ofSpecialCall(call, call.node, value.index, callback) + end + return + end + + if value.type == 'table' + or value.type == 'string' + or value.type == 'number' + or value.type == 'boolean' + or value.type == 'nil' + or value.type == 'function' then + callback { + source = value, + mode = 'value', + } + end + + vm.eachRef(value, callback) + + local parent = value.parent + if parent.type == 'local' + or parent.type == 'setglobal' + or parent.type == 'setlocal' + or parent.type == 'setfield' + or parent.type == 'setmethod' + or parent.type == 'setindex' + or parent.type == 'tablefield' + or parent.type == 'tableindex' then + if parent.value == value then + vm.eachRef(parent, callback) + end + end + if parent.type == 'return' then + for i = 1, #parent do + if parent[i] == value then + ofReturn(parent, i, callback) + break + end + end + end +end + +local function ofSelf(loc, callback) + -- self 的2个特殊引用位置: + -- 1. 当前方法定义时的对象(mt) + local method = loc.method + local node = method.node + vm.eachRef(node, callback) + -- 2. 调用该方法时传入的对象 +end + +--- 自己作为赋值的值 +local function asValue(source, callback) + local parent = source.parent + if parent and parent.value == source then + if guide.getName(parent) == '__index' then + if parent.type == 'tablefield' + or parent.type == 'tableindex' then + local t = parent.parent + local args = t.parent + if args[2] == t then + local call = args.parent + local func = call.node + if func.special == 'setmetatable' then + vm.eachRef(args[1], callback) + end + end + end + end + end +end + +local function getCallRecvs(call) + local parent = call.parent + if parent.type ~= 'select' then + return nil + end + local exParent = call.exParent + local recvs = {} + recvs[1] = parent.parent + if exParent then + for _, p in ipairs(exParent) do + recvs[#recvs+1] = p.parent + end + end + return recvs +end + +--- 自己作为函数的参数 +local function asArg(source, callback) + local parent = source.parent + if not parent then + return + end + if parent.type == 'callargs' then + local call = parent.parent + local func = call.node + local name = func.special + if name == 'setmetatable' then + if parent[1] == source then + if parent[2] then + vm.eachField(parent[2], function (info) + if info.key == 's|__index' then + vm.eachRef(info.source, callback) + if info.value then + vm.eachRef(info.value, callback) + end + end + end) + end + end + local recvs = getCallRecvs(call) + if recvs and recvs[1] then + vm.eachRef(recvs[1], callback) + end + end + end +end + +local function ofLocal(loc, callback) + -- 方法中的 self 使用了一个虚拟的定义位置 + if loc.tag ~= 'self' then + callback { + source = loc, + mode = 'declare', + } + end + if loc.ref then + for _, ref in ipairs(loc.ref) do + if ref.type == 'getlocal' then + callback { + source = ref, + mode = 'get', + } + asValue(ref, callback) + elseif ref.type == 'setlocal' then + callback { + source = ref, + mode = 'set', + } + if ref.value then + ofValue(ref.value, callback) + end + end + end + end + if loc.tag == 'self' then + ofSelf(loc, callback) + end + if loc.value then + ofValue(loc.value, callback) + end + if loc.tag == '_ENV' and loc.ref then + for _, ref in ipairs(loc.ref) do + if ref.type == 'getlocal' then + local parent = ref.parent + if parent.type == 'getfield' + or parent.type == 'getindex' then + if guide.getKeyName(parent) == '_G' then + callback { + source = parent, + mode = 'get', + } + end + end + elseif ref.type == 'getglobal' then + if guide.getName(ref) == '_G' then + callback { + source = ref, + mode = 'get', + } + end + end + end + end +end + +local function ofGlobal(source, callback) + local key = guide.getKeyName(source) + local node = source.node + if node.tag == '_ENV' then + local uris = files.findGlobals(key) + for _, uri in ipairs(uris) do + local ast = files.getAst(uri) + local globals = vm.getGlobals(ast.ast) + if globals[key] then + for _, info in ipairs(globals[key]) do + callback(info) + if info.value then + ofValue(info.value, callback) + end + end + end + end + else + vm.eachField(node, function (info) + if key == info.key then + callback { + source = info.source, + mode = info.mode, + } + if info.value then + ofValue(info.value, callback) + end + end + end) + end +end + +local function ofField(source, callback) + local parent = source.parent + local key = guide.getKeyName(source) + if parent.type == 'tablefield' + or parent.type == 'tableindex' then + local tbl = parent.parent + vm.eachField(tbl, function (info) + if key == info.key then + callback { + source = info.source, + mode = info.mode, + } + if info.value then + ofValue(info.value, callback) + end + end + end) + else + local node = parent.node + vm.eachField(node, function (info) + if key == info.key then + callback { + source = info.source, + mode = info.mode, + } + if info.value then + ofValue(info.value, callback) + end + end + end) + end +end + +local function ofLiteral(source, callback) + local parent = source.parent + if not parent then + return + end + if parent.type == 'setindex' + or parent.type == 'getindex' + or parent.type == 'tableindex' then + ofField(source, callback) + end +end + +local function ofLabel(source, callback) + callback { + source = source, + mode = 'set', + } + if source.ref then + for _, ref in ipairs(source.ref) do + callback { + source = ref, + mode = 'get', + } + end + end +end + +local function ofGoTo(source, callback) + local name = source[1] + local label = guide.getLabel(source, name) + if label then + ofLabel(label, callback) + end +end + +local function ofMain(source, callback) + callback { + source = source, + mode = 'main', + } +end + +local function eachRef(source, callback) + local stype = source.type + if stype == 'local' then + ofLocal(source, callback) + elseif stype == 'getlocal' + or stype == 'setlocal' then + ofLocal(source.node, callback) + elseif stype == 'setglobal' + or stype == 'getglobal' then + ofGlobal(source, callback) + elseif stype == 'field' + or stype == 'method' then + ofField(source, callback) + elseif stype == 'setfield' + or stype == 'getfield' then + ofField(source.field, callback) + elseif stype == 'setmethod' + or stype == 'getmethod' then + ofField(source.method, callback) + elseif stype == 'number' + or stype == 'boolean' + or stype == 'string' then + ofLiteral(source, callback) + elseif stype == 'goto' then + ofGoTo(source, callback) + elseif stype == 'label' then + ofLabel(source, callback) + elseif stype == 'table' + or stype == 'function' then + ofValue(source, callback) + elseif stype == 'main' then + ofMain(source, callback) + end + asArg(source, callback) +end + +--- 判断2个对象是否拥有相同的引用 +function vm.isSameRef(a, b) + local cache = vm.cache.eachRef[a] + if cache then + -- 相同引用的source共享同一份cache + return cache == vm.cache.eachRef[b] + else + return vm.eachRef(a, function (info) + if info.source == b then + return true + end + end) or false + end +end + +--- 获取所有的引用 +function vm.eachRef(source, callback) + local cache = vm.cache.eachRef[source] + if cache then + for i = 1, #cache do + local res = callback(cache[i]) + if res ~= nil then + return res + end + end + return + end + local unlock = vm.lock('eachRef', source) + if not unlock then + return + end + cache = {} + vm.cache.eachRef[source] = cache + local mark = {} + eachRef(source, function (info) + local src = info.source + if mark[src] then + return + end + mark[src] = true + cache[#cache+1] = info + end) + unlock() + for i = 1, #cache do + local src = cache[i].source + vm.cache.eachRef[src] = cache + end + for i = 1, #cache do + local res = callback(cache[i]) + if res ~= nil then + return res + end + end +end |