summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--server/src/core/definition.lua70
-rw-r--r--server/src/core/implementation.lua69
-rw-r--r--server/src/method/textDocument/definition.lua4
-rw-r--r--server/src/method/textDocument/implementation.lua2
-rw-r--r--server/test/crossfile/definition.lua21
-rw-r--r--server/test/definition/init.lua2
6 files changed, 111 insertions, 57 deletions
diff --git a/server/src/core/definition.lua b/server/src/core/definition.lua
index 60038724..cd44f94f 100644
--- a/server/src/core/definition.lua
+++ b/server/src/core/definition.lua
@@ -1,3 +1,32 @@
+local function findFieldBySource(positions, source, obj, result)
+ if source.type == 'name' and source[1] == result.key then
+ if obj.type == 'field' then
+ for _, info in ipairs(obj) do
+ if info.type == 'set' and info.source == source then
+ positions[#positions+1] = {
+ source.start,
+ source.finish,
+ source.uri,
+ }
+ end
+ end
+ end
+ end
+end
+
+local function findFieldByName(positions, vm, result)
+ for source, obj in pairs(vm.results.sources) do
+ if source.type == 'multi-source' then
+ for i = 1, #obj do
+ findFieldBySource(positions, source, obj[i], result)
+ end
+ else
+ findFieldBySource(positions, source, obj, result)
+ end
+ end
+end
+
+
local function parseResultAcrossUri(positions, vm, result)
-- 跨越文件时,遍历的是值的绑定信息
for _, info in ipairs(result.value) do
@@ -28,35 +57,23 @@ local function parseResultAcrossUri(positions, vm, result)
end
end
-local function findFieldBySource(positions, source, obj, result)
- if source.type == 'name' and source[1] == result.key then
- if obj.type == 'field' then
- for _, info in ipairs(obj) do
- if info.type == 'set' and info.source == source then
- positions[#positions+1] = {
- source.start,
- source.finish,
- source.uri,
- }
- end
- end
- end
+local function findFieldCrossUriByName(positions, vm, result, lsp)
+ if not lsp then
+ return
end
-end
-
-local function findFieldByName(positions, vm, result)
- for source, obj in pairs(vm.results.sources) do
- if source.type == 'multi-source' then
- for i = 1, #obj do
- findFieldBySource(positions, source, obj[i], result)
- end
- else
- findFieldBySource(positions, source, obj, result)
+ local parentValue = result.parentValue
+ if not parentValue then
+ return
+ end
+ if parentValue.uri ~= vm.uri then
+ local destVM = lsp:loadVM(parentValue.uri)
+ if destVM then
+ findFieldByName(positions, destVM, result)
end
end
end
-local function parseResult(vm, result)
+local function parseResult(vm, result, lsp)
local positions = {}
local tp = result.type
if tp == 'local' then
@@ -88,6 +105,7 @@ local function parseResult(vm, result)
end
if #positions == 0 then
findFieldByName(positions, vm, result)
+ findFieldCrossUriByName(positions, vm, result, lsp)
end
end
elseif tp == 'label' then
@@ -110,10 +128,10 @@ local function parseResult(vm, result)
return positions
end
-return function (vm, result)
+return function (vm, result, lsp)
if not result then
return nil
end
- local positions = parseResult(vm, result)
+ local positions = parseResult(vm, result, lsp)
return positions
end
diff --git a/server/src/core/implementation.lua b/server/src/core/implementation.lua
index c10cfaf9..f6593cf2 100644
--- a/server/src/core/implementation.lua
+++ b/server/src/core/implementation.lua
@@ -1,3 +1,31 @@
+local function findFieldBySource(positions, source, obj, result)
+ if source.type == 'name' and source[1] == result.key then
+ if obj.type == 'field' then
+ for _, info in ipairs(obj) do
+ if info.type == 'set' and info.source == source then
+ positions[#positions+1] = {
+ source.start,
+ source.finish,
+ source.uri,
+ }
+ end
+ end
+ end
+ end
+end
+
+local function findFieldByName(positions, vm, result)
+ for source, obj in pairs(vm.results.sources) do
+ if source.type == 'multi-source' then
+ for i = 1, #obj do
+ findFieldBySource(positions, source, obj[i], result)
+ end
+ else
+ findFieldBySource(positions, source, obj, result)
+ end
+ end
+end
+
local function parseResultAcrossUri(positions, vm, result)
-- 跨越文件时,遍历的是值的绑定信息
for _, info in ipairs(result.value) do
@@ -28,35 +56,23 @@ local function parseResultAcrossUri(positions, vm, result)
end
end
-local function findFieldBySource(positions, source, obj, result)
- if source.type == 'name' and source[1] == result.key then
- if obj.type == 'field' then
- for _, info in ipairs(obj) do
- if info.type == 'set' and info.source == source then
- positions[#positions+1] = {
- source.start,
- source.finish,
- source.uri,
- }
- end
- end
- end
+local function findFieldCrossUriByName(positions, vm, result, lsp)
+ if not lsp then
+ return
end
-end
-
-local function findFieldByName(positions, vm, result)
- for source, obj in pairs(vm.results.sources) do
- if source.type == 'multi-source' then
- for i = 1, #obj do
- findFieldBySource(positions, source, obj[i], result)
- end
- else
- findFieldBySource(positions, source, obj, result)
+ local parentValue = result.parentValue
+ if not parentValue then
+ return
+ end
+ if parentValue.uri ~= vm.uri then
+ local destVM = lsp:loadVM(parentValue.uri)
+ if destVM then
+ findFieldByName(positions, destVM, result)
end
end
end
-local function parseResult(vm, result)
+local function parseResult(vm, result, lsp)
local positions = {}
local tp = result.type
if tp == 'local' then
@@ -88,6 +104,7 @@ local function parseResult(vm, result)
end
if #positions == 0 then
findFieldByName(positions, vm, result)
+ findFieldCrossUriByName(positions, vm, result, lsp)
end
end
elseif tp == 'label' then
@@ -110,10 +127,10 @@ local function parseResult(vm, result)
return positions
end
-return function (vm, result)
+return function (vm, result, lsp)
if not result then
return nil
end
- local positions = parseResult(vm, result)
+ local positions = parseResult(vm, result, lsp)
return positions
end
diff --git a/server/src/method/textDocument/definition.lua b/server/src/method/textDocument/definition.lua
index a9f92631..17351318 100644
--- a/server/src/method/textDocument/definition.lua
+++ b/server/src/method/textDocument/definition.lua
@@ -24,14 +24,14 @@ return function (lsp, params)
end
-- lua是从1开始的,因此都要+1
local position = lines:position(params.position.line + 1, params.position.character + 1)
- local result = core.findResult(vm, position)
+ local result = core.findResult(vm, position, lsp)
if not result then
return nil
end
checkWorkSpaceComplete(lsp, result)
- local positions = core.definition(vm, result)
+ local positions = core.definition(vm, result, lsp)
if not positions then
return nil
end
diff --git a/server/src/method/textDocument/implementation.lua b/server/src/method/textDocument/implementation.lua
index df6f027b..7117aba1 100644
--- a/server/src/method/textDocument/implementation.lua
+++ b/server/src/method/textDocument/implementation.lua
@@ -28,7 +28,7 @@ return function (lsp, params)
checkWorkSpaceComplete(lsp, result)
- local positions = core.implementation(vm, result)
+ local positions = core.implementation(vm, result, lsp)
if not positions then
return nil
end
diff --git a/server/test/crossfile/definition.lua b/server/test/crossfile/definition.lua
index 8d8c7569..dda1bf9a 100644
--- a/server/test/crossfile/definition.lua
+++ b/server/test/crossfile/definition.lua
@@ -49,7 +49,7 @@ function TEST(data)
assert(sourceVM)
local sourcePos = (sourceList[1][1] + sourceList[1][2]) // 2
local result = core.findResult(sourceVM, sourcePos)
- local positions = core.definition(sourceVM, result)
+ local positions = core.definition(sourceVM, result, lsp)
assert(positions and positions[1])
local start, finish, valueUri = positions[1][1], positions[1][2], positions[1][3]
assert(valueUri == targetUri)
@@ -148,3 +148,22 @@ TEST {
]],
},
}
+
+TEST {
+ {
+ path = 'a.lua',
+ content = [[
+ local x = {
+ <!a!> = 1,
+ }
+ return b
+ ]],
+ },
+ {
+ path = 'b.lua',
+ content = [[
+ local t = require 'a'
+ t.<?a?>()
+ ]]
+ }
+}
diff --git a/server/test/definition/init.lua b/server/test/definition/init.lua
index 5ed638dd..d55457f6 100644
--- a/server/test/definition/init.lua
+++ b/server/test/definition/init.lua
@@ -45,7 +45,7 @@ function TEST(script)
assert(vm)
local result = core.findResult(vm, pos)
- local positions = core.definition(vm, result)
+ local positions = core.definition(vm, result, nil)
if positions then
assert(founded(target, positions))
else