summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
author最萌小汐 <sumneko@hotmail.com>2019-12-14 15:54:35 +0800
committer最萌小汐 <sumneko@hotmail.com>2019-12-14 15:54:35 +0800
commitbbd6f89a2a024e4b9fa7f7f1f5cd5564187c622b (patch)
treec95d327ef00296f346e7c7dba53a0450df329791
parent0824ee5122f50b824bc677a8838b3e1d0102258e (diff)
downloadlua-language-server-bbd6f89a2a024e4b9fa7f7f1f5cd5564187c622b.zip
整理ref实现
-rw-r--r--script-beta/core/reference.lua5
-rw-r--r--script-beta/vm/eachRef.lua123
-rw-r--r--test-beta/references/init.lua11
3 files changed, 68 insertions, 71 deletions
diff --git a/script-beta/core/reference.lua b/script-beta/core/reference.lua
index 0f4efbfc..6edd63a7 100644
--- a/script-beta/core/reference.lua
+++ b/script-beta/core/reference.lua
@@ -72,8 +72,13 @@ return function (uri, offset)
return nil
end
local results = {}
+ local mark = {}
guide.eachSourceContain(ast.ast, offset, function (source)
findRef(source, offset, function (target, uri)
+ if mark[target] then
+ return
+ end
+ mark[target] = true
results[#results+1] = {
target = target,
uri = files.getOriginUri(uri),
diff --git a/script-beta/vm/eachRef.lua b/script-beta/vm/eachRef.lua
index 6e571822..a3ce1518 100644
--- a/script-beta/vm/eachRef.lua
+++ b/script-beta/vm/eachRef.lua
@@ -45,45 +45,6 @@ local function ofCallSelect(call, index, callback)
end
end
-local function ofReturn(rtn, index, callback)
- local func = guide.getParentFunction(rtn)
- if not func then
- return
- end
- -- 搜索函数调用的第 index 个接收值
- if func.type == 'main' then
- local myUri = func.uri
- local uris = files.findLinkTo(myUri)
- if not uris then
- return
- end
- for _, uri in ipairs(uris) do
- local ast = files.getAst(uri)
- if ast then
- local links = vm.getLinks(ast.ast)
- if links then
- for linkUri, calls in pairs(links) do
- if files.eq(linkUri, myUri) then
- for i = 1, #calls do
- ofCallSelect(calls[i], 1, callback)
- end
- end
- end
- end
- end
- end
- else
- vm.eachRef(func, function (info)
- local source = info.source
- local call = source.parent
- if not call or call.type ~= 'call' then
- return
- end
- ofCallSelect(call, index, callback)
- end)
- end
-end
-
local function ofSpecialCall(call, func, index, callback, offset)
local name = func.special
offset = offset or 0
@@ -151,19 +112,11 @@ local function ofSpecialCall(call, func, index, callback, offset)
end
end
-local function ofValue(value, callback)
- if value.type == 'select' then
- -- 检查函数返回值
- local call = value.vararg
- if call.type == 'call' then
- ofCall(call.node, value.index, callback)
- ofSpecialCall(call, call.node, value.index, callback)
- end
- return
+local function asSetValue(value, callback)
+ if value.type == 'field'
+ or value.type == 'method' then
+ value = value.parent
end
-
- vm.eachRef(value, callback)
-
local parent = value.parent
if not parent then
return
@@ -180,13 +133,14 @@ local function ofValue(value, callback)
vm.eachRef(parent, callback)
end
end
- if parent.type == 'return' then
- for i = 1, #parent do
- if parent[i] == value then
- ofReturn(parent, i, callback)
- break
- end
- end
+end
+
+local function ofSelect(source, callback)
+ -- 检查函数返回值
+ local call = source.vararg
+ if call.type == 'call' then
+ ofCall(call.node, source.index, callback)
+ ofSpecialCall(call, call.node, source.index, callback)
end
end
@@ -225,11 +179,11 @@ local function getCallRecvs(call)
if parent.type ~= 'select' then
return nil
end
- local exParent = call.exParent
+ local extParent = call.extParent
local recvs = {}
recvs[1] = parent.parent
- if exParent then
- for _, p in ipairs(exParent) do
+ if extParent then
+ for _, p in ipairs(extParent) do
recvs[#recvs+1] = p.parent
end
end
@@ -271,11 +225,35 @@ end
--- 自己作为函数的返回值
local function asReturn(source, callback)
local parent = source.parent
+ if source.type == 'field'
+ or source.type == 'method' then
+ parent = parent.parent
+ end
if not parent or parent.type ~= 'return' then
return
end
local func = guide.getParentFunction(source)
if func.type == 'main' then
+ local myUri = func.uri
+ local uris = files.findLinkTo(myUri)
+ if not uris then
+ return
+ end
+ for _, uri in ipairs(uris) do
+ local ast = files.getAst(uri)
+ if ast then
+ local links = vm.getLinks(ast.ast)
+ if links then
+ for linkUri, calls in pairs(links) do
+ if files.eq(linkUri, myUri) then
+ for i = 1, #calls do
+ ofCallSelect(calls[i], 1, callback)
+ end
+ end
+ end
+ end
+ end
+ end
else
local index
for i = 1, #parent do
@@ -321,15 +299,15 @@ local function ofLocal(loc, callback)
source = ref,
mode = 'get',
}
- asValue(ref, callback)
- asReturn(ref, callback)
+ vm.eachRef(ref, callback)
elseif ref.type == 'setlocal' then
callback {
source = ref,
mode = 'set',
}
+ vm.eachRef(ref, callback)
if ref.value then
- ofValue(ref.value, callback)
+ vm.eachRef(ref.value, callback)
end
end
end
@@ -338,7 +316,7 @@ local function ofLocal(loc, callback)
ofSelf(loc, callback)
end
if loc.value then
- ofValue(loc.value, callback)
+ vm.eachRef(loc.value, callback)
end
if loc.tag == '_ENV' and loc.ref then
for _, ref in ipairs(loc.ref) do
@@ -377,7 +355,7 @@ local function ofGlobal(source, callback)
for _, info in ipairs(globals[key]) do
callback(info)
if info.value then
- ofValue(info.value, callback)
+ vm.eachRef(info.value, callback)
end
end
end
@@ -390,7 +368,7 @@ local function ofGlobal(source, callback)
mode = info.mode,
}
if info.value then
- ofValue(info.value, callback)
+ vm.eachRef(info.value, callback)
end
end
end)
@@ -413,7 +391,7 @@ local function ofField(source, callback)
mode = info.mode,
}
if info.value then
- ofValue(info.value, callback)
+ vm.eachRef(info.value, callback)
end
end
end)
@@ -426,7 +404,7 @@ local function ofField(source, callback)
mode = info.mode,
}
if info.value then
- ofValue(info.value, callback)
+ vm.eachRef(info.value, callback)
end
end
end)
@@ -502,7 +480,8 @@ local function eachRef(source, callback)
or stype == 'method' then
ofField(source, callback)
elseif stype == 'setfield'
- or stype == 'getfield' then
+ or stype == 'getfield'
+ or stype == 'tablefield' then
ofField(source.field, callback)
elseif stype == 'setmethod'
or stype == 'getmethod' then
@@ -518,8 +497,9 @@ local function eachRef(source, callback)
ofLabel(source, callback)
elseif stype == 'table'
or stype == 'function' then
- ofValue(source, callback)
ofSelfValue(source, callback)
+ elseif stype == 'select' then
+ ofSelect(source, callback)
elseif stype == 'call' then
ofCall(source.node, 1, callback)
ofSpecialCall(source, source.node, 1, callback)
@@ -531,6 +511,7 @@ local function eachRef(source, callback)
asArg(source, callback)
asReturn(source, callback)
asParen(source, callback)
+ asSetValue(source, callback)
end
--- 判断2个对象是否拥有相同的引用
diff --git a/test-beta/references/init.lua b/test-beta/references/init.lua
index 3cbb5ed8..04ca9936 100644
--- a/test-beta/references/init.lua
+++ b/test-beta/references/init.lua
@@ -204,6 +204,17 @@ t.x = 1
t[a.b.<?x?>] = 1
]]
+TEST [[
+local t
+local <!f!> = t.<?f?>
+
+<!f!>()
+
+return {
+ <!f!> = <!f!>,
+}
+]]
+
--TEST [[
-----@class <!Class!>
-----@type <?Class?>