summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--script/core/noder.lua76
-rw-r--r--script/core/searcher.lua58
-rw-r--r--test/definition/bug.lua57
3 files changed, 180 insertions, 11 deletions
diff --git a/script/core/noder.lua b/script/core/noder.lua
index 8e2e0009..dc846cf3 100644
--- a/script/core/noder.lua
+++ b/script/core/noder.lua
@@ -94,6 +94,10 @@ local INFO_DEEP_AND_DONT_CROSS = {
---@field binfo? table<node.id, node.info>
-- 后退的关联ID与info
---@field backwards table<node.id, node.id[]|table<node.id, node.info>>
+-- 第一个继承
+---@field extend table<node.id, node.id>
+-- 其他继承
+---@field extends table<node.id, node.id[]>
-- 函数调用参数信息(用于泛型)
---@field call table<node.id, parser.guide.object>
---@field require table<node.id, string>
@@ -548,8 +552,8 @@ end
---添加关联的前进ID
---@param noders noders
----@param id string
----@param forwardID string
+---@param id node.id
+---@param forwardID node.id
---@param info? node.info
local function pushForward(noders, id, forwardID, info)
if not id
@@ -580,8 +584,8 @@ end
---添加关联的后退ID
---@param noders noders
----@param id string
----@param backwardID string
+---@param id node.id
+---@param backwardID node.id
---@param info? node.info
local function pushBackward(noders, id, backwardID, info)
if not id
@@ -610,6 +614,36 @@ local function pushBackward(noders, id, backwardID, info)
backwards[#backwards+1] = backwardID
end
+---添加继承的关联ID
+---@param noders noders
+---@param id node.id
+---@param extendID node.id
+local function pushExtend(noders, id, extendID)
+ if not id
+ or not extendID
+ or extendID == ''
+ or id == extendID then
+ return
+ end
+ if not noders.extend[id] then
+ noders.extend[id] = extendID
+ return
+ end
+ if noders.extend[id] == extendID then
+ return
+ end
+ local extends = noders.extends[id]
+ if not extends then
+ extends = {}
+ noders.extends[id] = extends
+ end
+ if extends[extendID] ~= nil then
+ return
+ end
+ extends[extendID] = false
+ extends[#extends+1] = extendID
+end
+
---@class noder
local m = {}
@@ -750,6 +784,31 @@ function m.eachBackward(noders, id)
end
end
+---遍历extend
+---@param noders noders
+---@param id node.id
+---@return fun():string, node.info
+function m.eachExtend(noders, id)
+ local extend = noders.extend[id]
+ if not extend then
+ return DUMMY_FUNCTION
+ end
+ local index
+ local extends = noders.extends[id]
+ return function ()
+ if not index then
+ index = 0
+ return extend
+ end
+ if not extends then
+ return nil
+ end
+ index = index + 1
+ local id = extends[index]
+ return id
+ end
+end
+
local function bindValue(noders, source, id)
local value = source.value
if not value then
@@ -1049,7 +1108,7 @@ compileNodeMap = util.switch()
pushForward(noders, getID(source.class), id)
if source.extends then
for _, ext in ipairs(source.extends) do
- pushForward(noders, id, getID(ext), INFO_CLASS_TO_EXNTENDS)
+ pushExtend(noders, id, getID(ext))
end
end
if source.bindSources then
@@ -1517,6 +1576,11 @@ function m.getLastID(id)
return lastID
end
+function m.getFieldID(id)
+ local fieldID = smatch(id, LAST_REGEX)
+ return fieldID
+end
+
---获取ID的长度
---@param id string
---@return integer
@@ -1641,6 +1705,8 @@ function m.getNoders(source)
backward = {},
binfo = {},
backwards = {},
+ extend = {},
+ extends = {},
call = {},
require = {},
skip = {},
diff --git a/script/core/searcher.lua b/script/core/searcher.lua
index e5de6395..9c0f0faa 100644
--- a/script/core/searcher.lua
+++ b/script/core/searcher.lua
@@ -40,6 +40,7 @@ local getHeadID = noder.getHeadID
local eachForward = noder.eachForward
local getUriAndID = noder.getUriAndID
local eachBackward = noder.eachBackward
+local eachExtend = noder.eachExtend
local eachSource = noder.eachSource
local compileAllNodes = noder.compileAllNodes
local compilePartNoders = noder.compilePartNodes
@@ -192,17 +193,17 @@ local pushRefResultsMap = util.switch()
---@param force boolean
local function pushResult(status, mode, source, force)
if not source then
- return
+ return false
end
local results = status.results
local mark = status.rmark
if mark[source] then
- return
+ return true
end
mark[source] = true
if force then
results[#results+1] = source
- return
+ return true
end
if mode == 'def'
@@ -210,7 +211,7 @@ local function pushResult(status, mode, source, force)
local f = pushDefResultsMap[source.type]
if f and f(source, status) then
results[#results+1] = source
- return
+ return true
end
elseif mode == 'ref'
or mode == 'field'
@@ -219,7 +220,7 @@ local function pushResult(status, mode, source, force)
local f = pushRefResultsMap[source.type]
if f and f(source, status) then
results[#results+1] = source
- return
+ return true
end
end
@@ -227,9 +228,11 @@ local function pushResult(status, mode, source, force)
if parent.type == 'return' then
if source ~= status.source then
results[#results+1] = source
- return
+ return true
end
end
+
+ return false
end
---@param obj parser.guide.object
@@ -737,6 +740,46 @@ function m.searchRefsByID(status, suri, expect, mode)
end
end
+ local function checkExtend(uri, id, field)
+ if not field
+ and mode ~= 'field'
+ and mode ~= 'allfield' then
+ return
+ end
+ if field then
+ local results = status.results
+ for i = #results, 1, -1 do
+ local res = results[i]
+ if res.type == 'setfield'
+ or res.type == 'setmethod'
+ or res.type == 'setindex' then
+ local resField = noder.getFieldID(getID(res))
+ if field == resField then
+ return
+ end
+ end
+ if res.type == 'doc.field.name' then
+ local resField = STRING_FIELD .. res[1]
+ if field == resField then
+ return
+ end
+ end
+ end
+ end
+ for extendID in eachExtend(nodersMap[uri], id) do
+ local targetUri, targetID
+
+ targetUri, targetID = getUriAndID(extendID)
+ if targetUri and targetUri ~= uri then
+ if dontCross == 0 then
+ searchID(targetUri, targetID, field, uri)
+ end
+ else
+ searchID(uri, targetID or extendID, field)
+ end
+ end
+ end
+
local function searchSpecial(uri, id, field)
-- Special rule: ('').XX -> stringlib.XX
if id == 'str:'
@@ -892,6 +935,9 @@ function m.searchRefsByID(status, suri, expect, mode)
if noders.backward[id] then
checkBackward(uri, id, field)
end
+ if noders.extend[id] then
+ checkExtend(uri, id, field)
+ end
releaseExpanding(elock, ecall, id, field)
end
diff --git a/test/definition/bug.lua b/test/definition/bug.lua
index 7a2cc789..9ea47fd1 100644
--- a/test/definition/bug.lua
+++ b/test/definition/bug.lua
@@ -253,3 +253,60 @@ local <!v!> = t[a]
t[a] = <?v?>
]]
+
+TEST [[
+---@class A
+---@field x number
+
+---@class B: A
+---@field <!x!> boolean
+
+---@type B
+local t
+
+local <!<?v?>!> = t.x
+]]
+
+TEST [[
+---@class A
+---@field <!x!> number
+
+---@class B: A
+
+---@type B
+local t
+
+local <!<?v?>!> = t.x
+]]
+
+TEST [[
+---@class A
+local A
+
+function A:x() end
+
+---@class B: A
+local B
+
+function B:<!x!>() end
+
+---@type B
+local t
+
+local <!<?v?>!> = t.x
+]]
+
+TEST [[
+---@class A
+local A
+
+function A:<!x!>() end
+
+---@class B: A
+local B
+
+---@type B
+local t
+
+local <!<?v?>!> = t.x
+]]