summaryrefslogtreecommitdiff
path: root/script/vm
diff options
context:
space:
mode:
author最萌小汐 <sumneko@hotmail.com>2022-06-21 20:23:41 +0800
committer最萌小汐 <sumneko@hotmail.com>2022-06-21 20:23:41 +0800
commitbb1244fa62c0158490d2e54da0a19f28f16fe994 (patch)
treebea39497ee52f7b4de442438bad50c90668a0647 /script/vm
parentf76cd50992dab57a57c61e8e6f5f788745544da9 (diff)
downloadlua-language-server-bb1244fa62c0158490d2e54da0a19f28f16fe994.zip
resolve #871 infer called function by params num
Diffstat (limited to 'script/vm')
-rw-r--r--script/vm/compiler.lua39
-rw-r--r--script/vm/function.lua127
-rw-r--r--script/vm/init.lua1
3 files changed, 146 insertions, 21 deletions
diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua
index 99bd0691..515a8ebe 100644
--- a/script/vm/compiler.lua
+++ b/script/vm/compiler.lua
@@ -554,32 +554,29 @@ local function getReturn(func, index, args)
end
return vm.compileNode(ast)
end
- local node = vm.compileNode(func)
+ local funcs = vm.getMatchedFunctions(func, args)
---@type vm.node?
local result
- for cnode in node:eachObject() do
- if cnode.type == 'function'
- or cnode.type == 'doc.type.function' then
- local returnObject = vm.getReturnOfFunction(cnode, index)
- if returnObject then
- local returnNode = vm.compileNode(returnObject)
+ for _, mfunc in ipairs(funcs) do
+ local returnObject = vm.getReturnOfFunction(mfunc, index)
+ if returnObject then
+ local returnNode = vm.compileNode(returnObject)
+ for rnode in returnNode:eachObject() do
+ if rnode.type == 'generic' then
+ returnNode = rnode:resolve(guide.getUri(func), args)
+ break
+ end
+ end
+ if returnNode then
for rnode in returnNode:eachObject() do
- if rnode.type == 'generic' then
- returnNode = rnode:resolve(guide.getUri(func), args)
- break
+ -- TODO: narrow type
+ if rnode.type ~= 'doc.generic.name' then
+ result = result or vm.createNode()
+ result:merge(rnode)
end
end
- if returnNode then
- for rnode in returnNode:eachObject() do
- -- TODO: narrow type
- if rnode.type ~= 'doc.generic.name' then
- result = result or vm.createNode()
- result:merge(rnode)
- end
- end
- if result and returnNode:isOptional() then
- result:addOptional()
- end
+ if result and returnNode:isOptional() then
+ result:addOptional()
end
end
end
diff --git a/script/vm/function.lua b/script/vm/function.lua
new file mode 100644
index 00000000..69900141
--- /dev/null
+++ b/script/vm/function.lua
@@ -0,0 +1,127 @@
+---@class vm
+local vm = require 'vm.vm'
+
+---@param func parser.object
+---@return integer min
+---@return integer max
+function vm.countParamsOfFunction(func)
+ local min = 0
+ local max = 0
+ if func.type == 'function'
+ or func.type == 'doc.type.function' then
+ if func.args then
+ max = #func.args
+ min = max
+ for i = #func.args, 1, -1 do
+ local arg = func.args[i]
+ if arg.type == '...'
+ or (arg.name and arg.name[1] =='...') then
+ max = math.huge
+ elseif not vm.compileNode(arg):isNullable() then
+ min = i
+ break
+ end
+ end
+ end
+ end
+ return min, max
+end
+
+---@param func parser.object
+---@return integer min
+---@return integer max
+function vm.countReturnsOfFunction(func)
+ if func.type == 'function' then
+ if not func.returns then
+ return 0, 0
+ end
+ local min, max
+ for _, ret in ipairs(func.returns) do
+ local rmin, rmax = vm.countList(ret)
+ if not min or rmin < min then
+ min = rmin
+ end
+ if not max or rmax > max then
+ max = rmax
+ end
+ end
+ return min, max
+ end
+ if func.type == 'doc.type.function' then
+ return vm.countList(func.returns)
+ end
+ return 0, 0
+end
+
+---@param func parser.object
+---@return integer min
+---@return integer max
+function vm.countReturnsOfCall(func, args)
+ local funcs = vm.getMatchedFunctions(func, args)
+ local min
+ local max
+ for _, f in ipairs(funcs) do
+ local rmin, rmax = vm.countReturnsOfFunction(f)
+ if not min or rmin < min then
+ min = rmin
+ end
+ if not max or rmax > max then
+ max = rmax
+ end
+ end
+ return min or 0, max or 0
+end
+
+---@param list parser.object[]?
+---@return integer min
+---@return integer max
+function vm.countList(list)
+ if not list then
+ return 0, 0
+ end
+ local lastArg = list[#list]
+ if not lastArg then
+ return 0, 0
+ end
+ if lastArg.type == '...' then
+ return #list - 1, math.huge
+ end
+ if lastArg.type == 'call' then
+ local rmin, rmax = vm.countReturnsOfCall(lastArg.node, lastArg.args)
+ return #list - 1 + rmin, #list - 1 + rmax
+ end
+ return #list, #list
+end
+
+---@param func parser.object
+---@param args parser.object[]?
+---@return parser.object[]
+function vm.getMatchedFunctions(func, args)
+ local funcs = {}
+ local node = vm.compileNode(func)
+ for n in node:eachObject() do
+ if n.type == 'function'
+ or n.type == 'doc.type.function' then
+ funcs[#funcs+1] = n
+ end
+ end
+ if #funcs <= 1 then
+ return funcs
+ end
+
+ local amin, amax = vm.countList(args)
+
+ local matched = {}
+ for _, n in ipairs(funcs) do
+ local min, max = vm.countParamsOfFunction(n)
+ if amin >= min and amax <= max then
+ matched[#matched+1] = n
+ end
+ end
+
+ if #matched == 0 then
+ return funcs
+ else
+ return matched
+ end
+end
diff --git a/script/vm/init.lua b/script/vm/init.lua
index f5003c11..4fa65766 100644
--- a/script/vm/init.lua
+++ b/script/vm/init.lua
@@ -17,4 +17,5 @@ require 'vm.generic'
require 'vm.sign'
require 'vm.local-id'
require 'vm.global'
+require 'vm.function'
return vm