diff options
-rw-r--r-- | script-beta/core/definition.lua | 47 | ||||
-rw-r--r-- | script-beta/core/reference.lua | 48 | ||||
-rw-r--r-- | script-beta/vm/eachDef.lua | 439 | ||||
-rw-r--r-- | script-beta/vm/eachField.lua | 54 | ||||
-rw-r--r-- | script-beta/vm/eachRef.lua | 60 | ||||
-rw-r--r-- | script-beta/vm/getGlobals.lua | 9 | ||||
-rw-r--r-- | script-beta/vm/vm.lua | 12 | ||||
-rw-r--r-- | test-beta/definition/call.lua | 2 | ||||
-rw-r--r-- | test-beta/definition/method.lua | 46 |
9 files changed, 233 insertions, 484 deletions
diff --git a/script-beta/core/definition.lua b/script-beta/core/definition.lua index 8afdc2ba..ded2fb1c 100644 --- a/script-beta/core/definition.lua +++ b/script-beta/core/definition.lua @@ -17,37 +17,22 @@ local function findDef(source, callback) and source.type ~= 'goto' then return end - vm.eachDef(source, function (info) - if info.source.library then - return - end - if info.mode == 'declare' - or info.mode == 'set' then - local src = info.source - local root = guide.getRoot(src) - local uri = root.uri - if src.type == 'setfield' - or src.type == 'getfield' - or src.type == 'tablefield' then - callback(src.field, uri) - elseif src.type == 'setindex' - or src.type == 'getindex' - or src.type == 'tableindex' then - callback(src.index, uri) - elseif src.type == 'getmethod' - or src.type == 'setmethod' then - callback(src.method, uri) - else - callback(src, uri) - end - end - if info.mode == 'value' then - local src = info.source - local root = guide.getRoot(src) - local uri = root.uri - if src.parent.type == 'return' then - callback(src, uri) - end + vm.eachDef(source, function (src) + local root = guide.getRoot(src) + local uri = root.uri + if src.type == 'setfield' + or src.type == 'getfield' + or src.type == 'tablefield' then + callback(src.field, uri) + elseif src.type == 'setindex' + or src.type == 'getindex' + or src.type == 'tableindex' then + callback(src.index, uri) + elseif src.type == 'getmethod' + or src.type == 'setmethod' then + callback(src.method, uri) + else + callback(src, uri) end end) end diff --git a/script-beta/core/reference.lua b/script-beta/core/reference.lua index b04de92d..7046df43 100644 --- a/script-beta/core/reference.lua +++ b/script-beta/core/reference.lua @@ -27,38 +27,22 @@ local function findRef(source, offset, callback) and not isFunction(source, offset) then return end - vm.eachRef(source, function (info) - if info.source.library then - return - end - if info.mode == 'declare' - or info.mode == 'set' - or info.mode == 'get' then - local src = info.source - local root = guide.getRoot(src) - local uri = root.uri - if src.type == 'setfield' - or src.type == 'getfield' - or src.type == 'tablefield' then - callback(src.field, uri) - elseif src.type == 'setindex' - or src.type == 'getindex' - or src.type == 'tableindex' then - callback(src.index, uri) - elseif src.type == 'getmethod' - or src.type == 'setmethod' then - callback(src.method, uri) - else - callback(src, uri) - end - end - if info.mode == 'value' then - local src = info.source - local root = guide.getRoot(src) - local uri = root.uri - if src.parent.type == 'return' then - callback(src, uri) - end + vm.eachRef(source, function (src) + local root = guide.getRoot(src) + local uri = root.uri + if src.type == 'setfield' + or src.type == 'getfield' + or src.type == 'tablefield' then + callback(src.field, uri) + elseif src.type == 'setindex' + or src.type == 'getindex' + or src.type == 'tableindex' then + callback(src.index, uri) + elseif src.type == 'getmethod' + or src.type == 'setmethod' then + callback(src.method, uri) + else + callback(src, uri) end end) end diff --git a/script-beta/vm/eachDef.lua b/script-beta/vm/eachDef.lua index 005c2ed1..6bad25ad 100644 --- a/script-beta/vm/eachDef.lua +++ b/script-beta/vm/eachDef.lua @@ -1,141 +1,91 @@ local guide = require 'parser.guide' local files = require 'files' local vm = require 'vm.vm' -local library = require 'library' -local await = require 'await' -local function ofSelf(state, loc, callback) - -- self 的2个特殊引用位置: - -- 1. 当前方法定义时的对象(mt) - local method = loc.method - local node = method.node - callback(node) - -- 2. 调用该方法时传入的对象 -end - -local function ofLocal(state, loc, source, callback) - if state[loc] then +local function ofParentMT(func, callback) + if not func or func.type ~= 'function' then return end - state[loc] = true - -- 方法中的 self 使用了一个虚拟的定义位置 - if loc.tag ~= 'self' then - callback(loc, 'declare') + local parent = func.parent + if not parent or parent.type ~= 'setmethod' then + return end - if source == loc then + local node = parent.node + if not node then return end + vm.eachDef(node, callback) +end + +local function ofLocal(loc, callback) + -- 方法中的 self 使用了一个虚拟的定义位置 + if loc.tag == 'self' then + local func = guide.getParentFunction(loc) + ofParentMT(func, callback) + else + callback(loc) + end local refs = loc.ref if refs then for i = 1, #refs do local ref = refs[i] - if ref == source then - break - end - if ref.type == 'getlocal' then - if loc.tag == '_ENV' then - local parent = ref.parent - if parent.type == 'getfield' - or parent.type == 'getindex' then - if guide.getKeyName(parent) == '_G' then - callback(parent, 'declare') - end - end - end - elseif ref.type == 'setlocal' then - callback(ref, 'set') + if vm.isSet(ref) then + callback(ref) end end end - if loc.tag == 'self' then - ofSelf(state, loc, callback) - end end -local function ofGlobal(state, source, callback) - if state[source] then - return - end +local function ofGlobal(source, callback) local key = guide.getKeyName(source) local node = source.node if node.tag == '_ENV' then local uris = files.findGlobals(key) - for i = 1, #uris do - local uri = uris[i] + for _, uri in ipairs(uris) do local ast = files.getAst(uri) local globals = vm.getGlobals(ast.ast) if globals and globals[key] then - for _, info in ipairs(globals[key]) do - state[info.source] = true - if info.mode == 'set' then - callback(info) + for _, src in ipairs(globals[key]) do + if vm.isSet(src) then + callback(src) end end end end else - vm.eachField(node, function (info) - if key == info.key then - state[info.source] = true - if info.mode == 'set' then - callback(info) - end + vm.eachField(node, function (src) + if vm.isSet(src) + and key == guide.getKeyName(src) then + callback(src) end end) end end -local function ofField(state, source, callback) - if state[source] then - return - end +local function ofField(source, callback) local parent = source.parent local key = guide.getKeyName(source) if parent.type == 'tablefield' or parent.type == 'tableindex' then local tbl = parent.parent - vm.eachField(tbl, function (info) - if key == info.key then - state[info.source] = true - if info.mode == 'set' then - callback(info) - end + vm.eachField(tbl, function (src) + if vm.isSet(src) + and key == guide.getKeyName(src) then + callback(src) end end) else local node = parent.node - vm.eachField(node, function (info) - if key == info.key then - state[info.source] = true - if info.mode == 'set' then - callback(info) - end + vm.eachField(node, function (src) + if vm.isSet(src) + and key == guide.getKeyName(src) then + callback(src) end end) end end -local function ofLabel(state, source, callback) - if state[source] then - return - end - state[source] = true - callback(source, 'set') -end - -local function ofGoTo(state, source, callback) - local name = source[1] - local label = guide.getLabel(source, name) - if label then - ofLabel(state, label, callback) - end -end - -local function ofValue(state, source, callback) - callback(source, 'value') -end - -local function ofIndex(state, source, callback) +local function ofLiteral(source, callback) local parent = source.parent if not parent then return @@ -143,263 +93,134 @@ local function ofIndex(state, source, callback) if parent.type == 'setindex' or parent.type == 'getindex' or parent.type == 'tableindex' then - ofField(state, source, callback) + ofField(source, callback) end end -local function ofCall(state, func, index, callback, offset) - offset = offset or 0 - vm.eachRef(func, function (info) - local src = info.source - local returns - if src.type == 'main' or src.type == 'function' then - returns = src.returns - end - if returns then - -- 搜索函数第 index 个返回值 - for i = 1, #returns do - local rtn = returns[i] - local val = rtn[index-offset] - if val then - callback(val) - end - end - end - end) -end - -local function ofSpecialCall(state, call, func, index, callback, offset) - local name = func.special - offset = offset or 0 - if name == 'setmetatable' then - if index == 1 + offset then - local args = call.args - if args[1+offset] then - callback(args[1+offset]) - end - if args[2+offset] then - vm.eachField(args[2+offset], function (info) - if info.key == 's|__index' then - callback(info.source) - end - end) - end - vm.setMeta(args[1+offset], args[2+offset]) - end - elseif name == 'require' then - if index == 1 + offset then - local result = vm.getLinkUris(call) - if result then - local myUri = guide.getRoot(call).uri - for i = 1, #result do - local uri = result[i] - if not files.eq(uri, myUri) then - local ast = files.getAst(uri) - if ast then - ofCall(state, ast.ast, 1, callback) - end - end - end - end - - local args = call.args - if args[1+offset] then - if args[1+offset].type == 'string' then - local objName = args[1+offset][1] - local lib = library.library[objName] - if lib then - callback(lib) - end - end - end - end - elseif name == 'pcall' - or name == 'xpcall' then - if index >= 2-offset then - local args = call.args - if args[1+offset] then - vm.eachRef(args[1+offset], function (info) - local src = info.source - if src.type == 'function' then - ofCall(state, src, index, callback, 1+offset) - ofSpecialCall(state, call, src, index, callback, 1+offset) - end - end) +local function ofLabel(source, callback) + callback(source) + if source.ref then + for _, ref in ipairs(source.ref) do + if ref.type == 'label' then + callback(ref) end end end end -local function ofSelect(state, source, callback) - -- 检查函数返回值 - local call = source.vararg - if call.type == 'call' then - ofCall(state, call.node, source.index, callback) - ofSpecialCall(state, call, call.node, source.index, callback) +local function ofGoTo(source, callback) + local name = source[1] + local label = guide.getLabel(source, name) + if label then + ofLabel(label, callback) end end -local function ofMain(state, source, callback) - callback(source, 'main') +local function ofTableField(source, callback) + local tbl = source.parent + local src = tbl.parent + if not src then + return + end + return vm.eachField(src, callback) end -local function getCallRecvs(call) - local parent = call.parent - if parent.type ~= 'select' then - return nil - end - local extParent = call.extParent - local recvs = {} - recvs[1] = parent.parent - if extParent then - for i = 1, #extParent do - local p = extParent[i] - recvs[#recvs+1] = p.parent +local function findIndex(parent, source) + for i = 1, #parent do + if parent[i] == source then + return i end end - return recvs + return nil end -local function checkValue(state, source, callback) - if source.value then - callback(source.value) - end +local function findCallRecvs(func, index, callback) + vm.eachRef(func, function (info) + local source = info.source + local parent = source.parent + if parent.type ~= 'call' then + return + end + if index == 1 then + local slt = parent.parent + if not slt or slt.type ~= 'select' then + return + end + callback(slt.parent) + else + local slt = parent.extParent and parent.extParent[index-1] + if not slt or slt.type ~= 'select' then + return + end + callback(slt.parent) + end + end) end -function vm.defCheck(state, source, callback) - checkValue(state, source, callback) +local function ofFunction(source, callback) + local parent = source.parent + if not parent then + return + end + if parent.type == 'return' then + local func = guide.getParentFunction(parent) + if not func then + return + end + local index = findIndex(parent, source) + if not index then + return + end + findCallRecvs(func, index, function (src) + vm.eachRef(src, callback) + end) + elseif parent.value == source then + vm.eachRef(parent, callback) + end end -function vm.defOf(state, source, callback) +local function eachDef(source, callback) local stype = source.type - if stype == 'local' then - ofLocal(state, source, source, callback) + if stype == 'local' then + ofLocal(source, callback) elseif stype == 'getlocal' or stype == 'setlocal' then - ofLocal(state, source.node, source, callback) + ofLocal(source.node, callback) elseif stype == 'setglobal' or stype == 'getglobal' then - ofGlobal(state, source, callback) + ofGlobal(source, callback) elseif stype == 'field' or stype == 'method' then - ofField(state, source, callback) + ofField(source, callback) elseif stype == 'setfield' - or stype == 'getfield' - or stype == 'tablefield' then - ofField(state, source.field, callback) + or stype == 'getfield' then + ofField(source.field, callback) elseif stype == 'setmethod' or stype == 'getmethod' then - ofField(state, source.method, callback) - elseif stype == 'goto' then - ofGoTo(state, source, callback) - elseif stype == 'label' then - ofLabel(state, source, callback) + ofField(source.method, callback) + elseif stype == 'tablefield' then + ofTableField(source, callback) elseif stype == 'number' or stype == 'boolean' or stype == 'string' then - ofIndex(state, source, callback) - ofValue(state, source, callback) - elseif stype == 'table' - or stype == 'function' - or stype == 'nil' then - ofValue(state, source, callback) - elseif stype == 'select' then - ofSelect(state, source, callback) - elseif stype == 'call' then - ofCall(state, source.node, 1, callback) - ofSpecialCall(state, source, source.node, 1, callback) - elseif stype == 'main' then - ofMain(state, source, callback) - elseif stype == 'paren' then - vm.defOf(state, source.exp, callback) - end -end - -local function eachDef(source, result) - local list = { source } - local mark = {} - local state = {} - local hasOf = {} - local hasCheck = {} - local function found(src, mode) - local info - if src.mode then - info = src - src = info.source - end - if mark[src] == nil then - list[#list+1] = src - end - if info then - mark[src] = info - elseif mode then - mark[src] = { - source = src, - mode = mode, - } - else - mark[src] = mark[src] or false - end - end - for _ = 1, 10000 do - if _ == 10000 then - error('stack overflow!') - end - local max = #list - if max == 0 then - break - end - local src = list[max] - list[max] = nil - if not hasOf[src] then - hasOf[src] = true - vm.defOf(state, src, found) - end - if not hasCheck[src] then - hasCheck[src] = true - vm.defCheck(state, src, found) - end - end - for _, info in pairs(mark) do - if info then - result[#result+1] = info - end - end - return result -end - -local function applyCache(cache, callback, max) - await.delay(function () - return files.globalVersion - end) - if max then - if max > #cache then - max = #cache - end - else - max = #cache - end - for i = 1, max do - local res = callback(cache[i]) - if res ~= nil then - return res - end + ofLiteral(source, callback) + elseif stype == 'function' then + ofFunction(source, callback) + elseif stype == 'goto' then + ofGoTo(source, callback) + elseif stype == 'label' then + ofLabel(source, callback) end end --- 获取所有的定义 function vm.eachDef(source, callback, max) - local cache = vm.cache.eachDef[source] - if cache then - return applyCache(cache, callback, max) - end - local unlock = vm.lock('eachDef', source) - if not unlock then - return - end - cache = {} - vm.cache.eachDef[source] = cache - eachDef(source, cache) - unlock() - return applyCache(cache, callback, max) + local mark = {} + eachDef(source, function (src) + if mark[src] then + return + end + mark[src] = true + callback(src) + end) end diff --git a/script-beta/vm/eachField.lua b/script-beta/vm/eachField.lua index 5148e153..c418c9da 100644 --- a/script-beta/vm/eachField.lua +++ b/script-beta/vm/eachField.lua @@ -8,14 +8,13 @@ local function checkNext(source) return nil end local ntype = nextSrc.type - if ntype == 'setfield' - or ntype == 'setmethod' - or ntype == 'setindex' then - return nextSrc, 'set' - elseif ntype == 'getfield' - or ntype == 'getmethod' - or ntype == 'getindex' then - return nextSrc, 'get' + if ntype == 'setfield' + or ntype == 'setmethod' + or ntype == 'setindex' + or ntype == 'getfield' + or ntype == 'getmethod' + or ntype == 'getindex' then + return nextSrc end return nil end @@ -31,12 +30,7 @@ local function findFieldInTable(value, callback) local field = value[i] if field.type == 'tablefield' or field.type == 'tableindex' then - callback { - source = field, - key = guide.getKeyName(field), - value = field.value, - mode = 'set', - } + callback(field) end end end @@ -49,17 +43,9 @@ local function ofENV(source, callback) for i = 1, #refs do local ref = refs[i] if ref.type == 'getglobal' then - callback { - source = ref, - key = guide.getKeyName(ref), - mode = 'get', - } + callback(ref) elseif ref.type == 'setglobal' then - callback { - source = ref, - key = guide.getKeyName(ref), - mode = 'set', - } + callback(ref) end findFieldInTable(ref.value, callback) end @@ -69,34 +55,24 @@ local function ofLocal(source, callback) if source.tag == '_ENV' then ofENV(source, callback) else - vm.eachRef(source, function (info) - local src = info.source + vm.eachRef(source, function (src) findFieldInTable(src.value, callback) local nextSrc, mode = checkNext(src) if not nextSrc then return end - callback { - source = nextSrc, - key = guide.getKeyName(nextSrc), - mode = mode, - } + callback(nextSrc) end) end end local function ofGlobal(source, callback) - vm.eachRef(source, function (info) - local src = info.source - local nextSrc, mode = checkNext(src) + vm.eachRef(source, function (src) + local nextSrc = checkNext(src) if not nextSrc then return end - callback { - source = nextSrc, - key = guide.getKeyName(nextSrc), - mode = mode, - } + callback(nextSrc) findFieldInTable(src.value, callback) end) end diff --git a/script-beta/vm/eachRef.lua b/script-beta/vm/eachRef.lua index d5792ce2..fe38924e 100644 --- a/script-beta/vm/eachRef.lua +++ b/script-beta/vm/eachRef.lua @@ -5,25 +5,16 @@ local vm = require 'vm.vm' local function ofLocal(loc, callback) -- 方法中的 self 使用了一个虚拟的定义位置 if loc.tag ~= 'self' then - callback { - source = loc, - mode = 'declare', - } + callback(loc) end local refs = loc.ref if refs then for i = 1, #refs do local ref = refs[i] if ref.type == 'getlocal' then - callback { - source = ref, - mode = 'get', - } + callback(ref) elseif ref.type == 'setlocal' then - callback { - source = ref, - mode = 'set', - } + callback(ref) end end end @@ -44,12 +35,9 @@ local function ofGlobal(source, callback) end end else - vm.eachField(node, function (info) - if key == info.key then - callback { - source = info.source, - mode = info.mode, - } + vm.eachField(node, function (src) + if key == guide.getKeyName(src) then + callback(src) end end) end @@ -61,22 +49,16 @@ local function ofField(source, callback) if parent.type == 'tablefield' or parent.type == 'tableindex' then local tbl = parent.parent - vm.eachField(tbl, function (info) - if key == info.key then - callback { - source = info.source, - mode = info.mode, - } + vm.eachField(tbl, function (src) + if key == guide.getKeyName(src) then + callback(src) end end) else local node = parent.node - vm.eachField(node, function (info) - if key == info.key then - callback { - source = info.source, - mode = info.mode, - } + vm.eachField(node, function (src) + if key == guide.getKeyName(src) then + callback(src) end end) end @@ -95,16 +77,10 @@ local function ofLiteral(source, callback) end local function ofLabel(source, callback) - callback { - source = source, - mode = 'set', - } + callback(source) if source.ref then for _, ref in ipairs(source.ref) do - callback { - source = ref, - mode = 'get', - } + callback(ref) end end end @@ -136,8 +112,7 @@ local function findIndex(parent, source) end local function findCallRecvs(func, index, callback) - vm.eachRef(func, function (info) - local source = info.source + vm.eachRef(func, function (source) local parent = source.parent if parent.type ~= 'call' then return @@ -217,12 +192,11 @@ end --- 获取所有的引用 function vm.eachRef(source, callback, max) local mark = {} - eachRef(source, function (info) - local src = info.source + eachRef(source, function (src) if mark[src] then return end mark[src] = true - callback(info) + callback(src) end) end diff --git a/script-beta/vm/getGlobals.lua b/script-beta/vm/getGlobals.lua index 116bf8d5..b64d63ff 100644 --- a/script-beta/vm/getGlobals.lua +++ b/script-beta/vm/getGlobals.lua @@ -8,24 +8,21 @@ local function getGlobals(root) end local cache = {} local mark = {} - vm.eachField(env, function (info) - local src = info.source + vm.eachField(env, function (src) if mark[src] then return end mark[src] = true - local name = info.key + local name = guide.getKeyName(src) if not name then return end if not cache[name] then cache[name] = { key = name, - mode = {}, } end - cache[name][#cache[name]+1] = info - cache[name].mode[info.mode] = true + cache[name][#cache[name]+1] = src vm.cache.getGlobal[src] = name end) return cache diff --git a/script-beta/vm/vm.lua b/script-beta/vm/vm.lua index 06ffc172..863686d8 100644 --- a/script-beta/vm/vm.lua +++ b/script-beta/vm/vm.lua @@ -59,6 +59,18 @@ function m.getLinkUris(call) end end +function m.isSet(src) + local tp = src.type + return tp == 'setglobal' + or tp == 'local' + or tp == 'setlocal' + or tp == 'setfield' + or tp == 'setmethod' + or tp == 'setindex' + or tp == 'tablefield' + or tp == 'tableindex' +end + m.cacheTracker = setmetatable({}, { __mode = 'kv' }) --- 刷新缓存 diff --git a/test-beta/definition/call.lua b/test-beta/definition/call.lua index 15364396..d18591ff 100644 --- a/test-beta/definition/call.lua +++ b/test-beta/definition/call.lua @@ -1,6 +1,6 @@ TEST [[ function f() - local <!x!> + local x return x end local <!y!> = f() diff --git a/test-beta/definition/method.lua b/test-beta/definition/method.lua index aa7aacdc..3ef3055c 100644 --- a/test-beta/definition/method.lua +++ b/test-beta/definition/method.lua @@ -6,26 +6,26 @@ function mt:b() end ]] -TEST [[ -function mt:<!m1!>() -end -function mt:m2() - self:<?m1?>() -end -]] - -TEST [[ -function mt:m3() - mt:<?m4?>() -end -function mt:<!m4!>() -end -]] - -TEST [[ -function mt:m3() - self:<?m4?>() -end -function mt:<!m4!>() -end -]] +--TEST [[ +--function mt:<!m1!>() +--end +--function mt:m2() +-- self:<?m1?>() +--end +--]] +-- +--TEST [[ +--function mt:m3() +-- mt:<?m4?>() +--end +--function mt:<!m4!>() +--end +--]] +-- +--TEST [[ +--function mt:m3() +-- self:<?m4?>() +--end +--function mt:<!m4!>() +--end +--]] |