summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
author最萌小汐 <sumneko@hotmail.com>2021-04-26 17:53:13 +0800
committer最萌小汐 <sumneko@hotmail.com>2021-04-26 17:53:13 +0800
commitb8ec5ae4fa6eed01bd093f39cd76f01d315f330a (patch)
treecc071308420dcec3a8f7a8d5494cd9797caa143c
parentda4bc8f26bad8997cf82b8e3752eb6c197eb08fd (diff)
downloadlua-language-server-b8ec5ae4fa6eed01bd093f39cd76f01d315f330a.zip
generic forward
-rw-r--r--script/core/linker.lua63
-rw-r--r--script/core/searcher.lua65
-rw-r--r--test/definition/luadoc.lua17
3 files changed, 91 insertions, 54 deletions
diff --git a/script/core/linker.lua b/script/core/linker.lua
index 9741d7b4..97ff4927 100644
--- a/script/core/linker.lua
+++ b/script/core/linker.lua
@@ -1,13 +1,25 @@
local util = require 'utility'
local guide = require 'parser.guide'
-local Linkers, GetLink
+local Linkers
local LastIDCache = {}
local SPLIT_CHAR = '\x1F'
local SPLIT_REGEX = SPLIT_CHAR .. '[^' .. SPLIT_CHAR .. ']+$'
local RETURN_INDEX_CHAR = '#'
local PARAM_INDEX_CHAR = '@'
+---创建source的链接信息
+---@param id string
+---@return link
+local function getLink(id)
+ if not Linkers[id] then
+ Linkers[id] = {
+ id = id,
+ }
+ end
+ return Linkers[id]
+end
+
---是否是全局变量(包括 _G.XXX 形式)
---@param source parser.guide.object
---@return boolean
@@ -107,7 +119,11 @@ local function getKey(source)
or source.type == 'doc.alias.name'
or source.type == 'doc.extends.name'
or source.type == 'doc.see.name' then
- return source[1], nil
+ local name = source[1]
+ if source.typeGeneric then
+ return source.start, nil
+ end
+ return name, nil
elseif source.type == 'doc.class'
or source.type == 'doc.type'
or source.type == 'doc.alias'
@@ -226,7 +242,7 @@ end
---@param id string
---@param source parser.guide.object
local function pushSource(id, source)
- local link = GetLink(id)
+ local link = getLink(id)
if not link.sources then
link.sources = {}
end
@@ -243,7 +259,7 @@ local function pushForward(id, forwardID)
or id == forwardID then
return
end
- local link = GetLink(id)
+ local link = getLink(id)
if not link.forward then
link.forward = {}
end
@@ -260,7 +276,7 @@ local function pushBackward(id, backwardID)
or id == backwardID then
return
end
- local link = GetLink(id)
+ local link = getLink(id)
if not link.backward then
link.backward = {}
end
@@ -396,6 +412,7 @@ local function compileLink(source)
pushForward(id, callID)
pushBackward(callID, id)
end
+ getLink(selectID).callinfo = source
end)
-- 将setmetatable映射到 param1 以及 param2.__index 上
if node.special == 'setmetatable' then
@@ -463,8 +480,8 @@ local function compileLink(source)
RETURN_INDEX_CHAR,
rtn.returnIndex
)
- pushForward(getID(rtn), fullID)
- pushBackward(fullID, getID(rtn))
+ pushForward(fullID, getID(rtn))
+ pushBackward(getID(rtn), fullID)
end
end
if doc.type == 'doc.param' then
@@ -496,18 +513,8 @@ end
---@field forward string[]
-- 后退的关联ID
---@field backward string[]
-
----创建source的链接信息
----@param id string
----@return link
-function GetLink(id)
- if not Linkers[id] then
- Linkers[id] = {
- id = id,
- }
- end
- return Linkers[id]
-end
+-- 函数调用参数信息(用于泛型)
+---@field callinfo parser.guide.object
local m = {}
@@ -551,24 +558,6 @@ function m.getID(source)
return getID(source)
end
----获取source的special
----@param source parser.guide.object
----@return table
-function m.getSpecial(source, key)
- if not source then
- return nil
- end
- local link = m.getLink(source)
- if not link then
- return nil
- end
- local special = link.special
- if not special then
- return nil
- end
- return special[key]
-end
-
---编译整个文件的link
---@param source parser.guide.object
---@return table
diff --git a/script/core/searcher.lua b/script/core/searcher.lua
index 46d58940..75437497 100644
--- a/script/core/searcher.lua
+++ b/script/core/searcher.lua
@@ -2,7 +2,7 @@ local linker = require 'core.linker'
local guide = require 'parser.guide'
local files = require 'files'
-local UNI_CHAR = '~'
+local MARK_CHAR = '\x1E'
local function checkFunctionReturn(source)
if source.parent
@@ -170,14 +170,15 @@ function m.searchRefsByID(status, uri, expect, mode)
status.id = expect
local mark = status.mark
- local queueIDs = {}
- local queueFields = {}
- local index = 0
+ local queueIDs = {}
+ local queueFields = {}
+ local queueCallInfos = {}
+ local index = 0
- local function search(id, field)
+ local function search(id, field, callinfo)
local fullID
if field then
- fullID = id .. '\x1E' .. field
+ fullID = id .. MARK_CHAR .. field
local _, len = field:gsub(linker.SPLIT_CHAR, '')
if len >= 10 then
return
@@ -190,26 +191,30 @@ function m.searchRefsByID(status, uri, expect, mode)
end
mark[fullID] = true
index = index + 1
- queueIDs[index] = id
- queueFields[index] = field
+ queueIDs[index] = id
+ queueFields[index] = field
+ queueCallInfos[index] = callinfo
end
- local function checkLastID(id, field)
+ local function checkLastID(id, field, callinfo)
local lastID = linker.getLastID(id)
if lastID then
local newField = id:sub(#lastID + 1)
if field then
newField = newField .. field
end
- search(lastID, newField)
+ search(lastID, newField, callinfo)
end
end
- local function searchID(id, field)
+ local function searchID(id, field, callinfo)
+ if not id then
+ return
+ end
if field then
id = id .. field
end
- search(id)
+ search(id, nil, callinfo)
end
local function searchFunction(id)
@@ -236,6 +241,30 @@ function m.searchRefsByID(status, uri, expect, mode)
search(parentID, linker.SPLIT_CHAR .. linker.RETURN_INDEX_CHAR .. returnIndex)
end
+ local function checkGeneric(link, field, callinfo)
+ if not link.sources then
+ return
+ end
+ if not callinfo or not callinfo.args then
+ return
+ end
+ local source = link.sources[1]
+ if source.typeGeneric then
+ local key = source[1]
+ local generics = source.typeGeneric[key]
+ if generics then
+ for _, docName in ipairs(generics) do
+ local docType = docName.parent
+ -- @param T
+ if docType.paramIndex then
+ local paramID = linker.getID(callinfo.args[docType.paramIndex])
+ searchID(paramID, field)
+ end
+ end
+ end
+ end
+ end
+
search(expect)
searchFunction(expect)
@@ -243,8 +272,9 @@ function m.searchRefsByID(status, uri, expect, mode)
if index <= 0 then
return
end
- local id = queueIDs[index]
- local field = queueFields[index]
+ local id = queueIDs[index]
+ local field = queueFields[index]
+ local callinfo = queueCallInfos[index]
index = index - 1
local link = linker.getLinkByID(root, id)
@@ -256,16 +286,17 @@ function m.searchRefsByID(status, uri, expect, mode)
end
if link.forward then
for _, forwardID in ipairs(link.forward) do
- searchID(forwardID, field)
+ searchID(forwardID, field, link.callinfo or callinfo)
end
end
if link.backward and (mode == 'ref' or field) then
for _, backwardID in ipairs(link.backward) do
- searchID(backwardID, field)
+ searchID(backwardID, field, link.callinfo or callinfo)
end
end
+ checkGeneric(link, field, callinfo)
end
- checkLastID(id, field)
+ checkLastID(id, field, callinfo)
end
error('too large')
end
diff --git a/test/definition/luadoc.lua b/test/definition/luadoc.lua
index e04c62df..438425f3 100644
--- a/test/definition/luadoc.lua
+++ b/test/definition/luadoc.lua
@@ -209,6 +209,23 @@ TEST [[
]]
TEST [[
+---@return <!fun()!>
+local function f() end
+
+local <?<!r!>?> = f()
+]]
+
+TEST [[
+---@generic T
+---@param p T
+---@return T
+local function f(p) end
+
+local <!k!>
+local <?<!r!>?> = f(k)
+]]
+
+TEST [[
---@class Foo
local Foo = {}
function Foo:<!bar1!>() end