diff options
-rw-r--r-- | server/src/core/definition.lua | 32 | ||||
-rw-r--r-- | server/src/core/find_result.lua | 27 | ||||
-rw-r--r-- | server/src/core/vm.lua | 45 | ||||
-rw-r--r-- | server/test/definition/arg.lua | 1 |
4 files changed, 57 insertions, 48 deletions
diff --git a/server/src/core/definition.lua b/server/src/core/definition.lua index 806007b1..d7f0019d 100644 --- a/server/src/core/definition.lua +++ b/server/src/core/definition.lua @@ -1,7 +1,8 @@ -local function findFieldBySource(positions, source, obj, result) +local function findFieldBySource(positions, source, vm, result) if source.type == 'name' and source[1] == result.key then + local obj = source.object if obj.type == 'field' then - for _, info in ipairs(obj) do + vm:eachInfo(obj, function (info) if info.type == 'set' and info.source == source then positions[#positions+1] = { source.start, @@ -9,20 +10,14 @@ local function findFieldBySource(positions, source, obj, result) source.uri, } end - end + end) end end end local function findFieldByName(positions, vm, result) - for source, obj in pairs(vm.results.sources) do - if source.type == 'multi-source' then - for i = 1, #obj do - findFieldBySource(positions, source, obj[i], result) - end - else - findFieldBySource(positions, source, obj, result) - end + for _, source in pairs(vm.results.sources) do + findFieldBySource(positions, source, vm, result) end end @@ -82,8 +77,9 @@ local function parseResultAsVar(vm, result, lsp) end if result.value.uri ~= vm.uri then parseResultAcrossUri(positions, vm, result) - else - for _, info in ipairs(result) do + elseif result.link then + result = result.link + vm:eachInfo(result, function (info) if info.type == 'local' then positions[#positions+1] = { info.source.start, @@ -91,7 +87,7 @@ local function parseResultAsVar(vm, result, lsp) info.source.uri, } end - end + end) end elseif tp == 'field' then if result.value.lib then @@ -100,7 +96,7 @@ local function parseResultAsVar(vm, result, lsp) if result.value.uri ~= vm.uri then parseResultAcrossUri(positions, vm, result) else - for _, info in ipairs(result) do + vm:eachInfo(result, function (info) if info.type == 'set' then positions[#positions+1] = { info.source.start, @@ -108,21 +104,21 @@ local function parseResultAsVar(vm, result, lsp) info.source.uri, } end - end + end) if #positions == 0 then findFieldByName(positions, vm, result) findFieldCrossUriByName(positions, vm, result, lsp) end end elseif tp == 'label' then - for _, info in ipairs(result) do + vm:eachInfo(result, function (info) if info.type == 'set' then positions[#positions+1] = { info.source.start, info.source.finish, } end - end + end) end return positions end diff --git a/server/src/core/find_result.lua b/server/src/core/find_result.lua index e4074879..dae0e792 100644 --- a/server/src/core/find_result.lua +++ b/server/src/core/find_result.lua @@ -11,26 +11,13 @@ end local function findAtPos(results, pos, level) local res = {} - for sources, object in pairs(results.sources) do - if sources.type == 'multi-source' then - for _, source in ipairs(sources) do - if isValidSource(source) and isContainPos(source, pos) then - res[#res+1] = { - object = object, - source = source, - range = source.finish - source.start, - } - end - end - else - local source = sources - if isValidSource(source) and isContainPos(source, pos) then - res[#res+1] = { - object = object, - source = source, - range = source.finish - source.start, - } - end + for _, source in ipairs(results.sources) do + if isValidSource(source) and isContainPos(source, pos) then + res[#res+1] = { + object = source.object, + source = source, + range = source.finish - source.start, + } end end if #res == 0 then diff --git a/server/src/core/vm.lua b/server/src/core/vm.lua index 6ccccd4b..e19ac0e7 100644 --- a/server/src/core/vm.lua +++ b/server/src/core/vm.lua @@ -91,8 +91,8 @@ function mt:createDummyVar(source, value) end function mt:createLocal(key, source, value) - if self.results.sources[source] then - return self.results.sources[source] + if source and source.object then + return source.object end local loc = { type = 'local', @@ -102,7 +102,8 @@ function mt:createLocal(key, source, value) } if source then - self.results.sources[source] = loc + source.object = loc + self.results.sources[#self.results.sources+1] = source source.isLocal = true end @@ -125,18 +126,22 @@ function mt:createLocal(key, source, value) if source then self:addInfo(loc, 'local', source, value) + if value then + value:addInfo('local', source, loc) + end end self:setValue(loc, value, source) return loc end function mt:createField(value, index, source) - if self.results.sources[source] then - return self.results.sources[source] + if source and source.object then + return source.object end local field = value:createField(index, source) if source then - self.results.sources[source] = field + source.object = field + self.results.sources[#self.results.sources+1] = source end return field end @@ -147,7 +152,8 @@ function mt:getField(value, index, source) return nil end if source then - self.results.sources[source] = field + source.object = field + self.results.sources[#self.results.sources+1] = source end return field end @@ -193,6 +199,19 @@ function mt:addInfo(var, type, source, value) self.results.infos[var][#self.results.infos[var]+1] = info end +function mt:eachInfo(var, callback) + if not self.results.infos[var] then + return nil + end + for _, info in ipairs(self.results.infos[var]) do + local res = callback(info) + if res ~= nil then + return res + end + end + return nil +end + function mt:createDots(index, source) local dots = { type = 'dots', @@ -319,8 +338,9 @@ function mt:runFunction(func) local index = 0 if func.object then - local var = self:createArg('self', func.object.source, self:getValue(func.object)) + local var = self:createArg('self', func.object.colon, self:getValue(func.object)) var.hide = true + var.link = func.object self:setValue(var, func.argValues[1] or self:createValue('nil')) index = 1 func.args[index] = var @@ -464,7 +484,7 @@ function mt:tryRequireOne(strValue, mode) if type(str) == 'string' then -- 支持 require 'xxx' 的转到定义 local strSource = strValue.source - self.results.sources[strSource] = strValue + strSource.object = strValue strValue.isRequire = true local uri @@ -726,6 +746,8 @@ end function mt:getName(name, source) local loc = self.scope.locals[name] if loc then + source.object = loc + self.results.sources[#self.results.sources+1] = source return loc end source.uri = self.uri @@ -752,6 +774,8 @@ function mt:setName(name, source, value) source.uri = self.uri local loc = self.scope.locals[name] if loc then + source.object = loc + self.results.sources[#self.results.sources+1] = source self:setValue(loc, value, source) return end @@ -940,6 +964,7 @@ function mt:getSimple(simple, mode) parentName = parentName .. '.' .. field.key elseif tp == ':' then object = field + object.colon = obj simple[i-1].colon = obj elseif tp == '.' then simple[i-1].dot = obj @@ -1438,10 +1463,10 @@ local function compile(ast, lsp, uri) labels = {}, funcs = {}, calls = {}, - sources= {}, strings= {}, indexs = {}, infos = {}, + sources= {}, main = nil, }, lsp = lsp, diff --git a/server/test/definition/arg.lua b/server/test/definition/arg.lua index 3d92da9a..9e88b2bf 100644 --- a/server/test/definition/arg.lua +++ b/server/test/definition/arg.lua @@ -9,6 +9,7 @@ local <!mt!> function mt:x() <?self?>() end +mt:x() ]] TEST [[ |