diff options
author | 最萌小汐 <sumneko@hotmail.com> | 2021-12-02 17:40:15 +0800 |
---|---|---|
committer | 最萌小汐 <sumneko@hotmail.com> | 2021-12-02 17:40:15 +0800 |
commit | 9f236d0485ac4cee2b62261b84ef93f473596495 (patch) | |
tree | 8e0ffdafd335042c6dc9f6f6ba30ffd616b654f2 | |
parent | b48f0819008eb50320660f34f316f1b9bde4eb59 (diff) | |
download | lua-language-server-9f236d0485ac4cee2b62261b84ef93f473596495.zip |
#832
-rw-r--r-- | script/core/noder.lua | 76 | ||||
-rw-r--r-- | script/core/searcher.lua | 58 | ||||
-rw-r--r-- | test/definition/bug.lua | 57 |
3 files changed, 180 insertions, 11 deletions
diff --git a/script/core/noder.lua b/script/core/noder.lua index 8e2e0009..dc846cf3 100644 --- a/script/core/noder.lua +++ b/script/core/noder.lua @@ -94,6 +94,10 @@ local INFO_DEEP_AND_DONT_CROSS = { ---@field binfo? table<node.id, node.info> -- 后退的关联ID与info ---@field backwards table<node.id, node.id[]|table<node.id, node.info>> +-- 第一个继承 +---@field extend table<node.id, node.id> +-- 其他继承 +---@field extends table<node.id, node.id[]> -- 函数调用参数信息(用于泛型) ---@field call table<node.id, parser.guide.object> ---@field require table<node.id, string> @@ -548,8 +552,8 @@ end ---添加关联的前进ID ---@param noders noders ----@param id string ----@param forwardID string +---@param id node.id +---@param forwardID node.id ---@param info? node.info local function pushForward(noders, id, forwardID, info) if not id @@ -580,8 +584,8 @@ end ---添加关联的后退ID ---@param noders noders ----@param id string ----@param backwardID string +---@param id node.id +---@param backwardID node.id ---@param info? node.info local function pushBackward(noders, id, backwardID, info) if not id @@ -610,6 +614,36 @@ local function pushBackward(noders, id, backwardID, info) backwards[#backwards+1] = backwardID end +---添加继承的关联ID +---@param noders noders +---@param id node.id +---@param extendID node.id +local function pushExtend(noders, id, extendID) + if not id + or not extendID + or extendID == '' + or id == extendID then + return + end + if not noders.extend[id] then + noders.extend[id] = extendID + return + end + if noders.extend[id] == extendID then + return + end + local extends = noders.extends[id] + if not extends then + extends = {} + noders.extends[id] = extends + end + if extends[extendID] ~= nil then + return + end + extends[extendID] = false + extends[#extends+1] = extendID +end + ---@class noder local m = {} @@ -750,6 +784,31 @@ function m.eachBackward(noders, id) end end +---遍历extend +---@param noders noders +---@param id node.id +---@return fun():string, node.info +function m.eachExtend(noders, id) + local extend = noders.extend[id] + if not extend then + return DUMMY_FUNCTION + end + local index + local extends = noders.extends[id] + return function () + if not index then + index = 0 + return extend + end + if not extends then + return nil + end + index = index + 1 + local id = extends[index] + return id + end +end + local function bindValue(noders, source, id) local value = source.value if not value then @@ -1049,7 +1108,7 @@ compileNodeMap = util.switch() pushForward(noders, getID(source.class), id) if source.extends then for _, ext in ipairs(source.extends) do - pushForward(noders, id, getID(ext), INFO_CLASS_TO_EXNTENDS) + pushExtend(noders, id, getID(ext)) end end if source.bindSources then @@ -1517,6 +1576,11 @@ function m.getLastID(id) return lastID end +function m.getFieldID(id) + local fieldID = smatch(id, LAST_REGEX) + return fieldID +end + ---获取ID的长度 ---@param id string ---@return integer @@ -1641,6 +1705,8 @@ function m.getNoders(source) backward = {}, binfo = {}, backwards = {}, + extend = {}, + extends = {}, call = {}, require = {}, skip = {}, diff --git a/script/core/searcher.lua b/script/core/searcher.lua index e5de6395..9c0f0faa 100644 --- a/script/core/searcher.lua +++ b/script/core/searcher.lua @@ -40,6 +40,7 @@ local getHeadID = noder.getHeadID local eachForward = noder.eachForward local getUriAndID = noder.getUriAndID local eachBackward = noder.eachBackward +local eachExtend = noder.eachExtend local eachSource = noder.eachSource local compileAllNodes = noder.compileAllNodes local compilePartNoders = noder.compilePartNodes @@ -192,17 +193,17 @@ local pushRefResultsMap = util.switch() ---@param force boolean local function pushResult(status, mode, source, force) if not source then - return + return false end local results = status.results local mark = status.rmark if mark[source] then - return + return true end mark[source] = true if force then results[#results+1] = source - return + return true end if mode == 'def' @@ -210,7 +211,7 @@ local function pushResult(status, mode, source, force) local f = pushDefResultsMap[source.type] if f and f(source, status) then results[#results+1] = source - return + return true end elseif mode == 'ref' or mode == 'field' @@ -219,7 +220,7 @@ local function pushResult(status, mode, source, force) local f = pushRefResultsMap[source.type] if f and f(source, status) then results[#results+1] = source - return + return true end end @@ -227,9 +228,11 @@ local function pushResult(status, mode, source, force) if parent.type == 'return' then if source ~= status.source then results[#results+1] = source - return + return true end end + + return false end ---@param obj parser.guide.object @@ -737,6 +740,46 @@ function m.searchRefsByID(status, suri, expect, mode) end end + local function checkExtend(uri, id, field) + if not field + and mode ~= 'field' + and mode ~= 'allfield' then + return + end + if field then + local results = status.results + for i = #results, 1, -1 do + local res = results[i] + if res.type == 'setfield' + or res.type == 'setmethod' + or res.type == 'setindex' then + local resField = noder.getFieldID(getID(res)) + if field == resField then + return + end + end + if res.type == 'doc.field.name' then + local resField = STRING_FIELD .. res[1] + if field == resField then + return + end + end + end + end + for extendID in eachExtend(nodersMap[uri], id) do + local targetUri, targetID + + targetUri, targetID = getUriAndID(extendID) + if targetUri and targetUri ~= uri then + if dontCross == 0 then + searchID(targetUri, targetID, field, uri) + end + else + searchID(uri, targetID or extendID, field) + end + end + end + local function searchSpecial(uri, id, field) -- Special rule: ('').XX -> stringlib.XX if id == 'str:' @@ -892,6 +935,9 @@ function m.searchRefsByID(status, suri, expect, mode) if noders.backward[id] then checkBackward(uri, id, field) end + if noders.extend[id] then + checkExtend(uri, id, field) + end releaseExpanding(elock, ecall, id, field) end diff --git a/test/definition/bug.lua b/test/definition/bug.lua index 7a2cc789..9ea47fd1 100644 --- a/test/definition/bug.lua +++ b/test/definition/bug.lua @@ -253,3 +253,60 @@ local <!v!> = t[a] t[a] = <?v?> ]] + +TEST [[ +---@class A +---@field x number + +---@class B: A +---@field <!x!> boolean + +---@type B +local t + +local <!<?v?>!> = t.x +]] + +TEST [[ +---@class A +---@field <!x!> number + +---@class B: A + +---@type B +local t + +local <!<?v?>!> = t.x +]] + +TEST [[ +---@class A +local A + +function A:x() end + +---@class B: A +local B + +function B:<!x!>() end + +---@type B +local t + +local <!<?v?>!> = t.x +]] + +TEST [[ +---@class A +local A + +function A:<!x!>() end + +---@class B: A +local B + +---@type B +local t + +local <!<?v?>!> = t.x +]] |