From bb1244fa62c0158490d2e54da0a19f28f16fe994 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=80=E8=90=8C=E5=B0=8F=E6=B1=90?= Date: Tue, 21 Jun 2022 20:23:41 +0800 Subject: resolve #871 infer called function by params num --- script/vm/function.lua | 127 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 127 insertions(+) create mode 100644 script/vm/function.lua (limited to 'script/vm/function.lua') diff --git a/script/vm/function.lua b/script/vm/function.lua new file mode 100644 index 00000000..69900141 --- /dev/null +++ b/script/vm/function.lua @@ -0,0 +1,127 @@ +---@class vm +local vm = require 'vm.vm' + +---@param func parser.object +---@return integer min +---@return integer max +function vm.countParamsOfFunction(func) + local min = 0 + local max = 0 + if func.type == 'function' + or func.type == 'doc.type.function' then + if func.args then + max = #func.args + min = max + for i = #func.args, 1, -1 do + local arg = func.args[i] + if arg.type == '...' + or (arg.name and arg.name[1] =='...') then + max = math.huge + elseif not vm.compileNode(arg):isNullable() then + min = i + break + end + end + end + end + return min, max +end + +---@param func parser.object +---@return integer min +---@return integer max +function vm.countReturnsOfFunction(func) + if func.type == 'function' then + if not func.returns then + return 0, 0 + end + local min, max + for _, ret in ipairs(func.returns) do + local rmin, rmax = vm.countList(ret) + if not min or rmin < min then + min = rmin + end + if not max or rmax > max then + max = rmax + end + end + return min, max + end + if func.type == 'doc.type.function' then + return vm.countList(func.returns) + end + return 0, 0 +end + +---@param func parser.object +---@return integer min +---@return integer max +function vm.countReturnsOfCall(func, args) + local funcs = vm.getMatchedFunctions(func, args) + local min + local max + for _, f in ipairs(funcs) do + local rmin, rmax = vm.countReturnsOfFunction(f) + if not min or rmin < min then + min = rmin + end + if not max or rmax > max then + max = rmax + end + end + return min or 0, max or 0 +end + +---@param list parser.object[]? +---@return integer min +---@return integer max +function vm.countList(list) + if not list then + return 0, 0 + end + local lastArg = list[#list] + if not lastArg then + return 0, 0 + end + if lastArg.type == '...' then + return #list - 1, math.huge + end + if lastArg.type == 'call' then + local rmin, rmax = vm.countReturnsOfCall(lastArg.node, lastArg.args) + return #list - 1 + rmin, #list - 1 + rmax + end + return #list, #list +end + +---@param func parser.object +---@param args parser.object[]? +---@return parser.object[] +function vm.getMatchedFunctions(func, args) + local funcs = {} + local node = vm.compileNode(func) + for n in node:eachObject() do + if n.type == 'function' + or n.type == 'doc.type.function' then + funcs[#funcs+1] = n + end + end + if #funcs <= 1 then + return funcs + end + + local amin, amax = vm.countList(args) + + local matched = {} + for _, n in ipairs(funcs) do + local min, max = vm.countParamsOfFunction(n) + if amin >= min and amax <= max then + matched[#matched+1] = n + end + end + + if #matched == 0 then + return funcs + else + return matched + end +end -- cgit v1.2.3