summaryrefslogtreecommitdiff
path: root/script
diff options
context:
space:
mode:
Diffstat (limited to 'script')
-rw-r--r--script/core/linker.lua167
-rw-r--r--script/core/searcher.lua59
2 files changed, 111 insertions, 115 deletions
diff --git a/script/core/linker.lua b/script/core/linker.lua
index 41bdbaf9..54885c48 100644
--- a/script/core/linker.lua
+++ b/script/core/linker.lua
@@ -2,7 +2,7 @@ local util = require 'utility'
local guide = require 'parser.guide'
local vm = require 'vm.vm'
-local Linkers, CreateLink
+local Linkers, GetLink
local LastIDCache = {}
local SPLIT_CHAR = '\x1F'
local SPLIT_REGEX = SPLIT_CHAR .. '[^' .. SPLIT_CHAR .. ']+$'
@@ -185,20 +185,54 @@ local function getID(source)
return id
end
-local TempList = {}
+---添加关联单元
+---@param id string
+---@param source parser.guide.object
+local function pushSource(id, source)
+ local link = GetLink(id)
+ if not link.sources then
+ link.sources = {}
+ end
+ link.sources[#link.sources+1] = source
+end
+
+---添加关联的前进ID
+---@param id string
+---@param forwardID string
+local function pushForward(id, forwardID)
+ if not forwardID or forwardID == '' or id == forwardID then
+ return
+ end
+ local link = GetLink(id)
+ if not link.forward then
+ link.forward = {}
+ end
+ link.forward[#link.forward+1] = forwardID
+end
+
+---添加关联的后退ID
+---@param id string
+---@param backwardID string
+local function pushBackward(id, backwardID)
+ if not backwardID or backwardID == '' or id == backwardID then
+ return
+ end
+ local link = GetLink(id)
+ if not link.backward then
+ link.backward = {}
+ end
+ link.backward[#link.backward+1] = backwardID
+end
---前进
---@param source parser.guide.object
---@return parser.guide.object[]
local function checkForward(source)
- if not source then
- return
- end
- local list = TempList
+ local id = getID(source)
local parent = source.parent
if source.value then
-- x = y : x -> y
- list[#list+1] = getID(source.value)
+ pushForward(id, getID(source.value))
end
-- mt:f -> self
if parent.type == 'setmethod'
@@ -207,7 +241,7 @@ local function checkForward(source)
if func then
local self = func.locals[1]
if self.tag == 'self' then
- list[#list+1] = getID(self)
+ pushForward(id, getID(self))
end
end
end
@@ -216,7 +250,7 @@ local function checkForward(source)
local func = guide.getParentFunction(source)
local setmethod = func.parent
if setmethod and setmethod.type == 'setmethod' then
- list[#list+1] = getID(setmethod.node)
+ pushForward(id, getID(setmethod.node))
end
end
-- source 绑定的 @class/@type
@@ -225,20 +259,20 @@ local function checkForward(source)
for _, doc in ipairs(bindDocs) do
if doc.type == 'doc.class'
or doc.type == 'doc.type' then
- list[#list+1] = getID(doc)
+ pushForward(id, getID(doc))
end
end
end
-- 分解 @type
if source.type == 'doc.type' then
for _, typeUnit in ipairs(source.types) do
- list[#list+1] = getID(typeUnit)
+ pushForward(id, getID(typeUnit))
end
end
-- 分解 @class
if source.type == 'doc.class' then
- list[#list+1] = getID(source.class)
- list[#list+1] = getID(source.extends)
+ pushForward(id, getID(source.class))
+ pushForward(id, getID(source.extends))
end
-- 将call的返回值接收映射到函数返回值上
if source.type == 'select' then
@@ -251,7 +285,11 @@ local function checkForward(source)
RETURN_INDEX_CHAR,
source.index
)
- list[#list+1] = callID
+ pushForward(id, callID)
+ -- 将setmetatable映射到 param1 以及 param2.__index 上
+ if node.special == 'setmetatable' and source.index == 1 then
+ pushForward(id, getID())
+ end
end
end
-- 将函数的返回值映射到具体的返回值上
@@ -267,51 +305,40 @@ local function checkForward(source)
end
end
for index, rtnObjs in ipairs(returns) do
- local id = ('%s%s%s%s'):format(
+ local returnID = ('%s%s%s%s'):format(
getID(source),
SPLIT_CHAR,
RETURN_INDEX_CHAR,
index
)
- local link = CreateLink(id)
- link.forward = {}
for _, rtnObj in ipairs(rtnObjs) do
- link.forward[#link.forward+1] = getID(rtnObj)
+ pushForward(returnID, getID(rtnObj))
end
end
end
end
- if #list == 0 then
- return nil
- else
- TempList = {}
- return list
- end
end
---后退
---@param source parser.guide.object
---@return parser.guide.object[]
local function checkBackward(source)
- if not source then
- return
- end
- local list = TempList
+ local id = getID(source)
local parent = source.parent
if parent.value == source then
- list[#list+1] = getID(parent)
+ pushBackward(id, getID(parent))
end
-- name 映射回 class 与 type
if source.type == 'doc.class.name'
or source.type == 'doc.type.name' then
- list[#list+1] = getID(parent)
+ pushBackward(id, getID(parent))
end
-- class 与 type 绑定的 source
if source.type == 'doc.class'
or source.type == 'doc.type' then
if source.bindSources then
for _, src in ipairs(source.bindSources) do
- list[#list+1] = getID(src)
+ pushBackward(id, getID(src))
end
end
-- 将 @return 映射到函数返回值上
@@ -319,7 +346,7 @@ local function checkBackward(source)
for _, src in ipairs(parent.bindSources) do
if src.type == 'function' then
local fullID = ('%s%s%s%s'):format(getID(src), SPLIT_CHAR, RETURN_INDEX_CHAR, source.returnIndex)
- list[#list+1] = fullID
+ pushBackward(id, fullID)
end
end
end
@@ -328,7 +355,7 @@ local function checkBackward(source)
if parent.type == 'call' and parent.node == source then
local sel = parent.parent
if sel.type == 'select' then
- list[#list+1] = ('s:%d'):format(sel.start)
+ pushBackward(id, ('s:%d'):format(sel.start))
end
end
-- 将调用参数映射到函数调用上
@@ -341,12 +368,12 @@ local function checkBackward(source)
if not nodeID then
break
end
- list[#list+1] = ('%s%s%s%s'):format(
+ pushBackward(id, ('%s%s%s%s'):format(
nodeID,
SPLIT_CHAR,
PARAM_INDEX_CHAR,
i
- )
+ ))
break
end
end
@@ -354,46 +381,28 @@ local function checkBackward(source)
if source.type == 'doc.param' then
print(source)
end
- if #list == 0 then
- return nil
- else
- TempList = {}
- return list
- end
-end
-
----@param link link
-local function insertLinker(link)
- local id = link.id
- if not Linkers[id] then
- Linkers[id] = {}
- end
- Linkers[id][#Linkers[id]+1] = link
end
---@class link
-- 当前节点的id
---@field id string
--- 语法树单元
----@field source parser.guide.object
--- 前进的关联单元
----@field forward parser.guide.object[]
--- 后退的关联单元
----@field backward parser.guide.object[]
+-- 使用该ID的单元
+---@field sources parser.guide.object[]
+-- 前进的关联ID
+---@field forward string[]
+-- 后退的关联ID
+---@field backward string[]
---创建source的链接信息
---@param id string
----@param source? parser.guide.object
---@return link
-function CreateLink(id, source)
- local link = {
- id = id,
- source = source,
- forward = checkForward(source),
- backward = checkBackward(source),
- }
- insertLinker(link)
- return link
+function GetLink(id)
+ if not Linkers[id] then
+ Linkers[id] = {
+ id = id,
+ }
+ end
+ return Linkers[id]
end
local m = {}
@@ -405,8 +414,8 @@ m.PARAM_INDEX_CHAR = PARAM_INDEX_CHAR
---根据ID来获取所有的link
---@param root parser.guide.object
---@param id string
----@return link[]?
-function m.getLinksByID(root, id)
+---@return link?
+function m.getLinkByID(root, id)
root = guide.getRoot(root)
local linkers = root._linkers
if not linkers then
@@ -431,20 +440,6 @@ function m.getLastID(id)
return lastID
end
----获取source的链接信息
----@param source parser.guide.object
----@return link
-function m.getLink(source)
- local id = getID(source)
- if not id then
- return nil
- end
- if source._link == nil then
- source._link = CreateLink(id, source) or false
- end
- return source._link or nil
-end
-
---获取source的ID
---@param source parser.guide.object
---@return string
@@ -481,7 +476,13 @@ function m.compileLinks(source)
Linkers = {}
root._linkers = Linkers
guide.eachSource(root, function (src)
- m.getLink(src)
+ local id = getID(src)
+ if not id then
+ return
+ end
+ pushSource(id, src)
+ checkForward(src)
+ checkBackward(src)
end)
return Linkers
end
diff --git a/script/core/searcher.lua b/script/core/searcher.lua
index a22f9ab6..c784c3b8 100644
--- a/script/core/searcher.lua
+++ b/script/core/searcher.lua
@@ -33,6 +33,7 @@ function m.pushResult(status, mode, ref)
return
end
local results = status.results
+ local parent = ref.parent
if mode == 'def' then
if ref.type == 'local'
or ref.type == 'setlocal'
@@ -48,6 +49,11 @@ function m.pushResult(status, mode, ref)
or ref.type == 'doc.alias.name' then
results[#results+1] = ref
end
+ if parent.type == 'return' then
+ if linker.getID(ref) ~= status.id then
+ results[#results+1] = ref
+ end
+ end
elseif mode == 'ref' then
if ref.type == 'local'
or ref.type == 'setlocal'
@@ -146,6 +152,8 @@ function m.searchRefsByID(status, uri, expect, mode)
local root = ast.ast
linker.compileLinks(root)
+ status.id = expect
+
local mark = status.mark
local queueIDs = {}
local queueFields = {}
@@ -180,12 +188,12 @@ function m.searchRefsByID(status, uri, expect, mode)
end
local function searchFunction(id)
- local funcs = linker.getLinksByID(root, id)
- if not funcs then
+ local link = linker.getLinkByID(root, id)
+ if not link or not link.sources then
return
end
- local obj = funcs[1].source
- if obj.type ~= 'function' then
+ local obj = link.sources[1]
+ if not obj or obj.type ~= 'function' then
return
end
local returnIndex = checkFunctionReturn(obj)
@@ -203,27 +211,6 @@ function m.searchRefsByID(status, uri, expect, mode)
search(parentID, linker.SPLIT_CHAR .. linker.RETURN_INDEX_CHAR .. returnIndex)
end
- local function checkForward(link, field)
- if not link.forward then
- return
- end
- for _, id in ipairs(link.forward) do
- searchID(id, field)
- end
- end
-
- local function checkBackward(link, field)
- if not link.backward then
- return
- end
- if mode == 'def' and not field then
- return
- end
- for _, id in ipairs(link.backward) do
- searchID(id, field)
- end
- end
-
search(expect)
searchFunction(expect)
@@ -235,14 +222,22 @@ function m.searchRefsByID(status, uri, expect, mode)
local field = queueFields[index]
index = index - 1
- local links = linker.getLinksByID(root, id)
- if links then
- for _, eachLink in ipairs(links) do
- if field == nil then
- m.pushResult(status, mode, eachLink.source)
+ local link = linker.getLinkByID(root, id)
+ if link then
+ if field == nil and link.sources then
+ for _, source in ipairs(link.sources) do
+ m.pushResult(status, mode, source)
+ end
+ end
+ if link.forward then
+ for _, forwardID in ipairs(link.forward) do
+ searchID(forwardID, field)
+ end
+ end
+ if link.backward and (mode == 'ref' or field) then
+ for _, backwardID in ipairs(link.backward) do
+ searchID(backwardID, field)
end
- checkForward(eachLink, field)
- checkBackward(eachLink, field)
end
end
checkLastID(id, field)