summaryrefslogtreecommitdiff
path: root/server
diff options
context:
space:
mode:
Diffstat (limited to 'server')
-rw-r--r--server/src/core/definition.lua32
-rw-r--r--server/src/core/find_result.lua27
-rw-r--r--server/src/core/vm.lua45
-rw-r--r--server/test/definition/arg.lua1
4 files changed, 57 insertions, 48 deletions
diff --git a/server/src/core/definition.lua b/server/src/core/definition.lua
index 806007b1..d7f0019d 100644
--- a/server/src/core/definition.lua
+++ b/server/src/core/definition.lua
@@ -1,7 +1,8 @@
-local function findFieldBySource(positions, source, obj, result)
+local function findFieldBySource(positions, source, vm, result)
if source.type == 'name' and source[1] == result.key then
+ local obj = source.object
if obj.type == 'field' then
- for _, info in ipairs(obj) do
+ vm:eachInfo(obj, function (info)
if info.type == 'set' and info.source == source then
positions[#positions+1] = {
source.start,
@@ -9,20 +10,14 @@ local function findFieldBySource(positions, source, obj, result)
source.uri,
}
end
- end
+ end)
end
end
end
local function findFieldByName(positions, vm, result)
- for source, obj in pairs(vm.results.sources) do
- if source.type == 'multi-source' then
- for i = 1, #obj do
- findFieldBySource(positions, source, obj[i], result)
- end
- else
- findFieldBySource(positions, source, obj, result)
- end
+ for _, source in pairs(vm.results.sources) do
+ findFieldBySource(positions, source, vm, result)
end
end
@@ -82,8 +77,9 @@ local function parseResultAsVar(vm, result, lsp)
end
if result.value.uri ~= vm.uri then
parseResultAcrossUri(positions, vm, result)
- else
- for _, info in ipairs(result) do
+ elseif result.link then
+ result = result.link
+ vm:eachInfo(result, function (info)
if info.type == 'local' then
positions[#positions+1] = {
info.source.start,
@@ -91,7 +87,7 @@ local function parseResultAsVar(vm, result, lsp)
info.source.uri,
}
end
- end
+ end)
end
elseif tp == 'field' then
if result.value.lib then
@@ -100,7 +96,7 @@ local function parseResultAsVar(vm, result, lsp)
if result.value.uri ~= vm.uri then
parseResultAcrossUri(positions, vm, result)
else
- for _, info in ipairs(result) do
+ vm:eachInfo(result, function (info)
if info.type == 'set' then
positions[#positions+1] = {
info.source.start,
@@ -108,21 +104,21 @@ local function parseResultAsVar(vm, result, lsp)
info.source.uri,
}
end
- end
+ end)
if #positions == 0 then
findFieldByName(positions, vm, result)
findFieldCrossUriByName(positions, vm, result, lsp)
end
end
elseif tp == 'label' then
- for _, info in ipairs(result) do
+ vm:eachInfo(result, function (info)
if info.type == 'set' then
positions[#positions+1] = {
info.source.start,
info.source.finish,
}
end
- end
+ end)
end
return positions
end
diff --git a/server/src/core/find_result.lua b/server/src/core/find_result.lua
index e4074879..dae0e792 100644
--- a/server/src/core/find_result.lua
+++ b/server/src/core/find_result.lua
@@ -11,26 +11,13 @@ end
local function findAtPos(results, pos, level)
local res = {}
- for sources, object in pairs(results.sources) do
- if sources.type == 'multi-source' then
- for _, source in ipairs(sources) do
- if isValidSource(source) and isContainPos(source, pos) then
- res[#res+1] = {
- object = object,
- source = source,
- range = source.finish - source.start,
- }
- end
- end
- else
- local source = sources
- if isValidSource(source) and isContainPos(source, pos) then
- res[#res+1] = {
- object = object,
- source = source,
- range = source.finish - source.start,
- }
- end
+ for _, source in ipairs(results.sources) do
+ if isValidSource(source) and isContainPos(source, pos) then
+ res[#res+1] = {
+ object = source.object,
+ source = source,
+ range = source.finish - source.start,
+ }
end
end
if #res == 0 then
diff --git a/server/src/core/vm.lua b/server/src/core/vm.lua
index 6ccccd4b..e19ac0e7 100644
--- a/server/src/core/vm.lua
+++ b/server/src/core/vm.lua
@@ -91,8 +91,8 @@ function mt:createDummyVar(source, value)
end
function mt:createLocal(key, source, value)
- if self.results.sources[source] then
- return self.results.sources[source]
+ if source and source.object then
+ return source.object
end
local loc = {
type = 'local',
@@ -102,7 +102,8 @@ function mt:createLocal(key, source, value)
}
if source then
- self.results.sources[source] = loc
+ source.object = loc
+ self.results.sources[#self.results.sources+1] = source
source.isLocal = true
end
@@ -125,18 +126,22 @@ function mt:createLocal(key, source, value)
if source then
self:addInfo(loc, 'local', source, value)
+ if value then
+ value:addInfo('local', source, loc)
+ end
end
self:setValue(loc, value, source)
return loc
end
function mt:createField(value, index, source)
- if self.results.sources[source] then
- return self.results.sources[source]
+ if source and source.object then
+ return source.object
end
local field = value:createField(index, source)
if source then
- self.results.sources[source] = field
+ source.object = field
+ self.results.sources[#self.results.sources+1] = source
end
return field
end
@@ -147,7 +152,8 @@ function mt:getField(value, index, source)
return nil
end
if source then
- self.results.sources[source] = field
+ source.object = field
+ self.results.sources[#self.results.sources+1] = source
end
return field
end
@@ -193,6 +199,19 @@ function mt:addInfo(var, type, source, value)
self.results.infos[var][#self.results.infos[var]+1] = info
end
+function mt:eachInfo(var, callback)
+ if not self.results.infos[var] then
+ return nil
+ end
+ for _, info in ipairs(self.results.infos[var]) do
+ local res = callback(info)
+ if res ~= nil then
+ return res
+ end
+ end
+ return nil
+end
+
function mt:createDots(index, source)
local dots = {
type = 'dots',
@@ -319,8 +338,9 @@ function mt:runFunction(func)
local index = 0
if func.object then
- local var = self:createArg('self', func.object.source, self:getValue(func.object))
+ local var = self:createArg('self', func.object.colon, self:getValue(func.object))
var.hide = true
+ var.link = func.object
self:setValue(var, func.argValues[1] or self:createValue('nil'))
index = 1
func.args[index] = var
@@ -464,7 +484,7 @@ function mt:tryRequireOne(strValue, mode)
if type(str) == 'string' then
-- 支持 require 'xxx' 的转到定义
local strSource = strValue.source
- self.results.sources[strSource] = strValue
+ strSource.object = strValue
strValue.isRequire = true
local uri
@@ -726,6 +746,8 @@ end
function mt:getName(name, source)
local loc = self.scope.locals[name]
if loc then
+ source.object = loc
+ self.results.sources[#self.results.sources+1] = source
return loc
end
source.uri = self.uri
@@ -752,6 +774,8 @@ function mt:setName(name, source, value)
source.uri = self.uri
local loc = self.scope.locals[name]
if loc then
+ source.object = loc
+ self.results.sources[#self.results.sources+1] = source
self:setValue(loc, value, source)
return
end
@@ -940,6 +964,7 @@ function mt:getSimple(simple, mode)
parentName = parentName .. '.' .. field.key
elseif tp == ':' then
object = field
+ object.colon = obj
simple[i-1].colon = obj
elseif tp == '.' then
simple[i-1].dot = obj
@@ -1438,10 +1463,10 @@ local function compile(ast, lsp, uri)
labels = {},
funcs = {},
calls = {},
- sources= {},
strings= {},
indexs = {},
infos = {},
+ sources= {},
main = nil,
},
lsp = lsp,
diff --git a/server/test/definition/arg.lua b/server/test/definition/arg.lua
index 3d92da9a..9e88b2bf 100644
--- a/server/test/definition/arg.lua
+++ b/server/test/definition/arg.lua
@@ -9,6 +9,7 @@ local <!mt!>
function mt:x()
<?self?>()
end
+mt:x()
]]
TEST [[