summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--script/core/definition.lua3
-rw-r--r--script/core/linker.lua15
-rw-r--r--script/core/searcher.lua105
-rw-r--r--script/parser/ast.lua2
-rw-r--r--script/parser/luadoc.lua1
-rw-r--r--test/basic/linker.txt31
-rw-r--r--test/definition/luadoc.lua10
7 files changed, 112 insertions, 55 deletions
diff --git a/script/core/definition.lua b/script/core/definition.lua
index 84c0d4f4..5d996a88 100644
--- a/script/core/definition.lua
+++ b/script/core/definition.lua
@@ -149,9 +149,6 @@ return function (uri, offset)
goto CONTINUE
end
src = src.field or src.method or src.index or src
- if src.type == 'table' and src.parent.type ~= 'return' then
- goto CONTINUE
- end
if src.type == 'doc.class.name'
and source.type ~= 'doc.type.name'
and source.type ~= 'doc.extends.name'
diff --git a/script/core/linker.lua b/script/core/linker.lua
index faf051ea..76d8bf35 100644
--- a/script/core/linker.lua
+++ b/script/core/linker.lua
@@ -287,6 +287,10 @@ local function pushBackward(id, backwardID)
link.backward[#link.backward+1] = backwardID
end
+local function findDocState()
+
+end
+
---前进
---@param source parser.guide.object
---@return parser.guide.object[]
@@ -438,7 +442,7 @@ local function compileLink(source)
)
pushForward(id, callID)
pushBackward(callID, id)
- getLink(id).callinfo = source.vararg
+ getLink(id).call = source.vararg
end
end
if source.type == 'doc.type.function' then
@@ -509,9 +513,12 @@ local function compileLink(source)
end
if doc.type == 'doc.param' then
local paramName = doc.param[1]
- for _, param in ipairs(source.args) do
- if param[1] == paramName then
+ if source.docParamMap then
+ local paramIndex = source.docParamMap[paramName]
+ local param = source.args[paramIndex]
+ if param then
pushForward(getID(param), getID(doc))
+ param.docParam = doc
end
end
end
@@ -537,7 +544,7 @@ end
-- 后退的关联ID
---@field backward string[]
-- 函数调用参数信息(用于泛型)
----@field callinfo parser.guide.object
+---@field call parser.guide.object
local m = {}
diff --git a/script/core/searcher.lua b/script/core/searcher.lua
index 30d2e33a..7f17088b 100644
--- a/script/core/searcher.lua
+++ b/script/core/searcher.lua
@@ -45,11 +45,13 @@ function m.pushResult(status, mode, source)
or source.type == 'tableindex'
or source.type == 'tablefield'
or source.type == 'function'
+ or source.type == 'table'
or source.type == 'doc.class.name'
or source.type == 'doc.alias.name'
or source.type == 'doc.field.name'
or source.type == 'doc.type.function' then
results[#results+1] = source
+ return
end
if source.type == 'call' then
if source.node.special == 'rawset' then
@@ -78,6 +80,7 @@ function m.pushResult(status, mode, source)
or source.type == 'tableindex'
or source.type == 'tablefield'
or source.type == 'function'
+ or source.type == 'table'
or source.type == 'doc.class.name'
or source.type == 'doc.type.name'
or source.type == 'doc.alias.name'
@@ -85,6 +88,7 @@ function m.pushResult(status, mode, source)
or source.type == 'doc.field.name'
or source.type == 'doc.type.function' then
results[#results+1] = source
+ return
end
if source.type == 'call' then
if source.node.special == 'rawset'
@@ -170,12 +174,15 @@ function m.searchRefsByID(status, uri, expect, mode)
status.id = expect
local mark = status.mark
- local queueIDs = {}
- local queueFields = {}
- local queueCallInfos = {}
- local index = 0
+ local queueIDs = {}
+ local queueFields = {}
+ local queueCalls = {}
+ local queueIndex = 0
- local function search(id, field, callinfo)
+ -- 缓存过程中的泛型,以泛型关联表为key
+ local genericStashMap = {}
+
+ local function search(id, field, call)
local fieldLen
if field then
local _, len = field:gsub(linker.SPLIT_CHAR, '')
@@ -187,10 +194,10 @@ function m.searchRefsByID(status, uri, expect, mode)
return
end
mark[id] = fieldLen
- index = index + 1
- queueIDs[index] = id
- queueFields[index] = field
- queueCallInfos[index] = callinfo
+ queueIndex = queueIndex + 1
+ queueIDs[queueIndex] = id
+ queueFields[queueIndex] = field
+ queueCalls[queueIndex] = call
end
local function checkLastID(id, field, callinfo)
@@ -238,51 +245,61 @@ function m.searchRefsByID(status, uri, expect, mode)
search(parentID, linker.SPLIT_CHAR .. linker.RETURN_INDEX_CHAR .. returnIndex)
end
- local function checkGeneric(link, field, callinfo)
- if not link.sources then
+ local function genericStash(source, call)
+ if not call or not call.args then
return
end
- if not callinfo or not callinfo.args then
- return
- end
- local source = link.sources[1]
- if source.typeGeneric then
- local key = source[1]
- local generics = source.typeGeneric[key]
- if generics then
- for _, docName in ipairs(generics) do
- -- @param T
- local docType = docName.parent
- local param = callinfo.args[docType.paramIndex]
- if param then
- if docName.literal then
- -- @param `T`
- if param.type == 'string' and param[1] then
- local paramID = 'dn:' .. param[1]
- searchID(paramID, field)
+ if source.type == 'function' then
+ if not source.docParamMap then
+ return
+ end
+ for index, param in ipairs(source.args) do
+ local docParam = param.docParam
+ if docParam then
+ for _, typeUnit in ipairs(docParam.extends.types) do
+ if typeUnit.typeGeneric then
+ local key = typeUnit[1]
+ local generics = typeUnit.typeGeneric[key]
+ local callParam = call.args[index]
+ if callParam then
+ if typeUnit.literal then
+ if callParam.type == 'string' then
+ genericStashMap[generics] = ('dn:%s'):format(callParam[1] or '')
+ end
+ else
+ genericStashMap[generics] = linker.getID(callParam)
+ end
end
- else
- local paramID = linker.getID(param)
- searchID(paramID, field)
end
- return
end
end
end
end
end
+ local function genericResolve(source, field)
+ if not source.typeGeneric then
+ return
+ end
+ local key = source[1]
+ local generics = source.typeGeneric[key]
+ local paramID = genericStashMap[generics]
+ if paramID then
+ searchID(paramID, field)
+ end
+ end
+
search(expect)
searchFunction(expect)
for _ = 1, 1000 do
- if index <= 0 then
+ if queueIndex <= 0 then
return
end
- local id = queueIDs[index]
- local field = queueFields[index]
- local callinfo = queueCallInfos[index]
- index = index - 1
+ local id = queueIDs[queueIndex]
+ local field = queueFields[queueIndex]
+ local call = queueCalls[queueIndex]
+ queueIndex = queueIndex - 1
local link = linker.getLinkByID(root, id)
if link then
@@ -293,17 +310,21 @@ function m.searchRefsByID(status, uri, expect, mode)
end
if link.forward then
for _, forwardID in ipairs(link.forward) do
- searchID(forwardID, field, link.callinfo or callinfo)
+ searchID(forwardID, field, link.call or call)
end
end
if link.backward and (mode == 'ref' or field) then
for _, backwardID in ipairs(link.backward) do
- searchID(backwardID, field, link.callinfo or callinfo)
+ searchID(backwardID, field, link.call or call)
end
end
- checkGeneric(link, field, callinfo)
+
+ if link.sources then
+ genericStash(link.sources[1], call)
+ genericResolve(link.sources[1], field)
+ end
end
- checkLastID(id, field, callinfo)
+ checkLastID(id, field, call)
end
error('too large')
end
diff --git a/script/parser/ast.lua b/script/parser/ast.lua
index 45801cf6..0a188da4 100644
--- a/script/parser/ast.lua
+++ b/script/parser/ast.lua
@@ -1454,7 +1454,7 @@ local Defs = {
if func then
local call = createCall(exp, func.finish + 1, exp.finish)
call.node = func
- call.start = func.start
+ call.start = inA
func.next = call
func.iterator = true
values = { call }
diff --git a/script/parser/luadoc.lua b/script/parser/luadoc.lua
index 31a6af48..af5071b3 100644
--- a/script/parser/luadoc.lua
+++ b/script/parser/luadoc.lua
@@ -1145,6 +1145,7 @@ local function bindParamAndReturnIndex(binded)
paramMap[param[1]] = paramIndex
end
end
+ func.docParamMap = paramMap
end
local returnIndex = 0
for _, doc in ipairs(binded) do
diff --git a/test/basic/linker.txt b/test/basic/linker.txt
index 5d3f54e8..ea3ba180 100644
--- a/test/basic/linker.txt
+++ b/test/basic/linker.txt
@@ -113,8 +113,29 @@ local <?r?> = f(<!k!>)
```
l:r
-s:1#1
-l:f#1
-dg:T
-l:f@1
-k
+s:1#1 call
+l:f#1 call
+f:1#1 call -> f:1&T = l:k
+l:f@1 --> 从保存的call信息里找到 f:1&T = l:k
+l:k
+
+
+
+```
+---@generic T, V
+---@param p T
+---@return fun(V):T, V
+local function f(p) end
+
+local f2 = f(<!k!>)
+local <?r?> = f2()
+```
+
+l:r
+s:2|#1 call1
+l:f2|#1 call1
+f:2|#1 call1
+s:1#1|#1 call2
+f:1#1|#1 call2 -> f:1&T = l:k
+dfun:1|#1
+dn:V -> f:1&T = l:k
diff --git a/test/definition/luadoc.lua b/test/definition/luadoc.lua
index 61f203f5..a971dfe1 100644
--- a/test/definition/luadoc.lua
+++ b/test/definition/luadoc.lua
@@ -328,6 +328,16 @@ local <?<!f2!>?> = f()
]]
TEST [[
+---@generic T
+---@param x T
+---@return fun():T
+local function f(x) end
+
+local v1 = f(<!{}!>)
+local <?<!v2!>?> = v1()
+]]
+
+TEST [[
---@class Foo
local Foo = {}
function Foo:<!bar1!>() end