diff options
author | 最萌小汐 <sumneko@hotmail.com> | 2022-06-21 20:23:41 +0800 |
---|---|---|
committer | 最萌小汐 <sumneko@hotmail.com> | 2022-06-21 20:23:41 +0800 |
commit | bb1244fa62c0158490d2e54da0a19f28f16fe994 (patch) | |
tree | bea39497ee52f7b4de442438bad50c90668a0647 /script/vm | |
parent | f76cd50992dab57a57c61e8e6f5f788745544da9 (diff) | |
download | lua-language-server-bb1244fa62c0158490d2e54da0a19f28f16fe994.zip |
resolve #871 infer called function by params num
Diffstat (limited to 'script/vm')
-rw-r--r-- | script/vm/compiler.lua | 39 | ||||
-rw-r--r-- | script/vm/function.lua | 127 | ||||
-rw-r--r-- | script/vm/init.lua | 1 |
3 files changed, 146 insertions, 21 deletions
diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua index 99bd0691..515a8ebe 100644 --- a/script/vm/compiler.lua +++ b/script/vm/compiler.lua @@ -554,32 +554,29 @@ local function getReturn(func, index, args) end return vm.compileNode(ast) end - local node = vm.compileNode(func) + local funcs = vm.getMatchedFunctions(func, args) ---@type vm.node? local result - for cnode in node:eachObject() do - if cnode.type == 'function' - or cnode.type == 'doc.type.function' then - local returnObject = vm.getReturnOfFunction(cnode, index) - if returnObject then - local returnNode = vm.compileNode(returnObject) + for _, mfunc in ipairs(funcs) do + local returnObject = vm.getReturnOfFunction(mfunc, index) + if returnObject then + local returnNode = vm.compileNode(returnObject) + for rnode in returnNode:eachObject() do + if rnode.type == 'generic' then + returnNode = rnode:resolve(guide.getUri(func), args) + break + end + end + if returnNode then for rnode in returnNode:eachObject() do - if rnode.type == 'generic' then - returnNode = rnode:resolve(guide.getUri(func), args) - break + -- TODO: narrow type + if rnode.type ~= 'doc.generic.name' then + result = result or vm.createNode() + result:merge(rnode) end end - if returnNode then - for rnode in returnNode:eachObject() do - -- TODO: narrow type - if rnode.type ~= 'doc.generic.name' then - result = result or vm.createNode() - result:merge(rnode) - end - end - if result and returnNode:isOptional() then - result:addOptional() - end + if result and returnNode:isOptional() then + result:addOptional() end end end diff --git a/script/vm/function.lua b/script/vm/function.lua new file mode 100644 index 00000000..69900141 --- /dev/null +++ b/script/vm/function.lua @@ -0,0 +1,127 @@ +---@class vm +local vm = require 'vm.vm' + +---@param func parser.object +---@return integer min +---@return integer max +function vm.countParamsOfFunction(func) + local min = 0 + local max = 0 + if func.type == 'function' + or func.type == 'doc.type.function' then + if func.args then + max = #func.args + min = max + for i = #func.args, 1, -1 do + local arg = func.args[i] + if arg.type == '...' + or (arg.name and arg.name[1] =='...') then + max = math.huge + elseif not vm.compileNode(arg):isNullable() then + min = i + break + end + end + end + end + return min, max +end + +---@param func parser.object +---@return integer min +---@return integer max +function vm.countReturnsOfFunction(func) + if func.type == 'function' then + if not func.returns then + return 0, 0 + end + local min, max + for _, ret in ipairs(func.returns) do + local rmin, rmax = vm.countList(ret) + if not min or rmin < min then + min = rmin + end + if not max or rmax > max then + max = rmax + end + end + return min, max + end + if func.type == 'doc.type.function' then + return vm.countList(func.returns) + end + return 0, 0 +end + +---@param func parser.object +---@return integer min +---@return integer max +function vm.countReturnsOfCall(func, args) + local funcs = vm.getMatchedFunctions(func, args) + local min + local max + for _, f in ipairs(funcs) do + local rmin, rmax = vm.countReturnsOfFunction(f) + if not min or rmin < min then + min = rmin + end + if not max or rmax > max then + max = rmax + end + end + return min or 0, max or 0 +end + +---@param list parser.object[]? +---@return integer min +---@return integer max +function vm.countList(list) + if not list then + return 0, 0 + end + local lastArg = list[#list] + if not lastArg then + return 0, 0 + end + if lastArg.type == '...' then + return #list - 1, math.huge + end + if lastArg.type == 'call' then + local rmin, rmax = vm.countReturnsOfCall(lastArg.node, lastArg.args) + return #list - 1 + rmin, #list - 1 + rmax + end + return #list, #list +end + +---@param func parser.object +---@param args parser.object[]? +---@return parser.object[] +function vm.getMatchedFunctions(func, args) + local funcs = {} + local node = vm.compileNode(func) + for n in node:eachObject() do + if n.type == 'function' + or n.type == 'doc.type.function' then + funcs[#funcs+1] = n + end + end + if #funcs <= 1 then + return funcs + end + + local amin, amax = vm.countList(args) + + local matched = {} + for _, n in ipairs(funcs) do + local min, max = vm.countParamsOfFunction(n) + if amin >= min and amax <= max then + matched[#matched+1] = n + end + end + + if #matched == 0 then + return funcs + else + return matched + end +end diff --git a/script/vm/init.lua b/script/vm/init.lua index f5003c11..4fa65766 100644 --- a/script/vm/init.lua +++ b/script/vm/init.lua @@ -17,4 +17,5 @@ require 'vm.generic' require 'vm.sign' require 'vm.local-id' require 'vm.global' +require 'vm.function' return vm |