diff options
Diffstat (limited to 'script')
-rw-r--r-- | script/core/linker.lua | 167 | ||||
-rw-r--r-- | script/core/searcher.lua | 59 |
2 files changed, 111 insertions, 115 deletions
diff --git a/script/core/linker.lua b/script/core/linker.lua index 41bdbaf9..54885c48 100644 --- a/script/core/linker.lua +++ b/script/core/linker.lua @@ -2,7 +2,7 @@ local util = require 'utility' local guide = require 'parser.guide' local vm = require 'vm.vm' -local Linkers, CreateLink +local Linkers, GetLink local LastIDCache = {} local SPLIT_CHAR = '\x1F' local SPLIT_REGEX = SPLIT_CHAR .. '[^' .. SPLIT_CHAR .. ']+$' @@ -185,20 +185,54 @@ local function getID(source) return id end -local TempList = {} +---添加关联单元 +---@param id string +---@param source parser.guide.object +local function pushSource(id, source) + local link = GetLink(id) + if not link.sources then + link.sources = {} + end + link.sources[#link.sources+1] = source +end + +---添加关联的前进ID +---@param id string +---@param forwardID string +local function pushForward(id, forwardID) + if not forwardID or forwardID == '' or id == forwardID then + return + end + local link = GetLink(id) + if not link.forward then + link.forward = {} + end + link.forward[#link.forward+1] = forwardID +end + +---添加关联的后退ID +---@param id string +---@param backwardID string +local function pushBackward(id, backwardID) + if not backwardID or backwardID == '' or id == backwardID then + return + end + local link = GetLink(id) + if not link.backward then + link.backward = {} + end + link.backward[#link.backward+1] = backwardID +end ---前进 ---@param source parser.guide.object ---@return parser.guide.object[] local function checkForward(source) - if not source then - return - end - local list = TempList + local id = getID(source) local parent = source.parent if source.value then -- x = y : x -> y - list[#list+1] = getID(source.value) + pushForward(id, getID(source.value)) end -- mt:f -> self if parent.type == 'setmethod' @@ -207,7 +241,7 @@ local function checkForward(source) if func then local self = func.locals[1] if self.tag == 'self' then - list[#list+1] = getID(self) + pushForward(id, getID(self)) end end end @@ -216,7 +250,7 @@ local function checkForward(source) local func = guide.getParentFunction(source) local setmethod = func.parent if setmethod and setmethod.type == 'setmethod' then - list[#list+1] = getID(setmethod.node) + pushForward(id, getID(setmethod.node)) end end -- source 绑定的 @class/@type @@ -225,20 +259,20 @@ local function checkForward(source) for _, doc in ipairs(bindDocs) do if doc.type == 'doc.class' or doc.type == 'doc.type' then - list[#list+1] = getID(doc) + pushForward(id, getID(doc)) end end end -- 分解 @type if source.type == 'doc.type' then for _, typeUnit in ipairs(source.types) do - list[#list+1] = getID(typeUnit) + pushForward(id, getID(typeUnit)) end end -- 分解 @class if source.type == 'doc.class' then - list[#list+1] = getID(source.class) - list[#list+1] = getID(source.extends) + pushForward(id, getID(source.class)) + pushForward(id, getID(source.extends)) end -- 将call的返回值接收映射到函数返回值上 if source.type == 'select' then @@ -251,7 +285,11 @@ local function checkForward(source) RETURN_INDEX_CHAR, source.index ) - list[#list+1] = callID + pushForward(id, callID) + -- 将setmetatable映射到 param1 以及 param2.__index 上 + if node.special == 'setmetatable' and source.index == 1 then + pushForward(id, getID()) + end end end -- 将函数的返回值映射到具体的返回值上 @@ -267,51 +305,40 @@ local function checkForward(source) end end for index, rtnObjs in ipairs(returns) do - local id = ('%s%s%s%s'):format( + local returnID = ('%s%s%s%s'):format( getID(source), SPLIT_CHAR, RETURN_INDEX_CHAR, index ) - local link = CreateLink(id) - link.forward = {} for _, rtnObj in ipairs(rtnObjs) do - link.forward[#link.forward+1] = getID(rtnObj) + pushForward(returnID, getID(rtnObj)) end end end end - if #list == 0 then - return nil - else - TempList = {} - return list - end end ---后退 ---@param source parser.guide.object ---@return parser.guide.object[] local function checkBackward(source) - if not source then - return - end - local list = TempList + local id = getID(source) local parent = source.parent if parent.value == source then - list[#list+1] = getID(parent) + pushBackward(id, getID(parent)) end -- name 映射回 class 与 type if source.type == 'doc.class.name' or source.type == 'doc.type.name' then - list[#list+1] = getID(parent) + pushBackward(id, getID(parent)) end -- class 与 type 绑定的 source if source.type == 'doc.class' or source.type == 'doc.type' then if source.bindSources then for _, src in ipairs(source.bindSources) do - list[#list+1] = getID(src) + pushBackward(id, getID(src)) end end -- 将 @return 映射到函数返回值上 @@ -319,7 +346,7 @@ local function checkBackward(source) for _, src in ipairs(parent.bindSources) do if src.type == 'function' then local fullID = ('%s%s%s%s'):format(getID(src), SPLIT_CHAR, RETURN_INDEX_CHAR, source.returnIndex) - list[#list+1] = fullID + pushBackward(id, fullID) end end end @@ -328,7 +355,7 @@ local function checkBackward(source) if parent.type == 'call' and parent.node == source then local sel = parent.parent if sel.type == 'select' then - list[#list+1] = ('s:%d'):format(sel.start) + pushBackward(id, ('s:%d'):format(sel.start)) end end -- 将调用参数映射到函数调用上 @@ -341,12 +368,12 @@ local function checkBackward(source) if not nodeID then break end - list[#list+1] = ('%s%s%s%s'):format( + pushBackward(id, ('%s%s%s%s'):format( nodeID, SPLIT_CHAR, PARAM_INDEX_CHAR, i - ) + )) break end end @@ -354,46 +381,28 @@ local function checkBackward(source) if source.type == 'doc.param' then print(source) end - if #list == 0 then - return nil - else - TempList = {} - return list - end -end - ----@param link link -local function insertLinker(link) - local id = link.id - if not Linkers[id] then - Linkers[id] = {} - end - Linkers[id][#Linkers[id]+1] = link end ---@class link -- 当前节点的id ---@field id string --- 语法树单元 ----@field source parser.guide.object --- 前进的关联单元 ----@field forward parser.guide.object[] --- 后退的关联单元 ----@field backward parser.guide.object[] +-- 使用该ID的单元 +---@field sources parser.guide.object[] +-- 前进的关联ID +---@field forward string[] +-- 后退的关联ID +---@field backward string[] ---创建source的链接信息 ---@param id string ----@param source? parser.guide.object ---@return link -function CreateLink(id, source) - local link = { - id = id, - source = source, - forward = checkForward(source), - backward = checkBackward(source), - } - insertLinker(link) - return link +function GetLink(id) + if not Linkers[id] then + Linkers[id] = { + id = id, + } + end + return Linkers[id] end local m = {} @@ -405,8 +414,8 @@ m.PARAM_INDEX_CHAR = PARAM_INDEX_CHAR ---根据ID来获取所有的link ---@param root parser.guide.object ---@param id string ----@return link[]? -function m.getLinksByID(root, id) +---@return link? +function m.getLinkByID(root, id) root = guide.getRoot(root) local linkers = root._linkers if not linkers then @@ -431,20 +440,6 @@ function m.getLastID(id) return lastID end ----获取source的链接信息 ----@param source parser.guide.object ----@return link -function m.getLink(source) - local id = getID(source) - if not id then - return nil - end - if source._link == nil then - source._link = CreateLink(id, source) or false - end - return source._link or nil -end - ---获取source的ID ---@param source parser.guide.object ---@return string @@ -481,7 +476,13 @@ function m.compileLinks(source) Linkers = {} root._linkers = Linkers guide.eachSource(root, function (src) - m.getLink(src) + local id = getID(src) + if not id then + return + end + pushSource(id, src) + checkForward(src) + checkBackward(src) end) return Linkers end diff --git a/script/core/searcher.lua b/script/core/searcher.lua index a22f9ab6..c784c3b8 100644 --- a/script/core/searcher.lua +++ b/script/core/searcher.lua @@ -33,6 +33,7 @@ function m.pushResult(status, mode, ref) return end local results = status.results + local parent = ref.parent if mode == 'def' then if ref.type == 'local' or ref.type == 'setlocal' @@ -48,6 +49,11 @@ function m.pushResult(status, mode, ref) or ref.type == 'doc.alias.name' then results[#results+1] = ref end + if parent.type == 'return' then + if linker.getID(ref) ~= status.id then + results[#results+1] = ref + end + end elseif mode == 'ref' then if ref.type == 'local' or ref.type == 'setlocal' @@ -146,6 +152,8 @@ function m.searchRefsByID(status, uri, expect, mode) local root = ast.ast linker.compileLinks(root) + status.id = expect + local mark = status.mark local queueIDs = {} local queueFields = {} @@ -180,12 +188,12 @@ function m.searchRefsByID(status, uri, expect, mode) end local function searchFunction(id) - local funcs = linker.getLinksByID(root, id) - if not funcs then + local link = linker.getLinkByID(root, id) + if not link or not link.sources then return end - local obj = funcs[1].source - if obj.type ~= 'function' then + local obj = link.sources[1] + if not obj or obj.type ~= 'function' then return end local returnIndex = checkFunctionReturn(obj) @@ -203,27 +211,6 @@ function m.searchRefsByID(status, uri, expect, mode) search(parentID, linker.SPLIT_CHAR .. linker.RETURN_INDEX_CHAR .. returnIndex) end - local function checkForward(link, field) - if not link.forward then - return - end - for _, id in ipairs(link.forward) do - searchID(id, field) - end - end - - local function checkBackward(link, field) - if not link.backward then - return - end - if mode == 'def' and not field then - return - end - for _, id in ipairs(link.backward) do - searchID(id, field) - end - end - search(expect) searchFunction(expect) @@ -235,14 +222,22 @@ function m.searchRefsByID(status, uri, expect, mode) local field = queueFields[index] index = index - 1 - local links = linker.getLinksByID(root, id) - if links then - for _, eachLink in ipairs(links) do - if field == nil then - m.pushResult(status, mode, eachLink.source) + local link = linker.getLinkByID(root, id) + if link then + if field == nil and link.sources then + for _, source in ipairs(link.sources) do + m.pushResult(status, mode, source) + end + end + if link.forward then + for _, forwardID in ipairs(link.forward) do + searchID(forwardID, field) + end + end + if link.backward and (mode == 'ref' or field) then + for _, backwardID in ipairs(link.backward) do + searchID(backwardID, field) end - checkForward(eachLink, field) - checkBackward(eachLink, field) end end checkLastID(id, field) |