summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--script/core/definition.lua3
-rw-r--r--script/core/linker.lua68
-rw-r--r--script/core/searcher.lua73
-rw-r--r--test/basic/linker.txt24
-rw-r--r--test/references/init.lua15
5 files changed, 141 insertions, 42 deletions
diff --git a/script/core/definition.lua b/script/core/definition.lua
index eff55771..84c0d4f4 100644
--- a/script/core/definition.lua
+++ b/script/core/definition.lua
@@ -3,6 +3,7 @@ local workspace = require 'workspace'
local files = require 'files'
local vm = require 'vm'
local findSource = require 'core.find-source'
+local guide = require 'parser.guide'
local function sortResults(results)
-- 先按照顺序排序
@@ -143,7 +144,7 @@ return function (uri, offset)
if values[src] then
goto CONTINUE
end
- local root = searcher.getRoot(src)
+ local root = guide.getRoot(src)
if not root then
goto CONTINUE
end
diff --git a/script/core/linker.lua b/script/core/linker.lua
index ef6dd010..e31d4eb9 100644
--- a/script/core/linker.lua
+++ b/script/core/linker.lua
@@ -6,7 +6,8 @@ local Linkers
local LastIDCache = {}
local SPLIT_CHAR = '\x1F'
local SPLIT_REGEX = SPLIT_CHAR .. '[^' .. SPLIT_CHAR .. ']+$'
-local INDEX_CHAR = '\x1E'
+local RETURN_INDEX_CHAR = '#'
+local PARAM_INDEX_CHAR = '@'
---是否是全局变量(包括 _G.XXX 形式)
---@param source parser.guide.object
@@ -85,14 +86,15 @@ local function getKey(source)
elseif source.type == 'function' then
return source.start, nil
elseif source.type == 'select' then
- return ('%d%s%s%d'):format(source.start, SPLIT_CHAR, INDEX_CHAR, source.index)
+ return ('%d%s%s%d'):format(source.start, SPLIT_CHAR, RETURN_INDEX_CHAR, source.index)
elseif source.type == 'doc.class.name'
or source.type == 'doc.type.name'
or source.type == 'doc.alias.name' then
return source[1], nil
elseif source.type == 'doc.class'
or source.type == 'doc.type'
- or source.type == 'doc.alias' then
+ or source.type == 'doc.alias'
+ or source.type == 'doc.param' then
return source.start, nil
end
return nil, nil
@@ -120,6 +122,9 @@ local function checkMode(source)
if source.type == 'doc.type' then
return 'dt:'
end
+ if source.type == 'doc.param' then
+ return 'dp:'
+ end
if source.type == 'doc.alias' then
return 'da:'
end
@@ -185,7 +190,10 @@ local TempList = {}
---前进
---@param source parser.guide.object
---@return parser.guide.object[]
-local function checkForward(source, id)
+local function checkForward(source)
+ if not source then
+ return
+ end
local list = TempList
local parent = source.parent
if source.value then
@@ -235,7 +243,10 @@ end
---后退
---@param source parser.guide.object
---@return parser.guide.object[]
-local function checkBackward(source, id)
+local function checkBackward(source)
+ if not source then
+ return
+ end
local list = TempList
local parent = source.parent
if parent.value == source then
@@ -266,7 +277,7 @@ local function checkBackward(source, id)
if source.returnIndex then
for _, src in ipairs(parent.bindSources) do
if src.type == 'function' then
- local fullID = ('%s%s%s%s'):format(getID(src), SPLIT_CHAR, INDEX_CHAR, source.returnIndex)
+ local fullID = ('%s%s%s%s'):format(getID(src), SPLIT_CHAR, RETURN_INDEX_CHAR, source.returnIndex)
list[#list+1] = fullID
end
end
@@ -279,6 +290,29 @@ local function checkBackward(source, id)
list[#list+1] = ('s:%d'):format(sel.start)
end
end
+ -- 将调用参数映射到函数调用上
+ if parent.type == 'callargs' then
+ for i = 1, #parent do
+ if parent[i] == source then
+ local call = parent.parent
+ local node = call.node
+ local nodeID = getID(node)
+ if not nodeID then
+ break
+ end
+ list[#list+1] = ('%s%s%s%s'):format(
+ nodeID,
+ SPLIT_CHAR,
+ PARAM_INDEX_CHAR,
+ i
+ )
+ break
+ end
+ end
+ end
+ if source.type == 'doc.param' then
+ print(source)
+ end
if #list == 0 then
return nil
else
@@ -298,18 +332,15 @@ end
---@field backward parser.guide.object[]
---创建source的链接信息
----@param source parser.guide.object
+---@param id string
+---@param source? parser.guide.object
---@return link
-local function createLink(source)
- local id = getID(source)
- if not id then
- return nil
- end
+local function createLink(id, source)
return {
id = id,
source = source,
- forward = checkForward(source, id),
- backward = checkBackward(source, id),
+ forward = checkForward(source),
+ backward = checkBackward(source),
}
end
@@ -325,7 +356,8 @@ end
local m = {}
m.SPLIT_CHAR = SPLIT_CHAR
-m.INDEX_CHAR = INDEX_CHAR
+m.RETURN_INDEX_CHAR = RETURN_INDEX_CHAR
+m.PARAM_INDEX_CHAR = PARAM_INDEX_CHAR
---根据ID来获取所有的link
---@param root parser.guide.object
@@ -360,8 +392,12 @@ end
---@param source parser.guide.object
---@return link
function m.getLink(source)
+ local id = getID(source)
+ if not id then
+ return nil
+ end
if source._link == nil then
- source._link = createLink(source) or false
+ source._link = createLink(id, source) or false
end
return source._link or nil
end
diff --git a/script/core/searcher.lua b/script/core/searcher.lua
index f99b3a4b..f57e9ad7 100644
--- a/script/core/searcher.lua
+++ b/script/core/searcher.lua
@@ -96,6 +96,43 @@ function m.isGlobal(source)
return false
end
+---@param obj parser.guide.object
+---@return parser.guide.object?
+function m.getObjectValue(obj)
+ while obj.type == 'paren' do
+ obj = obj.exp
+ if not obj then
+ return nil
+ end
+ end
+ if obj.type == 'boolean'
+ or obj.type == 'number'
+ or obj.type == 'integer'
+ or obj.type == 'string'
+ or obj.type == 'doc.type.table'
+ or obj.type == 'doc.type.arrary' then
+ return obj
+ end
+ if obj.value then
+ return obj.value
+ end
+ if obj.type == 'field'
+ or obj.type == 'method' then
+ return obj.parent and obj.parent.value
+ end
+ if obj.type == 'call' then
+ if obj.node.special == 'rawset' then
+ return obj.args and obj.args[3]
+ else
+ return obj
+ end
+ end
+ if obj.type == 'select' then
+ return obj
+ end
+ return nil
+end
+
function m.searchRefsByID(status, uri, expect, mode)
local ast = files.getAst(uri)
if not ast then
@@ -124,26 +161,6 @@ function m.searchRefsByID(status, uri, expect, mode)
search(id)
end
- local function getCallSelectByReturnIndex(func, index)
- local call = func.parent
- if call.type ~= 'call' then
- return nil
- end
- if index == 0 then
- return nil
- end
- if index == 1 then
- return call.parent
- else
- for _, sel in ipairs(call.extParent) do
- if sel.index == index then
- return sel
- end
- end
- end
- return nil
- end
-
local function searchFunction(id)
local funcs = linker.getLinksByID(root, id)
if not funcs then
@@ -165,7 +182,7 @@ function m.searchRefsByID(status, uri, expect, mode)
if not parentID then
return
end
- search(parentID, linker.SPLIT_CHAR .. linker.INDEX_CHAR .. returnIndex)
+ search(parentID, linker.SPLIT_CHAR .. linker.RETURN_INDEX_CHAR .. returnIndex)
end
local function checkForward(link, field)
@@ -269,4 +286,18 @@ function m.requestReference(obj, interface, deep)
return status.results, 0
end
+--- 请求对象的定义
+---@param obj parser.guide.object
+---@param interface table
+---@param deep integer
+---@return parser.guide.object[]
+---@return integer
+function m.requestDefinition(obj, interface, deep)
+ local status = m.status(nil, interface, deep)
+ -- 根据 field 搜索引用
+ m.searchRefs(status, obj, 'def')
+
+ return status.results, 0
+end
+
return m
diff --git a/test/basic/linker.txt b/test/basic/linker.txt
index 284623cd..86930dd1 100644
--- a/test/basic/linker.txt
+++ b/test/basic/linker.txt
@@ -82,7 +82,23 @@ local <!x!> = f()
```
'd|A'
-'f|1:1'
-'f|1' + ':1'
-'l|1' + ':1'
-'s|1' + ':1'
+'f|1|#1'
+'f|1' + '|#1'
+'l|1' + '|#1'
+'s|1' + '|#1'
+
+```lua
+---@generic T
+---@param a T
+---@return T
+local function f(a) end
+
+local <?c?>
+
+local <!v!> = f(c)
+```
+
+'l1'
+'l2|@1'
+'f|1|@1'
+'f|1|#1'
diff --git a/test/references/init.lua b/test/references/init.lua
index 31aa0cb7..f6785984 100644
--- a/test/references/init.lua
+++ b/test/references/init.lua
@@ -359,6 +359,21 @@ local a, b = f()
return a.x, b.<!x!>
]]
+-- TODO 支持泛型
+do return end
+TEST [[
+---@class Dog
+local <?Dog?> = {}
+
+---@generic T
+---@param type1 T
+---@return T
+function foobar(type1)
+end
+
+local <!v1!> = foobar(<!Dog!>)
+]]
+
TEST [[
---@class Dog
local Dog = {}