summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--server/src/core/completion.lua99
-rw-r--r--server/src/emmy/manager.lua3
-rw-r--r--server/src/emmy/param.lua11
-rw-r--r--server/src/vm/function.lua9
-rw-r--r--server/test/completion/init.lua44
5 files changed, 133 insertions, 33 deletions
diff --git a/server/src/core/completion.lua b/server/src/core/completion.lua
index 59a3cb3f..83d740a3 100644
--- a/server/src/core/completion.lua
+++ b/server/src/core/completion.lua
@@ -509,39 +509,7 @@ local function searchInRequire(vm, select, source, callback)
end
end
-local function searchCallArg(vm, source, word, callback, pos)
- local results = {}
- vm:eachSource(function (src)
- if src.type == 'call'
- and src.start <= pos
- and src.finish >= pos
- then
- results[#results+1] = src
- end
- end)
- if #results == 0 then
- return nil
- end
- -- 可能处于 'func1(func2(' 的嵌套中,将最近的call放到最前面
- table.sort(results, function (a, b)
- return a.start > b.start
- end)
- local call = results[1]
- local args = call:bindCall()
- if not args then
- return
- end
-
- local value = call:findCallFunction()
- if not value then
- return
- end
-
- local lib = value:getLib()
- if not lib then
- return
- end
-
+local function searchEnumAsLib(vm, source, word, callback, pos, args, lib)
local select = #args + 1
for i, arg in ipairs(args) do
if arg.start <= pos and arg.finish >= pos - 1 then
@@ -584,6 +552,71 @@ local function searchCallArg(vm, source, word, callback, pos)
end
end
+local function searchEnumAsEmmyParams(vm, source, word, callback, pos, args, func)
+ local select = #args + 1
+ for i, arg in ipairs(args) do
+ if arg.start <= pos and arg.finish >= pos - 1 then
+ select = i
+ break
+ end
+ end
+
+ local param = func:findEmmyParamByIndex(select)
+ if not param then
+ return
+ end
+
+ param:eachEnum(function (enum)
+ if matchKey(word, enum) then
+ callback(enum, nil, CompletionItemKind.EnumMember, {
+ label = enum,
+ })
+ end
+ end)
+end
+
+local function searchCallArg(vm, source, word, callback, pos)
+ local results = {}
+ vm:eachSource(function (src)
+ if src.type == 'call'
+ and src.start <= pos
+ and src.finish >= pos
+ then
+ results[#results+1] = src
+ end
+ end)
+ if #results == 0 then
+ return nil
+ end
+ -- 可能处于 'func1(func2(' 的嵌套中,将最近的call放到最前面
+ table.sort(results, function (a, b)
+ return a.start > b.start
+ end)
+ local call = results[1]
+ local args = call:bindCall()
+ if not args then
+ return
+ end
+
+ local value = call:findCallFunction()
+ if not value then
+ return
+ end
+
+ local lib = value:getLib()
+ if lib then
+ searchEnumAsLib(vm, source, word, callback, pos, args, lib)
+ return
+ end
+
+ ---@type function
+ local func = value:getFunction()
+ if func then
+ searchEnumAsEmmyParams(vm, source, word, callback, pos, args, func)
+ return
+ end
+end
+
local function searchAllWords(vm, source, word, callback, pos)
if word == '' then
return
diff --git a/server/src/emmy/manager.lua b/server/src/emmy/manager.lua
index 128ed029..85f25bd1 100644
--- a/server/src/emmy/manager.lua
+++ b/server/src/emmy/manager.lua
@@ -138,6 +138,9 @@ function mt:addParam(source, bind)
elseif bind.type == 'emmy.generic' then
paramObj:bindGeneric(bind)
end
+ for i = 3, #source do
+ paramObj:addEnum(source[i][1])
+ end
return paramObj
end
diff --git a/server/src/emmy/param.lua b/server/src/emmy/param.lua
index 9a2d407f..5894e28e 100644
--- a/server/src/emmy/param.lua
+++ b/server/src/emmy/param.lua
@@ -29,10 +29,21 @@ function mt:bindGeneric(generic)
end
end
+function mt:addEnum(str)
+ self._enum[#self._enum+1] = str
+end
+
+function mt:eachEnum(callback)
+ for _, str in ipairs(self._enum) do
+ callback(str)
+ end
+end
+
return function (manager, source)
local self = setmetatable({
source = source.id,
_manager = manager,
+ _enum = {},
}, mt)
if source.type == 'emmyParam' then
self.name = source[1][1]
diff --git a/server/src/vm/function.lua b/server/src/vm/function.lua
index e605e1c9..593d40dd 100644
--- a/server/src/vm/function.lua
+++ b/server/src/vm/function.lua
@@ -349,6 +349,15 @@ function mt:findEmmyParamByName(name)
return nil
end
+function mt:findEmmyParamByIndex(index)
+ local arg = self.args[index]
+ if not arg then
+ return nil
+ end
+ local name = arg:getName()
+ return self:findEmmyParamByName(name)
+end
+
function mt:addArg(name, source, value)
local loc = localMgr.create(name, source, value)
self:saveUpvalue(name, loc)
diff --git a/server/test/completion/init.lua b/server/test/completion/init.lua
index 17ab0369..de924bd7 100644
--- a/server/test/completion/init.lua
+++ b/server/test/completion/init.lua
@@ -1055,3 +1055,47 @@ end
kind = CompletionItemKind.Keyword,
},
}
+
+TEST [[
+---@param x string | "'AAA'" | "'BBB'" | "'CCC'"
+function f(y, x)
+end
+
+f(1, $)
+]]
+{
+ {
+ label = "'AAA'",
+ kind = CompletionItemKind.EnumMember,
+ },
+ {
+ label = "'BBB'",
+ kind = CompletionItemKind.EnumMember,
+ },
+ {
+ label = "'CCC'",
+ kind = CompletionItemKind.EnumMember,
+ }
+}
+
+TEST [[
+---@param x string | "'AAA'" | "'BBB'" | "'CCC'"
+function f(x)
+end
+
+f($)
+]]
+{
+ {
+ label = "'AAA'",
+ kind = CompletionItemKind.EnumMember,
+ },
+ {
+ label = "'BBB'",
+ kind = CompletionItemKind.EnumMember,
+ },
+ {
+ label = "'CCC'",
+ kind = CompletionItemKind.EnumMember,
+ }
+}