diff options
author | 最萌小汐 <sumneko@hotmail.com> | 2019-12-16 21:31:58 +0800 |
---|---|---|
committer | 最萌小汐 <sumneko@hotmail.com> | 2019-12-16 21:31:58 +0800 |
commit | 70ae6b13b620148e954dbdf9b5563fe04d4e52b2 (patch) | |
tree | 08c13b758b2e8aaee3a32f7a61fc7ca7103f5338 | |
parent | d65bc946566756b7627898118e467ae8bd234341 (diff) | |
download | lua-language-server-70ae6b13b620148e954dbdf9b5563fe04d4e52b2.zip |
整理代码
-rw-r--r-- | script-beta/core/rename.lua | 4 | ||||
-rw-r--r-- | script-beta/vm/eachDef.lua | 147 | ||||
-rw-r--r-- | script-beta/vm/eachRef.lua | 87 | ||||
-rw-r--r-- | script-beta/vm/init.lua | 1 | ||||
-rw-r--r-- | script-beta/vm/refOf.lua | 519 | ||||
-rw-r--r-- | script-beta/vm/vm.lua | 1 |
6 files changed, 627 insertions, 132 deletions
diff --git a/script-beta/core/rename.lua b/script-beta/core/rename.lua index 581fd19a..4beb3bbb 100644 --- a/script-beta/core/rename.lua +++ b/script-beta/core/rename.lua @@ -11,8 +11,8 @@ local function askForcing(str) if TEST then return true end - if Forcing == false then - return false + if Forcing ~= nil then + return Forcing end local version = files.globalVersion -- TODO diff --git a/script-beta/vm/eachDef.lua b/script-beta/vm/eachDef.lua index cf865384..66420d5b 100644 --- a/script-beta/vm/eachDef.lua +++ b/script-beta/vm/eachDef.lua @@ -1,102 +1,73 @@ -local vm = require 'vm.vm' -local guide = require 'parser.guide' -local files = require 'files' +local guide = require 'parser.guide' +local files = require 'files' +local vm = require 'vm.vm' +local library = require 'library' +local await = require 'await' -local function checkPath(source, info) - if source.type == 'goto' then - return true - end - local src = info.source - local mode = guide.getPath(source, src) - if not mode then - return true - end - if mode == 'before' then - return false - end - if mode == 'equal' then - if src.type == 'field' - or src.type == 'method' - or src.type == 'local' - or src.type == 'setglobal' then - return true - else - return false - end +local function ofLocal(declare, source, callback) + +end + +local function eachDef(source, callback) + local stype = source.type + if stype == 'local' then + ofLocal(source, source, callback) + elseif stype == 'getlocal' + or stype == 'setlocal' then + ofLocal(source.node, source, callback) end - return true end --- TODO --- 只搜索本文件中的引用 --- 跨文件时,选确定入口(main的return),然后递归搜索本文件中的引用 --- 如果类型为setfield等,要确定tbl相同 -function vm.eachDef(source, callback) - local results = {} - local returns = {} - local infoMap = {} - local sourceUri = guide.getRoot(source).uri - vm.eachRef(source, function (info) - if info.mode == 'declare' - or info.mode == 'set' then - results[#results+1] = info - end - if info.mode == 'return' then - results[#results+1] = info - local root = guide.getParentBlock(info.source) - if root.type == 'main' then - returns[root.uri] = info +--- 获取所有的引用 +function vm.eachDef(source, callback, max) + local cache = vm.cache.eachDef[source] + if cache then + await.delay(function () + return files.globalVersion + end) + if max then + if max > #cache then + max = #cache end + else + max = #cache end - infoMap[info.source] = info - end) - - local function pushDef(info) - local res = callback(info) - if res ~= nil then - return res - end - local value = info.source.value - local vinfo = infoMap[value] - if vinfo then - res = callback(vinfo) + for i = 1, max do + local res = callback(cache[i]) + if res ~= nil then + return res + end end - return res + return end - - local res - local used = {} - for _, info in ipairs(results) do + local unlock = vm.lock('eachDef', source) + if not unlock then + return + end + cache = {} + vm.cache.eachDef[source] = cache + local mark = {} + eachDef(source, function (info) local src = info.source - local destUri - if used[src] then - goto CONTINUE + if mark[src] then + return end - used[src] = true - destUri = guide.getRoot(src).uri - -- 如果是同一个文件,则检查位置关系后放行 - if sourceUri == destUri then - if checkPath(source, info) then - res = pushDef(info) - end - goto CONTINUE - end - -- 如果是global或field,则直接放行(因为无法确定顺序) - if src.type == 'setindex' - or src.type == 'setfield' - or src.type == 'setmethod' - or src.type == 'tablefield' - or src.type == 'tableindex' - or src.type == 'setglobal' then - res = pushDef(info) - goto CONTINUE - end - -- 如果不是同一个文件,则必须在该文件 return 后才放行 - if returns[destUri] then - res = pushDef(info) - goto CONTINUE + mark[src] = true + cache[#cache+1] = info + end) + unlock() + await.delay(function () + return files.globalVersion + end) + if max then + if max > #cache then + max = #cache end - ::CONTINUE:: + else + max = #cache + end + for i = 1, max do + local res = callback(cache[i]) if res ~= nil then return res end diff --git a/script-beta/vm/eachRef.lua b/script-beta/vm/eachRef.lua index 81dedcbe..c629f667 100644 --- a/script-beta/vm/eachRef.lua +++ b/script-beta/vm/eachRef.lua @@ -524,48 +524,7 @@ function vm.isSameRef(a, b) end end ---- 获取所有的引用 -function vm.eachRef(source, callback, max) - local cache = vm.cache.eachRef[source] - if cache then - await.delay(function () - return files.globalVersion - end) - if max then - if max > #cache then - max = #cache - end - else - max = #cache - end - for i = 1, max do - local res = callback(cache[i]) - if res ~= nil then - return res - end - end - return - end - local unlock = vm.lock('eachRef', source) - if not unlock then - return - end - cache = {} - vm.cache.eachRef[source] = cache - local mark = {} - eachRef(source, function (info) - local src = info.source - if mark[src] then - return - end - mark[src] = true - cache[#cache+1] = info - end) - unlock() - for i = 1, #cache do - local src = cache[i].source - vm.cache.eachRef[src] = cache - end +local function applyCache(cache, callback, max) await.delay(function () return files.globalVersion end) @@ -583,3 +542,47 @@ function vm.eachRef(source, callback, max) end end end + +local function eachRef(source, callback) + local list = { source } + local mark = {} + local result = {} + local state = {} + local function found(info) + local src = info.source + if not mark[src] then + list[#list+1] = src + end + mark[src] = info + end + while #list > 0 do + local max = #list + local src = list[max] + list[max] = nil + vm.refOf(state, src, found) + end + for _, info in pairs(mark) do + result[#result+1] = info + end + return result +end + +--- 获取所有的引用 +function vm.eachRef(source, callback, max) + local cache = vm.cache.eachRef[source] + if cache then + applyCache(cache, callback, max) + return + end + local unlock = vm.lock('eachRef', source) + if not unlock then + return + end + cache = eachRef(source, callback) + unlock() + for i = 1, #cache do + local src = cache[i].source + vm.cache.eachRef[src] = cache + end + applyCache(cache, callback, max) +end diff --git a/script-beta/vm/init.lua b/script-beta/vm/init.lua index a4f81d07..3645e77a 100644 --- a/script-beta/vm/init.lua +++ b/script-beta/vm/init.lua @@ -1,5 +1,6 @@ local vm = require 'vm.vm' require 'vm.eachField' +require 'vm.refOf' require 'vm.eachRef' require 'vm.eachDef' require 'vm.getGlobals' diff --git a/script-beta/vm/refOf.lua b/script-beta/vm/refOf.lua new file mode 100644 index 00000000..96daf646 --- /dev/null +++ b/script-beta/vm/refOf.lua @@ -0,0 +1,519 @@ +local vm = require 'vm.vm' +local guide = require 'parser.guide' +local files = require 'files' +local library = require 'library' + +local function ofLocal(state, loc, callback) + if state[loc] then + return + end + state[loc] = true + -- 方法中的 self 使用了一个虚拟的定义位置 + if loc.tag ~= 'self' then + callback { + source = loc, + mode = 'declare', + } + end + local refs = loc.ref + if refs then + for i = 1, #refs do + local ref = refs[i] + if ref.type == 'getlocal' then + callback { + source = ref, + mode = 'get', + } + if loc.tag == '_ENV' then + local parent = ref.parent + if parent.type == 'getfield' + or parent.type == 'getindex' then + if guide.getKeyName(parent) == '_G' then + callback { + source = parent, + mode = 'get', + } + end + end + end + elseif ref.type == 'setlocal' then + callback { + source = ref, + mode = 'set', + } + if loc.tag == '_ENV' then + if guide.getName(ref) == '_G' then + callback { + source = ref, + mode = 'get', + } + end + end + end + end + end +end + +local function ofGlobal(state, source, callback) + if state[source] then + return + end + local key = guide.getKeyName(source) + local node = source.node + if node.tag == '_ENV' then + local uris = files.findGlobals(key) + for i = 1, #uris do + local uri = uris[i] + local ast = files.getAst(uri) + local globals = vm.getGlobals(ast.ast) + if globals and globals[key] then + for _, info in ipairs(globals[key]) do + state[info.source] = true + callback(info) + end + end + end + else + vm.eachField(node, function (info) + if key == info.key then + state[info.source] = true + callback { + source = info.source, + mode = info.mode, + } + end + end) + end +end + +local function ofField(state, source, callback) + if state[source] then + return + end + local parent = source.parent + local key = guide.getKeyName(source) + if parent.type == 'tablefield' + or parent.type == 'tableindex' then + local tbl = parent.parent + vm.eachField(tbl, function (info) + if key == info.key then + state[info.source] = true + callback { + source = info.source, + mode = info.mode, + } + end + end) + else + local node = parent.node + vm.eachField(node, function (info) + if key == info.key then + state[info.source] = true + callback { + source = info.source, + mode = info.mode, + } + end + end) + end +end + +local function ofLabel(state, source, callback) + if state[source] then + return + end + state[source] = true + callback { + source = source, + mode = 'set', + } + if source.ref then + for _, ref in ipairs(source.ref) do + callback { + source = ref, + mode = 'get', + } + end + end +end + +local function ofGoTo(state, source, callback) + local name = source[1] + local label = guide.getLabel(source, name) + if label then + ofLabel(state, label, callback) + end +end + +local function ofValue(state, source, callback) + callback { + source = source, + mode = 'value', + } +end + +local function ofIndex(state, source, callback) + local parent = source.parent + if not parent then + return + end + if parent.type == 'setindex' + or parent.type == 'getindex' + or parent.type == 'tableindex' then + ofField(state, source, callback) + end +end + +local function ofCall(state, func, index, callback, offset) + offset = offset or 0 + vm.eachRef(func, function (info) + local src = info.source + local returns + if src.type == 'main' or src.type == 'function' then + returns = src.returns + end + if returns then + -- 搜索函数第 index 个返回值 + for i = 1, #returns do + local rtn = returns[i] + local val = rtn[index-offset] + if val then + callback { + source = val, + mode = 'return', + } + end + end + end + end) +end + +local function ofSpecialCall(state, call, func, index, callback, offset) + local name = func.special + offset = offset or 0 + if name == 'setmetatable' then + if index == 1 + offset then + local args = call.args + if args[1+offset] then + callback { + source = args[1+offset], + mode = 'get', + } + end + if args[2+offset] then + vm.eachField(args[2+offset], function (info) + if info.key == 's|__index' then + callback(info) + end + end) + end + vm.setMeta(args[1+offset], args[2+offset]) + end + elseif name == 'require' then + if index == 1 + offset then + local result = vm.getLinkUris(call) + if result then + local myUri = guide.getRoot(call).uri + for i = 1, #result do + local uri = result[i] + if not files.eq(uri, myUri) then + local ast = files.getAst(uri) + if ast then + ofCall(state, ast.ast, 1, callback) + end + end + end + end + + local args = call.args + if args[1+offset] then + if args[1+offset].type == 'string' then + local objName = args[1+offset][1] + local lib = library.library[objName] + if lib then + callback { + source = lib, + mode = 'value', + } + end + end + end + end + elseif name == 'pcall' + or name == 'xpcall' then + if index >= 2-offset then + local args = call.args + if args[1+offset] then + vm.eachRef(args[1+offset], function (info) + local src = info.source + if src.type == 'function' then + ofCall(state, src, index, callback, 1+offset) + ofSpecialCall(state, call, src, index, callback, 1+offset) + end + end) + end + end + end +end + +local function ofSelect(state, source, callback) + -- 检查函数返回值 + local call = source.vararg + if call.type == 'call' then + ofCall(state, call.node, source.index, callback) + ofSpecialCall(state, call, call.node, source.index, callback) + end +end + +local function ofMain(state, source, callback) + callback { + source = source, + mode = 'main', + } +end + +local function getCallRecvs(call) + local parent = call.parent + if parent.type ~= 'select' then + return nil + end + local extParent = call.extParent + local recvs = {} + recvs[1] = parent.parent + if extParent then + for i = 1, #extParent do + local p = extParent[i] + recvs[#recvs+1] = p.parent + end + end + return recvs +end + +--- 自己作为函数的参数 +local function checkAsArg(state, source, callback) + local parent = source.parent + if not parent then + return + end + if parent.type == 'callargs' then + local call = parent.parent + local func = call.node + local name = func.special + if name == 'setmetatable' then + if parent[1] == source then + if parent[2] then + vm.eachField(parent[2], function (info) + if info.key == 's|__index' then + callback(info) + end + end) + end + local recvs = getCallRecvs(call) + if recvs and recvs[1] then + callback { + source = recvs[1], + mode = 'return', + } + end + vm.setMeta(source, parent[2]) + end + end + end +end + +local function ofCallSelect(state, call, index, callback) + local slc = call.parent + if slc.index == index then + callback { + source = slc.parent, + mode = 'get', + } + return + end + if call.extParent then + for i = 1, #call.extParent do + slc = call.extParent[i] + if slc.index == index then + callback { + source = slc.parent, + mode = 'get', + } + return + end + end + end +end + +--- 自己作为函数的返回值 +local function checkAsReturn(state, source, callback) + local parent = source.parent + if source.type == 'field' + or source.type == 'method' then + parent = parent.parent + end + if not parent or parent.type ~= 'return' then + return + end + local func = guide.getParentFunction(source) + if func.type == 'main' then + local myUri = func.uri + local uris = files.findLinkTo(myUri) + if not uris then + return + end + for i = 1, #uris do + local uri = uris[i] + local ast = files.getAst(uri) + if ast then + local links = vm.getLinks(ast.ast) + if links then + for linkUri, calls in pairs(links) do + if files.eq(linkUri, myUri) then + for j = 1, #calls do + ofCallSelect(state, calls[j], 1, callback) + end + end + end + end + end + end + else + local index + for i = 1, #parent do + if parent[i] == source then + index = i + break + end + end + if not index then + return + end + vm.eachRef(func, function (info) + local src = info.source + local call = src.parent + if not call or call.type ~= 'call' then + return + end + local recvs = getCallRecvs(call) + if recvs and recvs[index] then + callback { + source = recvs[index], + mode = 'return', + } + elseif index == 1 then + callback { + type = 'call', + source = call, + } + end + end) + end +end + +local function checkAsParen(state, source, callback) + if state[source] then + return + end + state[source] = true + if source.parent and source.parent.type == 'paren' then + vm.refOf(state, source.parent, callback) + end +end + +local function checkValue(state, source, callback) + if source.value then + callback { + source = source.value, + mode = 'value', + } + end +end + +local function checkSetValue(value, callback) + if value.type == 'field' + or value.type == 'method' then + value = value.parent + end + local parent = value.parent + if not parent then + return + end + if parent.type == 'local' + or parent.type == 'setglobal' + or parent.type == 'setlocal' + or parent.type == 'setfield' + or parent.type == 'setmethod' + or parent.type == 'setindex' + or parent.type == 'tablefield' + or parent.type == 'tableindex' then + if parent.value == value then + callback { + source = parent, + mode = 'set', + } + if guide.getName(parent) == '__index' then + if parent.type == 'tablefield' + or parent.type == 'tableindex' then + local t = parent.parent + local args = t.parent + if args[2] == t then + local call = args.parent + local func = call.node + if func.special == 'setmetatable' then + callback { + source = args[1], + mode = 'get', + } + end + end + end + end + end + end +end + +function vm.refOf(state, source, callback) + local stype = source.type + if stype == 'local' then + ofLocal(state, source, callback) + elseif stype == 'getlocal' + or stype == 'setlocal' then + ofLocal(state, source.node, callback) + elseif stype == 'setglobal' + or stype == 'getglobal' then + ofGlobal(state, source, callback) + elseif stype == 'field' + or stype == 'method' then + ofField(state, source, callback) + elseif stype == 'setfield' + or stype == 'getfield' + or stype == 'tablefield' then + ofField(state, source.field, callback) + elseif stype == 'setmethod' + or stype == 'getmethod' then + ofField(state, source.method, callback) + elseif stype == 'goto' then + ofGoTo(state, source, callback) + elseif stype == 'label' then + ofLabel(state, source, callback) + elseif stype == 'number' + or stype == 'boolean' + or stype == 'string' then + ofIndex(state, source, callback) + ofValue(state, source, callback) + elseif stype == 'table' + or stype == 'function' then + ofValue(state, source, callback) + elseif stype == 'select' then + ofSelect(state, source, callback) + elseif stype == 'main' then + ofMain(state, source, callback) + elseif stype == 'paren' then + vm.refOf(state, source.exp, callback) + end + checkValue(state, source, callback) + checkAsArg(state, source, callback) + checkAsReturn(state, source, callback) + checkAsParen(state, source, callback) + checkSetValue(state, source, callback) +end diff --git a/script-beta/vm/vm.lua b/script-beta/vm/vm.lua index 5460c52b..06ffc172 100644 --- a/script-beta/vm/vm.lua +++ b/script-beta/vm/vm.lua @@ -68,6 +68,7 @@ function m.refreshCache() end m.cache = { eachRef = {}, + eachDef = {}, eachField = {}, eachMeta = {}, getGlobals = {}, |