summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--server/src/matcher/completion.lua64
-rw-r--r--server/src/service.lua2
-rw-r--r--server/src/workspace.lua46
-rw-r--r--server/test/crossfile/completion.lua131
-rw-r--r--server/test/crossfile/init.lua1
5 files changed, 236 insertions, 8 deletions
diff --git a/server/src/matcher/completion.lua b/server/src/matcher/completion.lua
index d75efab4..1ee06dbc 100644
--- a/server/src/matcher/completion.lua
+++ b/server/src/matcher/completion.lua
@@ -277,15 +277,61 @@ local function isInString(vm, pos)
return false
end
+local function findArgCount(args, pos)
+ for i, arg in ipairs(args) do
+ if isContainPos(arg, pos) then
+ return i, arg
+ end
+ end
+ return #args + 1, nil
+end
+
+-- 找出范围包含pos的call
+local function findCall(vm, pos)
+ local results = {}
+ for _, call in ipairs(vm.results.calls) do
+ if isContainPos(call.args, pos) then
+ local n, arg = findArgCount(call.args, pos)
+ if arg and arg.type ~= 'string' then
+ return nil
+ end
+ local var = vm.results.sources[call.lastObj]
+ if var then
+ results[#results+1] = {
+ func = call.func,
+ var = var,
+ source = call.lastObj,
+ select = n,
+ args = call.args,
+ }
+ end
+ end
+ end
+ if #results == 0 then
+ return nil
+ end
+ -- 可能处于 'func1(func2(' 的嵌套中,因此距离越远的函数层级越低
+ table.sort(results, function (a, b)
+ return a.args.start < b.args.start
+ end)
+ return results
+end
+
return function (vm, pos)
local result, source = findResult(vm, pos)
+ local inCall
if not result then
result, source = findClosePos(vm, pos)
if not result then
return nil
end
if isInString(vm, pos) then
- return nil
+ do return nil end
+ local calls = findCall(vm, pos)
+ if not calls then
+ return nil
+ end
+ inCall = calls[#calls]
end
end
@@ -317,13 +363,17 @@ return function (vm, pos)
end
end
- if result.type == 'local' then
- searchAsGlobal(vm, pos, result, callback)
- elseif result.type == 'field' then
- if result.parent and result.parent.value and result.parent.value.ENV == true then
+ if inCall then
+ searchAsArg(vm, pos, result, callback)
+ else
+ if result.type == 'local' then
searchAsGlobal(vm, pos, result, callback)
- else
- searchAsSuffix(result, callback)
+ elseif result.type == 'field' then
+ if result.parent and result.parent.value and result.parent.value.ENV == true then
+ searchAsGlobal(vm, pos, result, callback)
+ else
+ searchAsSuffix(result, callback)
+ end
end
end
if #list == 0 then
diff --git a/server/src/service.lua b/server/src/service.lua
index 0c167570..3ac3bb87 100644
--- a/server/src/service.lua
+++ b/server/src/service.lua
@@ -96,7 +96,7 @@ function mt:_doDiagnostic()
end
for uri in pairs(copy) do
local obj = self._file[uri]
- if obj then
+ if obj and obj.vm then
local data = {
uri = uri,
vm = obj.vm,
diff --git a/server/src/workspace.lua b/server/src/workspace.lua
index 42ce3343..76ee7b0b 100644
--- a/server/src/workspace.lua
+++ b/server/src/workspace.lua
@@ -141,6 +141,50 @@ function mt:findPath(baseUri, searchers)
return uri
end
+function mt:compileLuaPath()
+ for i, luapath in ipairs(self.luapath) do
+ self.compiledpath[i] = '^' .. luapath:gsub('%?', '(.-)'):gsub('%.', '%%.') .. '$'
+ end
+end
+
+function mt:convertPathAsRequire(filename, start)
+ local list
+ for _, luapath in ipairs(self.compiledpath) do
+ local str = filename:match(luapath, start)
+ if str then
+ if not list then
+ list = {}
+ end
+ list[#list+1] = str
+ end
+ end
+ return list
+end
+
+function mt:matchPath(baseUri, str)
+ local first = str:match '[^%.]+'
+ if not first then
+ return nil
+ end
+ local rootLen = #self.root:string()
+ local results = {}
+ for filename in pairs(self.files) do
+ local start = filename:find('/' .. first, true, rootLen + 1)
+ if start then
+ local list = self:convertPathAsRequire(filename, start + 1)
+ if list then
+ for _, str in ipairs(list) do
+ if not results[str] then
+ results[str] = true
+ results[#results+1] = str
+ end
+ end
+ end
+ end
+ end
+ return results
+end
+
function mt:searchPath(baseUri, str)
if self.searched[str] then
return self.searched[str]
@@ -196,6 +240,8 @@ return function (lsp, name)
'?/init.lua',
'?/?.lua',
},
+ compiledpath = {}
}, mt)
+ workspace:compileLuaPath()
return workspace
end
diff --git a/server/test/crossfile/completion.lua b/server/test/crossfile/completion.lua
new file mode 100644
index 00000000..d98de0cc
--- /dev/null
+++ b/server/test/crossfile/completion.lua
@@ -0,0 +1,131 @@
+local service = require 'service'
+local workspace = require 'workspace'
+local fs = require 'bee.filesystem'
+local matcher = require 'matcher'
+
+rawset(_G, 'TEST', true)
+
+local CompletionItemKind = {
+ Text = 1,
+ Method = 2,
+ Function = 3,
+ Constructor = 4,
+ Field = 5,
+ Variable = 6,
+ Class = 7,
+ Interface = 8,
+ Module = 9,
+ Property = 10,
+ Unit = 11,
+ Value = 12,
+ Enum = 13,
+ Keyword = 14,
+ Snippet = 15,
+ Color = 16,
+ File = 17,
+ Reference = 18,
+ Folder = 19,
+ EnumMember = 20,
+ Constant = 21,
+ Struct = 22,
+ Event = 23,
+ Operator = 24,
+ TypeParameter = 25,
+}
+
+local EXISTS = {}
+
+local function eq(a, b)
+ if a == EXISTS and b ~= nil then
+ return true
+ end
+ local tp1, tp2 = type(a), type(b)
+ if tp1 ~= tp2 then
+ return false
+ end
+ if tp1 == 'table' then
+ local mark = {}
+ for k in pairs(a) do
+ if not eq(a[k], b[k]) then
+ return false
+ end
+ mark[k] = true
+ end
+ for k in pairs(b) do
+ if not mark[k] then
+ return false
+ end
+ end
+ return true
+ end
+ return a == b
+end
+
+function TEST(data)
+ local lsp = service()
+ local ws = workspace(lsp, 'test')
+ lsp.workspace = ws
+ ws.root = ROOT
+
+ local mainUri
+ local pos
+ for _, info in ipairs(data) do
+ local uri = ws:uriEncode(fs.path(info.path))
+ local script = info.content
+ if info.main then
+ pos = script:find('@', 1, true)
+ script = script:gsub('@', '')
+ mainUri = uri
+ end
+ lsp:saveText(uri, 1, script)
+ ws:addFile(uri)
+ end
+
+ lsp:compileAll()
+ lsp:compileAll()
+
+ local vm = lsp:loadVM(mainUri)
+ assert(vm)
+ local result = matcher.completion(vm, pos)
+ local expect = data.completion
+ if expect then
+ assert(result)
+ assert(eq(expect, result))
+ else
+ assert(result == nil)
+ end
+end
+
+TEST {
+ {
+ path = 'abc.lua',
+ content = '',
+ },
+ {
+ path = 'abc/aaa.lua',
+ content = '',
+ },
+ {
+ path = 'xxx/abcde.lua',
+ content = '',
+ },
+ {
+ path = 'test.lua',
+ content = 'require "a@"',
+ main = true,
+ },
+ completion = {
+ {
+ label = 'abc',
+ kind = CompletionItemKind.Module,
+ },
+ {
+ label = 'abc.aaa',
+ kind = CompletionItemKind.Module,
+ },
+ {
+ label = 'abcde',
+ kind = CompletionItemKind.Module,
+ },
+ }
+}
diff --git a/server/test/crossfile/init.lua b/server/test/crossfile/init.lua
index e68a253c..b74514af 100644
--- a/server/test/crossfile/init.lua
+++ b/server/test/crossfile/init.lua
@@ -1,2 +1,3 @@
require 'crossfile.definition'
require 'crossfile.hover'
+require 'crossfile.completion'