summaryrefslogtreecommitdiff
path: root/script/core/searcher.lua
diff options
context:
space:
mode:
Diffstat (limited to 'script/core/searcher.lua')
-rw-r--r--script/core/searcher.lua389
1 files changed, 389 insertions, 0 deletions
diff --git a/script/core/searcher.lua b/script/core/searcher.lua
new file mode 100644
index 00000000..fc1c08d1
--- /dev/null
+++ b/script/core/searcher.lua
@@ -0,0 +1,389 @@
+local linker = require 'core.linker'
+local guide = require 'parser.guide'
+local files = require 'files'
+local generic = require 'core.generic'
+
+local function checkFunctionReturn(source)
+ if source.parent
+ and source.parent.type == 'return' then
+ if source.parent.parent.type == 'main' then
+ return 0
+ elseif source.parent.parent.type == 'function' then
+ for i = 1, #source.parent do
+ if source.parent[i] == source then
+ return i
+ end
+ end
+ end
+ end
+ return nil
+end
+
+local m = {}
+
+---@alias guide.searchmode '"ref"'|'"def"'|'"field"'
+
+---添加结果
+---@param status guide.status
+---@param mode guide.searchmode
+---@param source parser.guide.object
+function m.pushResult(status, mode, source)
+ if not source then
+ return
+ end
+ local results = status.results
+ local parent = source.parent
+ if mode == 'def' then
+ if source.type == 'local'
+ or source.type == 'setlocal'
+ or source.type == 'setglobal'
+ or source.type == 'label'
+ or source.type == 'setfield'
+ or source.type == 'setmethod'
+ or source.type == 'setindex'
+ or source.type == 'tableindex'
+ or source.type == 'tablefield'
+ or source.type == 'function'
+ or source.type == 'table'
+ or source.type == 'doc.class.name'
+ or source.type == 'doc.alias.name'
+ or source.type == 'doc.field.name'
+ or source.type == 'doc.type.function' then
+ results[#results+1] = source
+ return
+ end
+ if source.type == 'call' then
+ if source.node.special == 'rawset' then
+ results[#results+1] = source
+ end
+ end
+ if parent.type == 'return' then
+ if linker.getID(source) ~= status.id then
+ results[#results+1] = source
+ end
+ end
+ elseif mode == 'ref' then
+ if source.type == 'local'
+ or source.type == 'setlocal'
+ or source.type == 'getlocal'
+ or source.type == 'setglobal'
+ or source.type == 'getglobal'
+ or source.type == 'label'
+ or source.type == 'goto'
+ or source.type == 'setfield'
+ or source.type == 'getfield'
+ or source.type == 'setmethod'
+ or source.type == 'getmethod'
+ or source.type == 'setindex'
+ or source.type == 'getindex'
+ or source.type == 'tableindex'
+ or source.type == 'tablefield'
+ or source.type == 'function'
+ or source.type == 'table'
+ or source.type == 'doc.class.name'
+ or source.type == 'doc.type.name'
+ or source.type == 'doc.alias.name'
+ or source.type == 'doc.extends.name'
+ or source.type == 'doc.field.name'
+ or source.type == 'doc.type.function' then
+ results[#results+1] = source
+ return
+ end
+ if source.type == 'call' then
+ if source.node.special == 'rawset'
+ or source.node.special == 'rawget' then
+ results[#results+1] = source
+ end
+ end
+ elseif mode == 'field' then
+ end
+end
+
+---获取uri
+---@param obj parser.guide.object
+---@return uri
+function m.getUri(obj)
+ if obj.uri then
+ return obj.uri
+ end
+ local root = guide.getRoot(obj)
+ if root then
+ return root.uri
+ end
+ return ''
+end
+
+-- TODO
+function m.findGlobals(root)
+ linker.compileLinks(root)
+ -- TODO
+ return {}
+end
+
+-- TODO
+function m.isGlobal(source)
+ return false
+end
+
+---@param obj parser.guide.object
+---@return parser.guide.object?
+function m.getObjectValue(obj)
+ while obj.type == 'paren' do
+ obj = obj.exp
+ if not obj then
+ return nil
+ end
+ end
+ if obj.type == 'boolean'
+ or obj.type == 'number'
+ or obj.type == 'integer'
+ or obj.type == 'string'
+ or obj.type == 'doc.type.table'
+ or obj.type == 'doc.type.arrary' then
+ return obj
+ end
+ if obj.value then
+ return obj.value
+ end
+ if obj.type == 'field'
+ or obj.type == 'method' then
+ return obj.parent and obj.parent.value
+ end
+ if obj.type == 'call' then
+ if obj.node.special == 'rawset' then
+ return obj.args and obj.args[3]
+ else
+ return obj
+ end
+ end
+ if obj.type == 'select' then
+ return obj
+ end
+ return nil
+end
+
+function m.searchRefsByID(status, uri, expect, mode)
+ local ast = files.getAst(uri)
+ if not ast then
+ return
+ end
+ local root = ast.ast
+ local searchStep
+ linker.compileLinks(root)
+
+ status.id = expect
+
+ local mark = status.mark
+
+ local callStack = {}
+
+ local function search(id, field)
+ local fieldLen
+ if field then
+ local _, len = field:gsub(linker.SPLIT_CHAR, '')
+ fieldLen = len
+ else
+ fieldLen = 0
+ end
+ if mark[id] and ((mark[id] < fieldLen) or fieldLen == 0) then
+ return
+ end
+ mark[id] = fieldLen
+ searchStep(id, field)
+ end
+
+ local function checkLastID(id, field)
+ local lastID = linker.getLastID(id)
+ if lastID then
+ local newField = id:sub(#lastID + 1)
+ if field then
+ newField = newField .. field
+ end
+ search(lastID, newField)
+ end
+ end
+
+ local function searchID(id, field)
+ if not id then
+ return
+ end
+ if field then
+ id = id .. field
+ end
+ search(id, nil)
+ end
+
+ local function searchFunction(id)
+ local link = linker.getLinkByID(root, id)
+ if not link or not link.sources then
+ return
+ end
+ local obj = link.sources[1]
+ if not obj or obj.type ~= 'function' then
+ return
+ end
+ local returnIndex = checkFunctionReturn(obj)
+ if not returnIndex then
+ return
+ end
+ local func = guide.getParentFunction(obj)
+ if not func or func.type ~= 'function' then
+ return
+ end
+ local parentID = linker.getID(func)
+ if not parentID then
+ return
+ end
+ search(parentID, linker.SPLIT_CHAR .. linker.RETURN_INDEX_CHAR .. returnIndex)
+ end
+
+ local function isCallID(field)
+ if not field then
+ return false
+ end
+ if field:sub(1, 1) == linker.SPLIT_CHAR
+ and field:sub(2, 2) == linker.RETURN_INDEX_CHAR then
+ return true
+ end
+ return false
+ end
+
+ local function findLastCall()
+ for i = #callStack, 1, -1 do
+ local call = callStack[i]
+ if call then
+ -- 标记此处的call失效,等待在堆栈平衡时弹出
+ callStack[i] = false
+ return call
+ end
+ end
+ return nil
+ end
+
+ local function checkGeneric(source, field)
+ if not source.isGeneric then
+ return
+ end
+ if not isCallID(field) then
+ return
+ end
+ local call = findLastCall()
+ if not call then
+ return
+ end
+ local closure = generic.createClosure(source, call)
+ if not closure then
+ return
+ end
+ local id = linker.getID(closure)
+ searchID(id, field)
+ end
+
+ local stepCount = 0
+ function searchStep(id, field)
+ stepCount = stepCount + 1
+ if stepCount > 1000 then
+ error('too large')
+ end
+ local link = linker.getLinkByID(root, id)
+ if link then
+ if link.call then
+ callStack[#callStack+1] = link.call
+ end
+ if field == nil and link.sources then
+ for _, source in ipairs(link.sources) do
+ m.pushResult(status, mode, source)
+ end
+ end
+ if link.forward then
+ for _, forwardID in ipairs(link.forward) do
+ searchID(forwardID, field)
+ end
+ end
+ if link.backward and (mode == 'ref' or field) then
+ for _, backwardID in ipairs(link.backward) do
+ searchID(backwardID, field)
+ end
+ end
+
+ if link.sources then
+ checkGeneric(link.sources[1], field)
+ end
+
+ if link.call then
+ callStack[#callStack] = nil
+ end
+ end
+ checkLastID(id, field)
+ end
+
+ search(expect)
+ searchFunction(expect)
+end
+
+---搜索对象的引用
+---@param status guide.status
+---@param source parser.guide.object
+---@param mode guide.searchmode
+function m.searchRefs(status, source, mode)
+ if source.type == 'field'
+ or source.type == 'method' then
+ source = source.parent
+ end
+ local root = guide.getRoot(source)
+ linker.compileLinks(root)
+ local uri = guide.getUri(source)
+ local id = linker.getID(source)
+ if not id then
+ return
+ end
+
+ m.searchRefsByID(status, uri, id, mode)
+end
+
+---@class guide.status
+---搜索结果
+---@field results parser.guide.object[]
+
+---创建搜索状态
+---@param parentStatus guide.status
+---@param interface table
+---@param deep integer
+---@return guide.status
+function m.status(parentStatus, interface, deep)
+ local status = {
+ mark = parentStatus and parentStatus.mark or {},
+ results = {},
+ }
+ return status
+end
+
+--- 请求对象的引用
+---@param obj parser.guide.object
+---@param interface table
+---@param deep integer
+---@return parser.guide.object[]
+---@return integer
+function m.requestReference(obj, interface, deep)
+ local status = m.status(nil, interface, deep)
+ -- 根据 field 搜索引用
+ m.searchRefs(status, obj, 'ref')
+
+ return status.results, 0
+end
+
+--- 请求对象的定义
+---@param obj parser.guide.object
+---@param interface table
+---@param deep integer
+---@return parser.guide.object[]
+---@return integer
+function m.requestDefinition(obj, interface, deep)
+ local status = m.status(nil, interface, deep)
+ -- 根据 field 搜索引用
+ m.searchRefs(status, obj, 'def')
+
+ return status.results, 0
+end
+
+return m