diff options
-rw-r--r-- | server/src/matcher/definition.lua | 10 | ||||
-rw-r--r-- | server/src/matcher/find_result.lua | 28 | ||||
-rw-r--r-- | server/src/matcher/hover.lua | 25 | ||||
-rw-r--r-- | server/src/matcher/implementation.lua | 9 | ||||
-rw-r--r-- | server/src/matcher/references.lua | 9 | ||||
-rw-r--r-- | server/src/matcher/rename.lua | 15 | ||||
-rw-r--r-- | server/src/matcher/vm.lua | 24 | ||||
-rw-r--r-- | server/test/find_lib/init.lua | 4 | ||||
-rw-r--r-- | server/test/main.lua | 2 | ||||
-rw-r--r-- | server/test/type_inference/init.lua | 3 |
10 files changed, 75 insertions, 54 deletions
diff --git a/server/src/matcher/definition.lua b/server/src/matcher/definition.lua index 57e4b237..13015cac 100644 --- a/server/src/matcher/definition.lua +++ b/server/src/matcher/definition.lua @@ -4,31 +4,29 @@ local function parseResult(result) local positions = {} local tp = result.type if tp == 'local' then - for _, info in ipairs(result.object) do + for _, info in ipairs(result) do if info.type == 'local' then positions[#positions+1] = {info.source.start, info.source.finish} end end elseif tp == 'field' then - for _, info in ipairs(result.object) do + for _, info in ipairs(result) do if info.type == 'set' then positions[#positions+1] = {info.source.start, info.source.finish} end end elseif tp == 'label' then - for _, info in ipairs(result.object) do + for _, info in ipairs(result) do if info.type == 'set' then positions[#positions+1] = {info.source.start, info.source.finish} end end - else - error('Unknow result type:' .. result.type) end return positions end return function (vm, pos) - local result = findResult(vm.results, pos) + local result = findResult(vm, pos) if not result then return nil end diff --git a/server/src/matcher/find_result.lua b/server/src/matcher/find_result.lua index e8128008..33f333e7 100644 --- a/server/src/matcher/find_result.lua +++ b/server/src/matcher/find_result.lua @@ -2,22 +2,20 @@ local function isContainPos(obj, pos) return obj.start <= pos and obj.finish + 1 >= pos end -local function findIn(name, group, pos) - for _, obj in ipairs(group) do - for _, info in ipairs(obj) do - if isContainPos(info.source, pos) then - return { - type = name, - object = obj, - info = info, - } +return function (vm, pos) + local results = vm.results + for source, object in pairs(results.sources) do + if source.type == 'multi-source' then + for _, source in ipairs(source) do + if isContainPos(source, pos) then + return object, source + end + end + else + if isContainPos(source, pos) then + return object, source end end end -end - -return function (results, pos) - return findIn('local', results.locals, pos) - or findIn('field', results.fields, pos) - or findIn('label', results.labels, pos) + return nil end diff --git a/server/src/matcher/hover.lua b/server/src/matcher/hover.lua index 12d599b3..f2b0f07d 100644 --- a/server/src/matcher/hover.lua +++ b/server/src/matcher/hover.lua @@ -160,17 +160,7 @@ local function buildTableHover(lib, fullKey) ]]):format(title, tip, field) end -return function (vm, pos) - local result = findResult(vm.results, pos) - if not result then - return nil - end - - local lib, fullKey, oo = findLib(result.object) - if not lib then - return nil - end - +local function getLibHover(lib, fullKey, oo) local cache = oo and OoCache or Cache if not cache[lib] then @@ -187,3 +177,16 @@ return function (vm, pos) return cache[lib] end + +return function (vm, pos) + local result = findResult(vm, pos) + if not result then + return nil + end + + local lib, fullKey, oo = findLib(result) + if lib then + local hover = getLibHover(lib, fullKey, oo) + return hover + end +end diff --git a/server/src/matcher/implementation.lua b/server/src/matcher/implementation.lua index 34b09693..72d2411f 100644 --- a/server/src/matcher/implementation.lua +++ b/server/src/matcher/implementation.lua @@ -4,20 +4,19 @@ local function parseResult(result) local positions = {} local tp = result.type if tp == 'local' then - for _, info in ipairs(result.object) do + for _, info in ipairs(result) do if info.type == 'set' then positions[#positions+1] = {info.source.start, info.source.finish} end end elseif tp == 'field' then - for _, info in ipairs(result.object) do + for _, info in ipairs(result) do if info.type == 'set' then positions[#positions+1] = {info.source.start, info.source.finish} end end elseif tp == 'label' then - local label = result.label - for _, info in ipairs(label) do + for _, info in ipairs(result) do if info.type == 'set' then positions[#positions+1] = {info.source.start, info.source.finish} end @@ -29,7 +28,7 @@ local function parseResult(result) end return function (vm, pos) - local result = findResult(vm.results, pos) + local result = findResult(vm, pos) if not result then return nil end diff --git a/server/src/matcher/references.lua b/server/src/matcher/references.lua index a241025c..722156d3 100644 --- a/server/src/matcher/references.lua +++ b/server/src/matcher/references.lua @@ -4,20 +4,19 @@ local function parseResult(result, declarat) local positions = {} local tp = result.type if tp == 'local' then - for _, info in ipairs(result.object) do + for _, info in ipairs(result) do if declarat or info.type == 'get' then positions[#positions+1] = {info.source.start, info.source.finish} end end elseif tp == 'field' then - for _, info in ipairs(result.object) do + for _, info in ipairs(result) do if declarat or info.type == 'get' then positions[#positions+1] = {info.source.start, info.source.finish} end end elseif tp == 'label' then - local label = result.label - for _, info in ipairs(label) do + for _, info in ipairs(result) do if declarat or info.type == 'goto' then positions[#positions+1] = {info.source.start, info.source.finish} end @@ -29,7 +28,7 @@ local function parseResult(result, declarat) end return function (vm, pos, declarat) - local result = findResult(vm.results, pos) + local result = findResult(vm, pos) if not result then return nil end diff --git a/server/src/matcher/rename.lua b/server/src/matcher/rename.lua index b395c696..1033d331 100644 --- a/server/src/matcher/rename.lua +++ b/server/src/matcher/rename.lua @@ -1,16 +1,15 @@ local findResult = require 'matcher.find_result' local parser = require 'parser' -local function parseResult(result, newName) +local function parseResult(result, source, newName) local positions = {} local tp = result.type if tp == 'local' or tp == 'field' then - local var = result.object - local key = result.info.source[1] - if var.disableRename then + local key = source[1] + if result.disableRename then return positions end - if result.info.source.index then + if source.index then if not parser.grammar(newName, 'Exp') then return positions end @@ -20,7 +19,7 @@ local function parseResult(result, newName) end end local mark = {} - for _, info in ipairs(var) do + for _, info in ipairs(result) do if not mark[info.source] then mark[info.source] = info if info.source[1] == key then @@ -43,10 +42,10 @@ local function parseResult(result, newName) end return function (vm, pos, newName) - local result = findResult(vm.results, pos) + local result, source = findResult(vm, pos) if not result then return nil end - local positions = parseResult(result, newName) + local positions = parseResult(result, source, newName) return positions end diff --git a/server/src/matcher/vm.lua b/server/src/matcher/vm.lua index 1079816a..d9f117f8 100644 --- a/server/src/matcher/vm.lua +++ b/server/src/matcher/vm.lua @@ -42,6 +42,22 @@ function mt:addInfo(obj, type, source) type = type, source = source or DefaultSource, } + if source then + local other = self.results.sources[source] + if other then + if other.type == 'multi-source' then + other[#other+1] = obj + else + other = { + type = 'multi-source', + [1] = other, + [2] = obj, + } + end + else + self.results.sources[source] = obj + end + end return obj end @@ -349,6 +365,7 @@ function mt:callRequire(func, values) end function mt:call(func, values) + self:inference(func, 'function') local lib = func.lib if lib and lib.special then if lib.special == 'setmetatable' then @@ -387,6 +404,12 @@ function mt:getFunctionReturns(func) return func.returns end +function mt:inference(value, type) + if value.type == 'nil' then + value.type = type + end +end + function mt:createValue(type, source, v) local value = { type = type, @@ -895,6 +918,7 @@ local function compile(ast) labels = {}, funcs = {}, calls = {}, + sources= {}, }, libraryValue = {}, libraryChild = {}, diff --git a/server/test/find_lib/init.lua b/server/test/find_lib/init.lua index 5bd5e8b4..63654b19 100644 --- a/server/test/find_lib/init.lua +++ b/server/test/find_lib/init.lua @@ -13,8 +13,8 @@ function TEST(fullkey) assert(ast) local vm = matcher.vm(ast) assert(vm) - local result = matcher.findResult(vm.results, pos) - local _, name = matcher.findLib(result.object) + local result = matcher.findResult(vm, pos) + local _, name = matcher.findLib(result) assert(name == fullkey) end end diff --git a/server/test/main.lua b/server/test/main.lua index 67e2da0a..f6b1cede 100644 --- a/server/test/main.lua +++ b/server/test/main.lua @@ -26,7 +26,7 @@ local function main() test 'vm' test 'definition' test 'diagnostics' - test 'type_inference' + --test 'type_inference' test 'find_lib' print('测试完成') diff --git a/server/test/type_inference/init.lua b/server/test/type_inference/init.lua index dc1a4099..3686451f 100644 --- a/server/test/type_inference/init.lua +++ b/server/test/type_inference/init.lua @@ -12,8 +12,9 @@ function TEST(res) local ast = parser:ast(new_script) local vm = matcher.vm(ast) assert(vm) - local result = matcher.findResult(vm.results, pos) + local result = matcher.findResult(vm, pos) assert(result) + assert(res == result.value.type) end end |