diff options
-rw-r--r-- | changelog.md | 11 | ||||
-rw-r--r-- | script/vm/compiler.lua | 5 | ||||
-rw-r--r-- | script/vm/function.lua | 54 | ||||
-rw-r--r-- | script/vm/infer.lua | 3 | ||||
-rw-r--r-- | script/vm/node.lua | 2 | ||||
-rw-r--r-- | test/type_inference/param_match.lua | 32 |
6 files changed, 96 insertions, 11 deletions
diff --git a/changelog.md b/changelog.md index 087fbf85..743d6129 100644 --- a/changelog.md +++ b/changelog.md @@ -1,7 +1,16 @@ # changelog -## 3.8.4 +## 3.9.0 * `NEW` goto implementation +* `NEW` narrow the function prototype based on the parameter type + ```lua + ---@overload fun(a: boolean): A + ---@overload fun(a: number): B + local function f(...) end + + local r1 = f(true) --> r1 is `A` + local r2 = f(10) --> r2 is `B` + ``` ## 3.8.3 `2024-4-23` diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua index 11ba07ab..78b62b27 100644 --- a/script/vm/compiler.lua +++ b/script/vm/compiler.lua @@ -550,11 +550,14 @@ local function matchCall(source) or call.node ~= source then return end - local funcs = vm.getMatchedFunctions(source, call.args) local myNode = vm.getNode(source) if not myNode then return end + local funcs = vm.getExactMatchedFunctions(source, call.args) + if not funcs then + return + end local needRemove for n in myNode:eachObject() do if n.type == 'function' diff --git a/script/vm/function.lua b/script/vm/function.lua index dde8ecb2..1a916a1c 100644 --- a/script/vm/function.lua +++ b/script/vm/function.lua @@ -267,6 +267,9 @@ end ---@return integer def function vm.countReturnsOfCall(func, args, mark) local funcs = vm.getMatchedFunctions(func, args, mark) + if not funcs then + return 0, math.huge, 0 + end ---@type integer?, number?, integer? local min, max, def for _, f in ipairs(funcs) do @@ -329,10 +332,52 @@ function vm.countList(list, mark) return min, max, def end +---@param uri uri +---@param args parser.object[] +---@return boolean +local function isAllParamMatched(uri, args, params) + if not params then + return false + end + for i = 1, #args do + if not params[i] then + break + end + local argNode = vm.compileNode(args[i]) + local defNode = vm.compileNode(params[i]) + if not vm.canCastType(uri, defNode, argNode) then + return false + end + end + return true +end + ---@param func parser.object ----@param args parser.object[]? +---@param args? parser.object[] +---@return parser.object[]? +function vm.getExactMatchedFunctions(func, args) + local funcs = vm.getMatchedFunctions(func, args) + if not args or not funcs then + return funcs + end + local uri = guide.getUri(func) + local result = {} + for _, n in ipairs(funcs) do + if not vm.isVarargFunctionWithOverloads(n) + and isAllParamMatched(uri, args, n.args) then + result[#result+1] = n + end + end + if #result == 0 then + return nil + end + return result +end + +---@param func parser.object +---@param args? parser.object[] ---@param mark? table ----@return parser.object[] +---@return parser.object[]? function vm.getMatchedFunctions(func, args, mark) local funcs = {} local node = vm.compileNode(func) @@ -342,9 +387,6 @@ function vm.getMatchedFunctions(func, args, mark) funcs[#funcs+1] = n end end - if #funcs <= 1 then - return funcs - end local amin, amax = vm.countList(args, mark) @@ -357,7 +399,7 @@ function vm.getMatchedFunctions(func, args, mark) end if #matched == 0 then - return funcs + return nil else return matched end diff --git a/script/vm/infer.lua b/script/vm/infer.lua index 3f3d0e3a..bb06ee3a 100644 --- a/script/vm/infer.lua +++ b/script/vm/infer.lua @@ -242,9 +242,6 @@ local viewNodeSwitch;viewNodeSwitch = util.switch() return vm.viewKey(source, uri) end) ----@class vm.node ----@field lastInfer? vm.infer - ---@param node? vm.node ---@return vm.infer local function createInfer(node) diff --git a/script/vm/node.lua b/script/vm/node.lua index bc1dfcb1..fae79cbc 100644 --- a/script/vm/node.lua +++ b/script/vm/node.lua @@ -16,6 +16,7 @@ vm.nodeCache = setmetatable({}, util.MODE_K) ---@field [vm.node.object] true ---@field fields? table<vm.node|string, vm.node> ---@field undefinedGlobal boolean? +---@field lastInfer? vm.infer local mt = {} mt.__index = mt mt.id = 0 @@ -31,6 +32,7 @@ function mt:merge(node) if not node then return self end + self.lastInfer = nil if node.type == 'vm.node' then if node == self then return self diff --git a/test/type_inference/param_match.lua b/test/type_inference/param_match.lua index 3bb167bc..8ead05ef 100644 --- a/test/type_inference/param_match.lua +++ b/test/type_inference/param_match.lua @@ -105,3 +105,35 @@ local r1 local <?x?> = f(r1()) ]] + +TEST '1' [[ +---@overload fun(a: 'x'): 1 +---@overload fun(a: 'y'): 2 +local function f(...) end + +local <?r?> = f('x') +]] + +TEST '2' [[ +---@overload fun(a: 'x'): 1 +---@overload fun(a: 'y'): 2 +local function f(...) end + +local <?r?> = f('y') +]] + +TEST '1' [[ +---@overload fun(a: boolean): 1 +---@overload fun(a: number): 2 +local function f(...) end + +local <?r?> = f(true) +]] + +TEST '2' [[ +---@overload fun(a: boolean): 1 +---@overload fun(a: number): 2 +local function f(...) end + +local <?r?> = f(10) +]] |