diff options
-rw-r--r-- | server/src/core/definition.lua | 70 | ||||
-rw-r--r-- | server/src/core/implementation.lua | 69 | ||||
-rw-r--r-- | server/src/method/textDocument/definition.lua | 4 | ||||
-rw-r--r-- | server/src/method/textDocument/implementation.lua | 2 | ||||
-rw-r--r-- | server/test/crossfile/definition.lua | 21 | ||||
-rw-r--r-- | server/test/definition/init.lua | 2 |
6 files changed, 111 insertions, 57 deletions
diff --git a/server/src/core/definition.lua b/server/src/core/definition.lua index 60038724..cd44f94f 100644 --- a/server/src/core/definition.lua +++ b/server/src/core/definition.lua @@ -1,3 +1,32 @@ +local function findFieldBySource(positions, source, obj, result) + if source.type == 'name' and source[1] == result.key then + if obj.type == 'field' then + for _, info in ipairs(obj) do + if info.type == 'set' and info.source == source then + positions[#positions+1] = { + source.start, + source.finish, + source.uri, + } + end + end + end + end +end + +local function findFieldByName(positions, vm, result) + for source, obj in pairs(vm.results.sources) do + if source.type == 'multi-source' then + for i = 1, #obj do + findFieldBySource(positions, source, obj[i], result) + end + else + findFieldBySource(positions, source, obj, result) + end + end +end + + local function parseResultAcrossUri(positions, vm, result) -- 跨越文件时,遍历的是值的绑定信息 for _, info in ipairs(result.value) do @@ -28,35 +57,23 @@ local function parseResultAcrossUri(positions, vm, result) end end -local function findFieldBySource(positions, source, obj, result) - if source.type == 'name' and source[1] == result.key then - if obj.type == 'field' then - for _, info in ipairs(obj) do - if info.type == 'set' and info.source == source then - positions[#positions+1] = { - source.start, - source.finish, - source.uri, - } - end - end - end +local function findFieldCrossUriByName(positions, vm, result, lsp) + if not lsp then + return end -end - -local function findFieldByName(positions, vm, result) - for source, obj in pairs(vm.results.sources) do - if source.type == 'multi-source' then - for i = 1, #obj do - findFieldBySource(positions, source, obj[i], result) - end - else - findFieldBySource(positions, source, obj, result) + local parentValue = result.parentValue + if not parentValue then + return + end + if parentValue.uri ~= vm.uri then + local destVM = lsp:loadVM(parentValue.uri) + if destVM then + findFieldByName(positions, destVM, result) end end end -local function parseResult(vm, result) +local function parseResult(vm, result, lsp) local positions = {} local tp = result.type if tp == 'local' then @@ -88,6 +105,7 @@ local function parseResult(vm, result) end if #positions == 0 then findFieldByName(positions, vm, result) + findFieldCrossUriByName(positions, vm, result, lsp) end end elseif tp == 'label' then @@ -110,10 +128,10 @@ local function parseResult(vm, result) return positions end -return function (vm, result) +return function (vm, result, lsp) if not result then return nil end - local positions = parseResult(vm, result) + local positions = parseResult(vm, result, lsp) return positions end diff --git a/server/src/core/implementation.lua b/server/src/core/implementation.lua index c10cfaf9..f6593cf2 100644 --- a/server/src/core/implementation.lua +++ b/server/src/core/implementation.lua @@ -1,3 +1,31 @@ +local function findFieldBySource(positions, source, obj, result) + if source.type == 'name' and source[1] == result.key then + if obj.type == 'field' then + for _, info in ipairs(obj) do + if info.type == 'set' and info.source == source then + positions[#positions+1] = { + source.start, + source.finish, + source.uri, + } + end + end + end + end +end + +local function findFieldByName(positions, vm, result) + for source, obj in pairs(vm.results.sources) do + if source.type == 'multi-source' then + for i = 1, #obj do + findFieldBySource(positions, source, obj[i], result) + end + else + findFieldBySource(positions, source, obj, result) + end + end +end + local function parseResultAcrossUri(positions, vm, result) -- 跨越文件时,遍历的是值的绑定信息 for _, info in ipairs(result.value) do @@ -28,35 +56,23 @@ local function parseResultAcrossUri(positions, vm, result) end end -local function findFieldBySource(positions, source, obj, result) - if source.type == 'name' and source[1] == result.key then - if obj.type == 'field' then - for _, info in ipairs(obj) do - if info.type == 'set' and info.source == source then - positions[#positions+1] = { - source.start, - source.finish, - source.uri, - } - end - end - end +local function findFieldCrossUriByName(positions, vm, result, lsp) + if not lsp then + return end -end - -local function findFieldByName(positions, vm, result) - for source, obj in pairs(vm.results.sources) do - if source.type == 'multi-source' then - for i = 1, #obj do - findFieldBySource(positions, source, obj[i], result) - end - else - findFieldBySource(positions, source, obj, result) + local parentValue = result.parentValue + if not parentValue then + return + end + if parentValue.uri ~= vm.uri then + local destVM = lsp:loadVM(parentValue.uri) + if destVM then + findFieldByName(positions, destVM, result) end end end -local function parseResult(vm, result) +local function parseResult(vm, result, lsp) local positions = {} local tp = result.type if tp == 'local' then @@ -88,6 +104,7 @@ local function parseResult(vm, result) end if #positions == 0 then findFieldByName(positions, vm, result) + findFieldCrossUriByName(positions, vm, result, lsp) end end elseif tp == 'label' then @@ -110,10 +127,10 @@ local function parseResult(vm, result) return positions end -return function (vm, result) +return function (vm, result, lsp) if not result then return nil end - local positions = parseResult(vm, result) + local positions = parseResult(vm, result, lsp) return positions end diff --git a/server/src/method/textDocument/definition.lua b/server/src/method/textDocument/definition.lua index a9f92631..17351318 100644 --- a/server/src/method/textDocument/definition.lua +++ b/server/src/method/textDocument/definition.lua @@ -24,14 +24,14 @@ return function (lsp, params) end -- lua是从1开始的,因此都要+1 local position = lines:position(params.position.line + 1, params.position.character + 1) - local result = core.findResult(vm, position) + local result = core.findResult(vm, position, lsp) if not result then return nil end checkWorkSpaceComplete(lsp, result) - local positions = core.definition(vm, result) + local positions = core.definition(vm, result, lsp) if not positions then return nil end diff --git a/server/src/method/textDocument/implementation.lua b/server/src/method/textDocument/implementation.lua index df6f027b..7117aba1 100644 --- a/server/src/method/textDocument/implementation.lua +++ b/server/src/method/textDocument/implementation.lua @@ -28,7 +28,7 @@ return function (lsp, params) checkWorkSpaceComplete(lsp, result) - local positions = core.implementation(vm, result) + local positions = core.implementation(vm, result, lsp) if not positions then return nil end diff --git a/server/test/crossfile/definition.lua b/server/test/crossfile/definition.lua index 8d8c7569..dda1bf9a 100644 --- a/server/test/crossfile/definition.lua +++ b/server/test/crossfile/definition.lua @@ -49,7 +49,7 @@ function TEST(data) assert(sourceVM) local sourcePos = (sourceList[1][1] + sourceList[1][2]) // 2 local result = core.findResult(sourceVM, sourcePos) - local positions = core.definition(sourceVM, result) + local positions = core.definition(sourceVM, result, lsp) assert(positions and positions[1]) local start, finish, valueUri = positions[1][1], positions[1][2], positions[1][3] assert(valueUri == targetUri) @@ -148,3 +148,22 @@ TEST { ]], }, } + +TEST { + { + path = 'a.lua', + content = [[ + local x = { + <!a!> = 1, + } + return b + ]], + }, + { + path = 'b.lua', + content = [[ + local t = require 'a' + t.<?a?>() + ]] + } +} diff --git a/server/test/definition/init.lua b/server/test/definition/init.lua index 5ed638dd..d55457f6 100644 --- a/server/test/definition/init.lua +++ b/server/test/definition/init.lua @@ -45,7 +45,7 @@ function TEST(script) assert(vm) local result = core.findResult(vm, pos) - local positions = core.definition(vm, result) + local positions = core.definition(vm, result, nil) if positions then assert(founded(target, positions)) else |