diff options
Diffstat (limited to 'script/core/searcher.lua')
-rw-r--r-- | script/core/searcher.lua | 389 |
1 files changed, 389 insertions, 0 deletions
diff --git a/script/core/searcher.lua b/script/core/searcher.lua new file mode 100644 index 00000000..fc1c08d1 --- /dev/null +++ b/script/core/searcher.lua @@ -0,0 +1,389 @@ +local linker = require 'core.linker' +local guide = require 'parser.guide' +local files = require 'files' +local generic = require 'core.generic' + +local function checkFunctionReturn(source) + if source.parent + and source.parent.type == 'return' then + if source.parent.parent.type == 'main' then + return 0 + elseif source.parent.parent.type == 'function' then + for i = 1, #source.parent do + if source.parent[i] == source then + return i + end + end + end + end + return nil +end + +local m = {} + +---@alias guide.searchmode '"ref"'|'"def"'|'"field"' + +---添加结果 +---@param status guide.status +---@param mode guide.searchmode +---@param source parser.guide.object +function m.pushResult(status, mode, source) + if not source then + return + end + local results = status.results + local parent = source.parent + if mode == 'def' then + if source.type == 'local' + or source.type == 'setlocal' + or source.type == 'setglobal' + or source.type == 'label' + or source.type == 'setfield' + or source.type == 'setmethod' + or source.type == 'setindex' + or source.type == 'tableindex' + or source.type == 'tablefield' + or source.type == 'function' + or source.type == 'table' + or source.type == 'doc.class.name' + or source.type == 'doc.alias.name' + or source.type == 'doc.field.name' + or source.type == 'doc.type.function' then + results[#results+1] = source + return + end + if source.type == 'call' then + if source.node.special == 'rawset' then + results[#results+1] = source + end + end + if parent.type == 'return' then + if linker.getID(source) ~= status.id then + results[#results+1] = source + end + end + elseif mode == 'ref' then + if source.type == 'local' + or source.type == 'setlocal' + or source.type == 'getlocal' + or source.type == 'setglobal' + or source.type == 'getglobal' + or source.type == 'label' + or source.type == 'goto' + or source.type == 'setfield' + or source.type == 'getfield' + or source.type == 'setmethod' + or source.type == 'getmethod' + or source.type == 'setindex' + or source.type == 'getindex' + or source.type == 'tableindex' + or source.type == 'tablefield' + or source.type == 'function' + or source.type == 'table' + or source.type == 'doc.class.name' + or source.type == 'doc.type.name' + or source.type == 'doc.alias.name' + or source.type == 'doc.extends.name' + or source.type == 'doc.field.name' + or source.type == 'doc.type.function' then + results[#results+1] = source + return + end + if source.type == 'call' then + if source.node.special == 'rawset' + or source.node.special == 'rawget' then + results[#results+1] = source + end + end + elseif mode == 'field' then + end +end + +---获取uri +---@param obj parser.guide.object +---@return uri +function m.getUri(obj) + if obj.uri then + return obj.uri + end + local root = guide.getRoot(obj) + if root then + return root.uri + end + return '' +end + +-- TODO +function m.findGlobals(root) + linker.compileLinks(root) + -- TODO + return {} +end + +-- TODO +function m.isGlobal(source) + return false +end + +---@param obj parser.guide.object +---@return parser.guide.object? +function m.getObjectValue(obj) + while obj.type == 'paren' do + obj = obj.exp + if not obj then + return nil + end + end + if obj.type == 'boolean' + or obj.type == 'number' + or obj.type == 'integer' + or obj.type == 'string' + or obj.type == 'doc.type.table' + or obj.type == 'doc.type.arrary' then + return obj + end + if obj.value then + return obj.value + end + if obj.type == 'field' + or obj.type == 'method' then + return obj.parent and obj.parent.value + end + if obj.type == 'call' then + if obj.node.special == 'rawset' then + return obj.args and obj.args[3] + else + return obj + end + end + if obj.type == 'select' then + return obj + end + return nil +end + +function m.searchRefsByID(status, uri, expect, mode) + local ast = files.getAst(uri) + if not ast then + return + end + local root = ast.ast + local searchStep + linker.compileLinks(root) + + status.id = expect + + local mark = status.mark + + local callStack = {} + + local function search(id, field) + local fieldLen + if field then + local _, len = field:gsub(linker.SPLIT_CHAR, '') + fieldLen = len + else + fieldLen = 0 + end + if mark[id] and ((mark[id] < fieldLen) or fieldLen == 0) then + return + end + mark[id] = fieldLen + searchStep(id, field) + end + + local function checkLastID(id, field) + local lastID = linker.getLastID(id) + if lastID then + local newField = id:sub(#lastID + 1) + if field then + newField = newField .. field + end + search(lastID, newField) + end + end + + local function searchID(id, field) + if not id then + return + end + if field then + id = id .. field + end + search(id, nil) + end + + local function searchFunction(id) + local link = linker.getLinkByID(root, id) + if not link or not link.sources then + return + end + local obj = link.sources[1] + if not obj or obj.type ~= 'function' then + return + end + local returnIndex = checkFunctionReturn(obj) + if not returnIndex then + return + end + local func = guide.getParentFunction(obj) + if not func or func.type ~= 'function' then + return + end + local parentID = linker.getID(func) + if not parentID then + return + end + search(parentID, linker.SPLIT_CHAR .. linker.RETURN_INDEX_CHAR .. returnIndex) + end + + local function isCallID(field) + if not field then + return false + end + if field:sub(1, 1) == linker.SPLIT_CHAR + and field:sub(2, 2) == linker.RETURN_INDEX_CHAR then + return true + end + return false + end + + local function findLastCall() + for i = #callStack, 1, -1 do + local call = callStack[i] + if call then + -- 标记此处的call失效,等待在堆栈平衡时弹出 + callStack[i] = false + return call + end + end + return nil + end + + local function checkGeneric(source, field) + if not source.isGeneric then + return + end + if not isCallID(field) then + return + end + local call = findLastCall() + if not call then + return + end + local closure = generic.createClosure(source, call) + if not closure then + return + end + local id = linker.getID(closure) + searchID(id, field) + end + + local stepCount = 0 + function searchStep(id, field) + stepCount = stepCount + 1 + if stepCount > 1000 then + error('too large') + end + local link = linker.getLinkByID(root, id) + if link then + if link.call then + callStack[#callStack+1] = link.call + end + if field == nil and link.sources then + for _, source in ipairs(link.sources) do + m.pushResult(status, mode, source) + end + end + if link.forward then + for _, forwardID in ipairs(link.forward) do + searchID(forwardID, field) + end + end + if link.backward and (mode == 'ref' or field) then + for _, backwardID in ipairs(link.backward) do + searchID(backwardID, field) + end + end + + if link.sources then + checkGeneric(link.sources[1], field) + end + + if link.call then + callStack[#callStack] = nil + end + end + checkLastID(id, field) + end + + search(expect) + searchFunction(expect) +end + +---搜索对象的引用 +---@param status guide.status +---@param source parser.guide.object +---@param mode guide.searchmode +function m.searchRefs(status, source, mode) + if source.type == 'field' + or source.type == 'method' then + source = source.parent + end + local root = guide.getRoot(source) + linker.compileLinks(root) + local uri = guide.getUri(source) + local id = linker.getID(source) + if not id then + return + end + + m.searchRefsByID(status, uri, id, mode) +end + +---@class guide.status +---搜索结果 +---@field results parser.guide.object[] + +---创建搜索状态 +---@param parentStatus guide.status +---@param interface table +---@param deep integer +---@return guide.status +function m.status(parentStatus, interface, deep) + local status = { + mark = parentStatus and parentStatus.mark or {}, + results = {}, + } + return status +end + +--- 请求对象的引用 +---@param obj parser.guide.object +---@param interface table +---@param deep integer +---@return parser.guide.object[] +---@return integer +function m.requestReference(obj, interface, deep) + local status = m.status(nil, interface, deep) + -- 根据 field 搜索引用 + m.searchRefs(status, obj, 'ref') + + return status.results, 0 +end + +--- 请求对象的定义 +---@param obj parser.guide.object +---@param interface table +---@param deep integer +---@return parser.guide.object[] +---@return integer +function m.requestDefinition(obj, interface, deep) + local status = m.status(nil, interface, deep) + -- 根据 field 搜索引用 + m.searchRefs(status, obj, 'def') + + return status.results, 0 +end + +return m |