diff options
author | 最萌小汐 <sumneko@hotmail.com> | 2019-12-17 14:59:39 +0800 |
---|---|---|
committer | 最萌小汐 <sumneko@hotmail.com> | 2019-12-17 14:59:39 +0800 |
commit | 5e3b44234899bf57f2688d48ef617357ef706669 (patch) | |
tree | aa41e119430a773ab5c300a3e3387edd8a3dcc8d | |
parent | 41d83a00788d2232ef6781ee337acc3d0c380bf2 (diff) | |
download | lua-language-server-5e3b44234899bf57f2688d48ef617357ef706669.zip |
eachDef
-rw-r--r-- | script-beta/vm/eachDef.lua | 401 | ||||
-rw-r--r-- | script-beta/vm/eachRef.lua | 9 | ||||
-rw-r--r-- | test-beta/definition/call.lua | 2 |
3 files changed, 369 insertions, 43 deletions
diff --git a/script-beta/vm/eachDef.lua b/script-beta/vm/eachDef.lua index 66420d5b..f2ddc73b 100644 --- a/script-beta/vm/eachDef.lua +++ b/script-beta/vm/eachDef.lua @@ -4,58 +4,367 @@ local vm = require 'vm.vm' local library = require 'library' local await = require 'await' -local function ofLocal(declare, source, callback) - +local function ofSelf(state, loc, callback) + -- self 的2个特殊引用位置: + -- 1. 当前方法定义时的对象(mt) + local method = loc.method + local node = method.node + callback(node) + -- 2. 调用该方法时传入的对象 end -local function eachDef(source, callback) - local stype = source.type - if stype == 'local' then - ofLocal(source, source, callback) - elseif stype == 'getlocal' - or stype == 'setlocal' then - ofLocal(source.node, source, callback) +local function ofLocal(state, loc, source, callback) + if state[loc] then + return + end + state[loc] = true + -- 方法中的 self 使用了一个虚拟的定义位置 + if loc.tag ~= 'self' then + callback(loc, 'declare') + end + local refs = loc.ref + if refs then + for i = 1, #refs do + local ref = refs[i] + if ref == source then + break + end + if ref.type == 'getlocal' then + if loc.tag == '_ENV' then + local parent = ref.parent + if parent.type == 'getfield' + or parent.type == 'getindex' then + if guide.getKeyName(parent) == '_G' then + callback(parent, 'declare') + end + end + end + elseif ref.type == 'setlocal' then + callback(ref, 'set') + end + end + end + if loc.tag == 'self' then + ofSelf(state, loc, callback) end end ---- 获取所有的引用 -function vm.eachDef(source, callback, max) - local cache = vm.cache.eachDef[source] - if cache then - await.delay(function () - return files.globalVersion - end) - if max then - if max > #cache then - max = #cache +local function ofGlobal(state, source, callback) + if state[source] then + return + end + local key = guide.getKeyName(source) + local node = source.node + if node.tag == '_ENV' then + local uris = files.findGlobals(key) + for i = 1, #uris do + local uri = uris[i] + local ast = files.getAst(uri) + local globals = vm.getGlobals(ast.ast) + if globals and globals[key] then + for _, info in ipairs(globals[key]) do + state[info.source] = true + if info.mode == 'set' then + callback(info) + end + end end - else - max = #cache end - for i = 1, max do - local res = callback(cache[i]) - if res ~= nil then - return res + else + vm.eachField(node, function (info) + if key == info.key then + state[info.source] = true + if info.mode == 'set' then + callback(info) + end end - end + end) + end +end + +local function ofField(state, source, callback) + if state[source] then return end - local unlock = vm.lock('eachDef', source) - if not unlock then + local parent = source.parent + local key = guide.getKeyName(source) + if parent.type == 'tablefield' + or parent.type == 'tableindex' then + local tbl = parent.parent + vm.eachField(tbl, function (info) + if key == info.key then + state[info.source] = true + if info.mode == 'set' then + callback(info) + end + end + end) + else + local node = parent.node + vm.eachField(node, function (info) + if key == info.key then + state[info.source] = true + if info.mode == 'set' then + callback(info) + end + end + end) + end +end + +local function ofLabel(state, source, callback) + if state[source] then return end - cache = {} - vm.cache.eachDef[source] = cache - local mark = {} - eachDef(source, function (info) + state[source] = true + callback(source, 'set') +end + +local function ofGoTo(state, source, callback) + local name = source[1] + local label = guide.getLabel(source, name) + if label then + ofLabel(state, label, callback) + end +end + +local function ofValue(state, source, callback) + callback(source, 'value') +end + +local function ofIndex(state, source, callback) + local parent = source.parent + if not parent then + return + end + if parent.type == 'setindex' + or parent.type == 'getindex' + or parent.type == 'tableindex' then + ofField(state, source, callback) + end +end + +local function ofCall(state, func, index, callback, offset) + offset = offset or 0 + vm.eachRef(func, function (info) local src = info.source - if mark[src] then - return + local returns + if src.type == 'main' or src.type == 'function' then + returns = src.returns + end + if returns then + -- 搜索函数第 index 个返回值 + for i = 1, #returns do + local rtn = returns[i] + local val = rtn[index-offset] + if val then + callback(val) + end + end end - mark[src] = true - cache[#cache+1] = info end) - unlock() +end + +local function ofSpecialCall(state, call, func, index, callback, offset) + local name = func.special + offset = offset or 0 + if name == 'setmetatable' then + if index == 1 + offset then + local args = call.args + if args[1+offset] then + callback(args[1+offset]) + end + if args[2+offset] then + vm.eachField(args[2+offset], function (info) + if info.key == 's|__index' then + callback(info.source) + end + end) + end + vm.setMeta(args[1+offset], args[2+offset]) + end + elseif name == 'require' then + if index == 1 + offset then + local result = vm.getLinkUris(call) + if result then + local myUri = guide.getRoot(call).uri + for i = 1, #result do + local uri = result[i] + if not files.eq(uri, myUri) then + local ast = files.getAst(uri) + if ast then + ofCall(state, ast.ast, 1, callback) + end + end + end + end + + local args = call.args + if args[1+offset] then + if args[1+offset].type == 'string' then + local objName = args[1+offset][1] + local lib = library.library[objName] + if lib then + callback(lib) + end + end + end + end + elseif name == 'pcall' + or name == 'xpcall' then + if index >= 2-offset then + local args = call.args + if args[1+offset] then + vm.eachRef(args[1+offset], function (info) + local src = info.source + if src.type == 'function' then + ofCall(state, src, index, callback, 1+offset) + ofSpecialCall(state, call, src, index, callback, 1+offset) + end + end) + end + end + end +end + +local function ofSelect(state, source, callback) + -- 检查函数返回值 + local call = source.vararg + if call.type == 'call' then + ofCall(state, call.node, source.index, callback) + ofSpecialCall(state, call, call.node, source.index, callback) + end +end + +local function ofMain(state, source, callback) + callback(source, 'main') +end + +local function getCallRecvs(call) + local parent = call.parent + if parent.type ~= 'select' then + return nil + end + local extParent = call.extParent + local recvs = {} + recvs[1] = parent.parent + if extParent then + for i = 1, #extParent do + local p = extParent[i] + recvs[#recvs+1] = p.parent + end + end + return recvs +end + +local function checkValue(state, source, callback) + if source.value then + callback(source.value) + end +end + +function vm.defCheck(state, source, callback) + checkValue(state, source, callback) +end + +function vm.defOf(state, source, callback) + local stype = source.type + if stype == 'local' then + ofLocal(state, source, source, callback) + elseif stype == 'getlocal' + or stype == 'setlocal' then + ofLocal(state, source.node, source, callback) + elseif stype == 'setglobal' + or stype == 'getglobal' then + ofGlobal(state, source, callback) + elseif stype == 'field' + or stype == 'method' then + ofField(state, source, callback) + elseif stype == 'setfield' + or stype == 'getfield' + or stype == 'tablefield' then + ofField(state, source.field, callback) + elseif stype == 'setmethod' + or stype == 'getmethod' then + ofField(state, source.method, callback) + elseif stype == 'goto' then + ofGoTo(state, source, callback) + elseif stype == 'label' then + ofLabel(state, source, callback) + elseif stype == 'number' + or stype == 'boolean' + or stype == 'string' then + ofIndex(state, source, callback) + ofValue(state, source, callback) + elseif stype == 'table' + or stype == 'function' then + ofValue(state, source, callback) + elseif stype == 'select' then + ofSelect(state, source, callback) + elseif stype == 'call' then + ofCall(state, source.node, 1, callback) + ofSpecialCall(state, source, source.node, 1, callback) + elseif stype == 'main' then + ofMain(state, source, callback) + elseif stype == 'paren' then + vm.defOf(state, source.exp, callback) + end +end + +local function eachDef(source, result) + local list = { source } + local mark = {} + local state = {} + local hasOf = {} + local hasCheck = {} + local function found(src, mode) + local info + if src.mode then + info = src + src = info.source + end + if mark[src] == nil then + list[#list+1] = src + end + if info then + mark[src] = info + elseif mode then + mark[src] = { + source = src, + mode = mode, + } + else + mark[src] = mark[src] or false + end + end + for _ = 1, 10000 do + if _ == 10000 then + error('stack overflow!') + end + local max = #list + if max == 0 then + break + end + local src = list[max] + list[max] = nil + if not hasOf[src] then + hasOf[src] = true + vm.defOf(state, src, found) + end + if not hasCheck[src] then + hasCheck[src] = true + vm.defCheck(state, src, found) + end + end + for _, info in pairs(mark) do + if info then + result[#result+1] = info + end + end + return result +end + +local function applyCache(cache, callback, max) await.delay(function () return files.globalVersion end) @@ -73,3 +382,21 @@ function vm.eachDef(source, callback, max) end end end + +--- 获取所有的引用 +function vm.eachDef(source, callback, max) + local cache = vm.cache.eachDef[source] + if cache then + applyCache(cache, callback, max) + return + end + local unlock = vm.lock('eachDef', source) + if not unlock then + return + end + cache = {} + vm.cache.eachDef[source] = cache + eachDef(source, cache) + unlock() + applyCache(cache, callback, max) +end diff --git a/script-beta/vm/eachRef.lua b/script-beta/vm/eachRef.lua index 456e59e1..94952568 100644 --- a/script-beta/vm/eachRef.lua +++ b/script-beta/vm/eachRef.lua @@ -3,10 +3,6 @@ local files = require 'files' local vm = require 'vm.vm' local library = require 'library' local await = require 'await' -local vm = require 'vm.vm' -local guide = require 'parser.guide' -local files = require 'files' -local library = require 'library' local function ofSelf(state, loc, callback) -- self 的2个特殊引用位置: @@ -466,7 +462,10 @@ local function eachRef(source, result) mark[src] = mark[src] or false end end - for _ = 1, 1000 do + for _ = 1, 10000 do + if _ == 10000 then + error('stack overflow!') + end local max = #list if max == 0 then break diff --git a/test-beta/definition/call.lua b/test-beta/definition/call.lua index 42502f40..15364396 100644 --- a/test-beta/definition/call.lua +++ b/test-beta/definition/call.lua @@ -1,7 +1,7 @@ TEST [[ function f() local <!x!> - return <!x!> + return x end local <!y!> = f() print(<?y?>) |