diff options
-rw-r--r-- | server-beta/src/core/definition.lua | 23 | ||||
-rw-r--r-- | server-beta/src/files.lua | 20 | ||||
-rw-r--r-- | server-beta/src/searcher/eachField.lua | 60 | ||||
-rw-r--r-- | server-beta/src/searcher/eachRef.lua | 138 | ||||
-rw-r--r-- | server-beta/src/searcher/init.lua | 193 | ||||
-rw-r--r-- | server-beta/src/searcher/searcher.lua | 119 |
6 files changed, 257 insertions, 296 deletions
diff --git a/server-beta/src/core/definition.lua b/server-beta/src/core/definition.lua index f9ba00d1..30ce5dec 100644 --- a/server-beta/src/core/definition.lua +++ b/server-beta/src/core/definition.lua @@ -1,8 +1,9 @@ local guide = require 'parser.guide' local workspace = require 'workspace' local files = require 'files' +local searcher = require 'searcher' -local function findDef(searcher, source, callback) +local function findDef(sch, source, callback) if source.type ~= 'local' and source.type ~= 'getlocal' and source.type ~= 'setlocal' @@ -16,12 +17,12 @@ local function findDef(searcher, source, callback) and source.type ~= 'goto' then return end - searcher:eachRef(source, function (info) + searcher.eachRef(source, function (info) if info.mode == 'declare' - or info.mode == 'set' + or info.mode == 'set' or info.mode == 'return' then local src = info.source - local uri = info.searcher.uri + local uri = info.uri if src.type == 'setfield' or src.type == 'getfield' or src.type == 'tablefield' then @@ -40,8 +41,8 @@ local function findDef(searcher, source, callback) end) end ----@param searcher engineer -local function checkRequire(searcher, source, offset, callback) +---@param sch searcher +local function checkRequire(sch, source, offset, callback) if source.type ~= 'call' then return end @@ -57,7 +58,7 @@ local function checkRequire(searcher, source, offset, callback) if type(literal) ~= 'string' then return end - local name = searcher:getSpecialName(func) + local name = searcher.getSpecialName(func) if name == 'require' then local result = workspace.findUrisByRequirePath(literal, true) for _, uri in ipairs(result) do @@ -73,12 +74,12 @@ local function checkRequire(searcher, source, offset, callback) end return function (uri, offset) - local results = {} - local searcher = files.getSearcher(uri) - if not searcher then + local ast = files.getAst(uri) + if not ast then return nil end - guide.eachSourceContain(searcher.ast, offset, function (source) + local results = {} + guide.eachSourceContain(ast.ast, offset, function (source) checkRequire(searcher, source, offset, function (uri) results[#results+1] = { uri = files.getOriginUri(uri), diff --git a/server-beta/src/files.lua b/server-beta/src/files.lua index 4e8cd093..c3b33122 100644 --- a/server-beta/src/files.lua +++ b/server-beta/src/files.lua @@ -5,6 +5,7 @@ local config = require 'config' local glob = require 'glob' local furi = require 'file-uri' local parser = require 'parser' +local searcher = require 'searcher.searcher' local m = {} @@ -61,6 +62,7 @@ function m.setText(uri, text) file.searcher = nil file.lines = nil file.ast = nil + searcher.refreshCache() end --- 监听编译完成 @@ -99,6 +101,7 @@ function m.remove(uri) uri = uri:lower() end m.fileMap[uri] = nil + searcher.refreshCache() end --- 移除所有文件 @@ -106,6 +109,7 @@ function m.removeAll() for uri in pairs(m.fileMap) do m.fileMap[uri] = nil end + searcher.refreshCache() end --- 遍历文件 @@ -152,22 +156,6 @@ function m.getLines(uri) return file.lines end ---- 获取搜索器 -function m.getSearcher(uri) - if platform.OS == 'Windows' then - uri = uri:lower() - end - local file = m.fileMap[uri] - if not file then - return nil - end - if not file.searcher then - local searcher = require 'searcher' - file.searcher = searcher.create(uri) - end - return file.searcher -end - --- 获取原始uri function m.getOriginUri(uri) if platform.OS == 'Windows' then diff --git a/server-beta/src/searcher/eachField.lua b/server-beta/src/searcher/eachField.lua index 3db2375f..71b6e45b 100644 --- a/server-beta/src/searcher/eachField.lua +++ b/server-beta/src/searcher/eachField.lua @@ -1,11 +1,11 @@ local guide = require 'parser.guide' +local searcher = require 'searcher.searcher' -local function ofTabel(searcher, value, callback) +local function ofTabel(value, callback) for _, field in ipairs(value) do if field.type == 'tablefield' or field.type == 'tableindex' then callback { - searcher = searcher, source = field, key = guide.getKeyName(field), value = field.value, @@ -15,14 +15,13 @@ local function ofTabel(searcher, value, callback) end end -local function ofENV(searcher, source, callback) +local function ofENV(source, callback) if source.type == 'getlocal' then local parent = source.parent if parent.type == 'getfield' or parent.type == 'getmethod' or parent.type == 'getindex' then callback { - searcher = searcher, source = parent, key = guide.getKeyName(parent), mode = 'get', @@ -30,14 +29,12 @@ local function ofENV(searcher, source, callback) end elseif source.type == 'getglobal' then callback { - searcher = searcher, source = source, key = guide.getKeyName(source), mode = 'get', } elseif source.type == 'setglobal' then callback { - searcher = searcher, source = source, key = guide.getKeyName(source), mode = 'set', @@ -46,7 +43,7 @@ local function ofENV(searcher, source, callback) end end -local function ofSpecialArg(searcher, source, callback) +local function ofSpecialArg(source, callback) local args = source.parent local call = args.parent local func = call.node @@ -54,7 +51,6 @@ local function ofSpecialArg(searcher, source, callback) if name == 'rawset' then if args[1] == source and args[2] then callback { - searcher = searcher, source = call, key = guide.getKeyName(args[2]), value = args[3], @@ -64,7 +60,6 @@ local function ofSpecialArg(searcher, source, callback) elseif name == 'rawget' then if args[1] == source and args[2] then callback { - searcher = searcher, source = call, key = guide.getKeyName(args[2]), mode = 'get', @@ -72,16 +67,16 @@ local function ofSpecialArg(searcher, source, callback) end elseif name == 'setmetatable' then if args[1] == source and args[2] then - searcher:eachField(args[2], function (info) + searcher.eachField(args[2], function (info) if info.key == 's|__index' and info.value then - info.searcher:eachField(info.value, callback) + searcher.eachField(info.value, callback) end end) end end end -local function ofVar(searcher, source, callback) +local function ofVar(source, callback) local parent = source.parent if not parent then return @@ -90,7 +85,6 @@ local function ofVar(searcher, source, callback) or parent.type == 'getmethod' or parent.type == 'getindex' then callback { - searcher = searcher, source = parent, key = guide.getKeyName(parent), mode = 'get', @@ -101,7 +95,6 @@ local function ofVar(searcher, source, callback) or parent.type == 'setmethod' or parent.type == 'setindex' then callback { - searcher = searcher, source = parent, key = guide.getKeyName(parent), value = parent.value, @@ -110,17 +103,17 @@ local function ofVar(searcher, source, callback) return end if parent.type == 'callargs' then - ofSpecialArg(searcher, source, callback) + ofSpecialArg(source, callback) end end -return function (searcher, source, callback) - searcher:eachRef(source, function (info) +local function eachField(source, callback) + searcher.eachRef(source, function (info) local src = info.source if src.tag == '_ENV' then if src.ref then for _, ref in ipairs(src.ref) do - ofENV(info.searcher, ref, callback) + ofENV(ref, callback) end end elseif src.type == 'getlocal' @@ -128,9 +121,36 @@ return function (searcher, source, callback) or src.type == 'getfield' or src.type == 'getmethod' or src.type == 'getindex' then - ofVar(info.searcher, src, callback) + ofVar(src, callback) elseif src.type == 'table' then - ofTabel(info.searcher, src, callback) + ofTabel(src, callback) end end) end + +--- 获取所有的field +function searcher.eachField(source, callback) + local lock <close> = searcher.lock('eachField', source) + if not lock then + return + end + local cache = searcher.cache.eachField[source] + if cache then + for i = 1, #cache do + callback(cache[i]) + end + return + end + cache = {} + searcher.cache.eachField[source] = cache + local mark = {} + eachField(source, function (info) + local src = info.source + if mark[src] then + return + end + mark[src] = true + cache[#cache+1] = info + callback(info) + end) +end diff --git a/server-beta/src/searcher/eachRef.lua b/server-beta/src/searcher/eachRef.lua index 05788eb0..b7c108d8 100644 --- a/server-beta/src/searcher/eachRef.lua +++ b/server-beta/src/searcher/eachRef.lua @@ -1,9 +1,10 @@ local guide = require 'parser.guide' local files = require 'files' local workspace = require 'workspace' +local searcher = require 'searcher.searcher' -local function ofCall(searcher, func, index, callback) - searcher:eachRef(func, function (info) +local function ofCall(func, index, callback) + searcher.eachRef(func, function (info) local src = info.source local funcDef = src.value if funcDef and funcDef.returns then @@ -12,31 +13,30 @@ local function ofCall(searcher, func, index, callback) local val = rtn[index] if val then callback { - searcher = info.searcher, source = val, mode = 'return', } - info.searcher:eachRef(val, callback) + searcher.eachRef(val, callback) end end end end) end -local function ofSpecialCall(searcher, call, func, index, callback) - local name = searcher:getSpecialName(func) +local function ofSpecialCall(call, func, index, callback) + local name = searcher.getSpecialName(func) if name == 'setmetatable' then if index == 1 then local args = call.args if args[1] then - searcher:eachRef(args[1], callback) + searcher.eachRef(args[1], callback) end if args[2] then - searcher:eachField(args[2], function (info) + searcher.eachField(args[2], function (info) if info.key == 's|__index' then - info.searcher:eachRef(info.source, callback) + searcher.eachRef(info.source, callback) if info.value then - info.searcher:eachRef(info.value, callback) + searcher.eachRef(info.value, callback) end end end) @@ -61,34 +61,33 @@ local function ofSpecialCall(searcher, call, func, index, callback) end end -local function ofValue(searcher, value, callback) +local function ofValue(value, callback) if value.type == 'select' then -- 检查函数返回值 local call = value.vararg if call.type == 'call' then - ofCall(searcher, call.node, value.index, callback) - ofSpecialCall(searcher, call, call.node, value.index, callback) + ofCall(call.node, value.index, callback) + ofSpecialCall(call, call.node, value.index, callback) end return end callback { - searcher = searcher, source = value, mode = 'value', } end -local function ofSelf(searcher, loc, callback) +local function ofSelf(loc, callback) -- self 的2个特殊引用位置: -- 1. 当前方法定义时的对象(mt) local method = loc.method local node = method.node - searcher:eachRef(node, callback) + searcher.eachRef(node, callback) -- 2. 调用该方法时传入的对象 end --- 自己作为赋值的值 -local function asValue(searcher, source, callback) +local function asValue(source, callback) local parent = source.parent if parent and parent.value == source then if guide.getKeyName(parent) == 's|__index' then @@ -99,8 +98,8 @@ local function asValue(searcher, source, callback) if args[2] == t then local call = args.parent local func = call.node - if searcher:getSpecialName(func) == 'setmetatable' then - searcher:eachRef(args[1], callback) + if searcher.getSpecialName(func) == 'setmetatable' then + searcher.eachRef(args[1], callback) end end end @@ -125,7 +124,7 @@ local function getCallRecvs(call) end --- 自己作为函数的参数 -local function asArg(searcher, source, callback) +local function asArg(source, callback) local parent = source.parent if not parent then return @@ -133,15 +132,15 @@ local function asArg(searcher, source, callback) if parent.type == 'callargs' then local call = parent.parent local func = call.node - local name = searcher:getSpecialName(func) + local name = searcher.getSpecialName(func) if name == 'setmetatable' then if parent[1] == source then if parent[2] then - searcher:eachField(parent[2], function (info) + searcher.eachField(parent[2], function (info) if info.key == 's|__index' then - info.searcher:eachRef(info.source, callback) + searcher.eachRef(info.source, callback) if info.value then - info.searcher:eachRef(info.value, callback) + searcher.eachRef(info.value, callback) end end end) @@ -149,17 +148,16 @@ local function asArg(searcher, source, callback) end local recvs = getCallRecvs(call) if recvs and recvs[1] then - searcher:eachRef(recvs[1], callback) + searcher.eachRef(recvs[1], callback) end end end end -local function ofLocal(searcher, loc, callback) +local function ofLocal(loc, callback) -- 方法中的 self 使用了一个虚拟的定义位置 if loc.tag ~= 'self' then callback { - searcher = searcher, source = loc, mode = 'declare', } @@ -168,28 +166,26 @@ local function ofLocal(searcher, loc, callback) for _, ref in ipairs(loc.ref) do if ref.type == 'getlocal' then callback { - searcher = searcher, source = ref, mode = 'get', } - asValue(searcher, ref, callback) + asValue(ref, callback) elseif ref.type == 'setlocal' then callback { - searcher = searcher, source = ref, mode = 'set', } if ref.value then - ofValue(searcher, ref.value, callback) + ofValue(ref.value, callback) end end end end if loc.tag == 'self' then - ofSelf(searcher, loc, callback) + ofSelf(loc, callback) end if loc.value then - ofValue(searcher, loc.value, callback) + ofValue(loc.value, callback) end if loc.tag == '_ENV' then for _, ref in ipairs(loc.ref) do @@ -199,7 +195,6 @@ local function ofLocal(searcher, loc, callback) or parent.type == 'getindex' then if guide.getKeyName(parent) == 's|_G' then callback { - searcher = searcher, source = parent, mode = 'get', } @@ -208,7 +203,6 @@ local function ofLocal(searcher, loc, callback) elseif ref.type == 'getglobal' then if guide.getKeyName(ref) == 's|_G' then callback { - searcher = searcher, source = ref, mode = 'get', } @@ -218,101 +212,125 @@ local function ofLocal(searcher, loc, callback) end end -local function ofGlobal(searcher, source, callback) +local function ofGlobal(source, callback) local node = source.node local key = guide.getKeyName(source) - searcher:eachField(node, function (info) + searcher.eachField(node, function (info) if key == info.key then callback { - searcher = info.searcher, source = info.source, mode = info.mode, } if info.value then - ofValue(info.searcher, info.value, callback) + ofValue(info.value, callback) end end end) end -local function ofField(searcher, source, callback) +local function ofField(source, callback) local parent = source.parent local node = parent.node local key = guide.getKeyName(source) - searcher:eachField(node, function (info) + searcher.eachField(node, function (info) if key == info.key then callback { - searcher = info.searcher, source = info.source, mode = info.mode, } if info.value then - ofValue(info.searcher, info.value, callback) + ofValue(info.value, callback) end end end) end -local function ofLiteral(searcher, source, callback) +local function ofLiteral(source, callback) local parent = source.parent if not parent then return end if parent.type == 'setindex' or parent.type == 'getindex' then - ofField(searcher, source, callback) + ofField(source, callback) end end -local function ofGoTo(searcher, source, callback) +local function ofGoTo(source, callback) local name = source[1] local label = guide.getLabel(source, name) if label then callback { - searcher = searcher, source = label, mode = 'set', } end end -local function ofLabel(searcher, source, callback) +local function ofLabel(source, callback) end -return function (searcher, source, callback) +local function eachRef(source, callback) local stype = source.type if stype == 'local' then - ofLocal(searcher, source, callback) + ofLocal(source, callback) elseif stype == 'getlocal' or stype == 'setlocal' then - ofLocal(searcher, source.node, callback) + ofLocal(source.node, callback) elseif stype == 'setglobal' or stype == 'getglobal' then - ofGlobal(searcher, source, callback) + ofGlobal(source, callback) elseif stype == 'field' or stype == 'method' or stype == 'index' then - ofField(searcher, source, callback) + ofField(source, callback) elseif stype == 'setfield' or stype == 'getfield' then - ofField(searcher, source.field, callback) + ofField(source.field, callback) elseif stype == 'setmethod' or stype == 'getmethod' then - ofField(searcher, source.method, callback) + ofField(source.method, callback) elseif stype == 'number' or stype == 'boolean' or stype == 'string' then - ofLiteral(searcher, source, callback) + ofLiteral(source, callback) elseif stype == 'goto' then - ofGoTo(searcher, source, callback) + ofGoTo(source, callback) elseif stype == 'label' then - ofLabel(searcher, source, callback) + ofLabel(source, callback) else callback { - searcher = searcher, source = source, + mode = 'value', } end - asArg(searcher, source, callback) + asArg(source, callback) +end + +--- 获取所有的引用 +function searcher.eachRef(source, callback) + local lock <close> = searcher.lock('eachRef', source) + if not lock then + return + end + local cache = searcher.cache.eachRef[source] + if cache then + for i = 1, #cache do + callback(cache[i]) + end + return + end + cache = {} + searcher.cache.eachRef[source] = cache + local mark = {} + eachRef(source, function (info) + local src = info.source + if mark[src] then + return + end + mark[src] = true + cache[#cache+1] = info + callback(info) + end) end diff --git a/server-beta/src/searcher/init.lua b/server-beta/src/searcher/init.lua index 66a15f54..770aed99 100644 --- a/server-beta/src/searcher/init.lua +++ b/server-beta/src/searcher/init.lua @@ -1,189 +1,4 @@ -local guide = require 'parser.guide' -local files = require 'files' -local util = require 'utility' -local eachRef = require 'searcher.eachRef' -local eachField = require 'searcher.eachField' - -local setmetatable = setmetatable -local assert = assert - -_ENV = nil - -local specials = { - ['_G'] = true, - ['rawset'] = true, - ['rawget'] = true, - ['setmetatable'] = true, - ['require'] = true, - ['dofile'] = true, - ['loadfile'] = true, -} - ----@class searcher -local mt = {} -mt.__index = mt -mt.__name = 'searcher' - -function mt:lock(tp, source) - if self.locked[tp][source] then - return nil - end - self.locked[tp][source] = true - return util.defer(function () - self.locked[tp][source] = nil - end) -end - ---- 获取所有的引用 -function mt:eachRef(source, callback) - local lock <close> = self:lock('eachRef', source) - if not lock then - return - end - local cache = self.cache.eachRef[source] - if cache then - for i = 1, #cache do - callback(cache[i]) - end - return - end - cache = {} - self.cache.eachRef[source] = cache - local mark = {} - eachRef(self, source, function (info) - local src = info.source - if mark[src] then - return - end - mark[src] = true - cache[#cache+1] = info - callback(info) - end) -end - ---- 获取所有的field -function mt:eachField(source, callback) - local lock <close> = self:lock('eachField', source) - if not lock then - return - end - local cache = self.cache.eachField[source] - if cache then - for i = 1, #cache do - callback(cache[i]) - end - return - end - cache = {} - self.cache.eachField[source] = cache - local mark = {} - eachField(self, source, function (info) - local src = info.source - if mark[src] then - return - end - mark[src] = true - cache[#cache+1] = info - callback(info) - end) -end - ---- 获取特殊对象的名字 -function mt:getSpecialName(source) - local spName = self.cache.specialName[source] - if spName ~= nil then - if spName then - return spName - end - return nil - end - local function getName(src) - if src.type == 'getglobal' then - local node = src.node - if node.tag ~= '_ENV' then - return nil - end - local name = guide.getKeyName(src) - if name:sub(1, 2) ~= 's|' then - return nil - end - spName = name:sub(3) - if not specials[spName] then - spName = nil - end - elseif src.type == 'local' then - if src.tag == '_ENV' then - spName = '_G' - end - elseif src.type == 'getlocal' then - local loc = src.loc - if loc.tag == '_ENV' then - spName = '_G' - end - end - end - getName(source) - if not spName then - self:eachRef(source, function (info) - getName(info.source) - end) - end - self.cache.specialName[source] = spName or false - return spName -end - ---- 遍历特殊对象 ----@param callback fun(name:string, source:table) -function mt:eachSpecial(callback) - local cache = self.cache.specials - if cache then - for i = 1, #cache do - callback(cache[i][1], cache[i][2]) - end - return - end - cache = {} - self.cache.specials = cache - guide.eachSource(self.ast, function (source) - if source.type == 'getlocal' - or source.type == 'getglobal' - or source.type == 'local' - or source.type == 'field' - or source.type == 'string' then - local name = self:getSpecialName(source) - if name then - cache[#cache+1] = { name, source } - end - end - end) - for i = 1, #cache do - callback(cache[i][1], cache[i][2]) - end -end - ----@class engineer -local m = {} - ---- 新建搜索器 ----@param uri string ----@return searcher -function m.create(uri) - local ast = files.getAst(uri) - local searcher = setmetatable({ - ast = ast.ast, - uri = uri, - cache = { - eachRef = {}, - eachField = {}, - specialName = {}, - specials = nil, - }, - locked = { - eachRef = {}, - eachField = {}, - } - }, mt) - return searcher -end - -return m +local searcher = require 'searcher.searcher' +require 'searcher.eachField' +require 'searcher.eachRef' +return searcher diff --git a/server-beta/src/searcher/searcher.lua b/server-beta/src/searcher/searcher.lua new file mode 100644 index 00000000..7fdcd39d --- /dev/null +++ b/server-beta/src/searcher/searcher.lua @@ -0,0 +1,119 @@ +local guide = require 'parser.guide' +local util = require 'utility' + +local setmetatable = setmetatable +local assert = assert + +_ENV = nil + +local specials = { + ['_G'] = true, + ['rawset'] = true, + ['rawget'] = true, + ['setmetatable'] = true, + ['require'] = true, + ['dofile'] = true, + ['loadfile'] = true, +} + +---@class searcher +local m = {} + +function m.lock(tp, source) + if m.locked[tp][source] then + return nil + end + m.locked[tp][source] = true + return util.defer(function () + m.locked[tp][source] = nil + end) +end + +--- 获取特殊对象的名字 +function m.getSpecialName(source) + local spName = m.cache.specialName[source] + if spName ~= nil then + if spName then + return spName + end + return nil + end + local function getName(src) + if src.type == 'getglobal' then + local node = src.node + if node.tag ~= '_ENV' then + return nil + end + local name = guide.getKeyName(src) + if name:sub(1, 2) ~= 's|' then + return nil + end + spName = name:sub(3) + if not specials[spName] then + spName = nil + end + elseif src.type == 'local' then + if src.tag == '_ENV' then + spName = '_G' + end + elseif src.type == 'getlocal' then + local loc = src.loc + if loc.tag == '_ENV' then + spName = '_G' + end + end + end + getName(source) + if not spName then + m.eachRef(source, function (info) + getName(info.source) + end) + end + m.cache.specialName[source] = spName or false + return spName +end + +--- 遍历特殊对象 +---@param callback fun(name:string, source:table) +function m.eachSpecial(callback) + local cache = m.cache.specials + if cache then + for i = 1, #cache do + callback(cache[i][1], cache[i][2]) + end + return + end + cache = {} + m.cache.specials = cache + guide.eachSource(m.ast, function (source) + if source.type == 'getlocal' + or source.type == 'getglobal' + or source.type == 'local' + or source.type == 'field' + or source.type == 'string' then + local name = m.getSpecialName(source) + if name then + cache[#cache+1] = { name, source } + end + end + end) + for i = 1, #cache do + callback(cache[i][1], cache[i][2]) + end +end + +--- 刷新缓存 +function m.refreshCache() + m.cache = { + eachRef = {}, + eachField = {}, + specialName = {}, + specials = nil, + } + m.locked = { + eachRef = {}, + eachField = {}, + } +end + +return m |