diff options
-rw-r--r-- | changelog.md | 4 | ||||
-rw-r--r-- | script/vm/compiler.lua | 70 | ||||
-rw-r--r-- | script/vm/node.lua | 13 | ||||
-rw-r--r-- | test/hover/init.lua | 43 | ||||
-rw-r--r-- | test/type_inference/init.lua | 12 |
5 files changed, 116 insertions, 26 deletions
diff --git a/changelog.md b/changelog.md index d4b5df88..48793c71 100644 --- a/changelog.md +++ b/changelog.md @@ -8,8 +8,8 @@ ``` * `CHG` infer called function by params num ```lua - ---@overload fun(number, number):string - ---@overload fun(number):number + ---@overload fun(x, y):string + ---@overload fun(x):number ---@return boolean local function f() end diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua index 515a8ebe..ff1e3a15 100644 --- a/script/vm/compiler.lua +++ b/script/vm/compiler.lua @@ -554,29 +554,32 @@ local function getReturn(func, index, args) end return vm.compileNode(ast) end - local funcs = vm.getMatchedFunctions(func, args) + local funcNode = vm.compileNode(func) ---@type vm.node? local result - for _, mfunc in ipairs(funcs) do - local returnObject = vm.getReturnOfFunction(mfunc, index) - if returnObject then - local returnNode = vm.compileNode(returnObject) - for rnode in returnNode:eachObject() do - if rnode.type == 'generic' then - returnNode = rnode:resolve(guide.getUri(func), args) - break - end - end - if returnNode then + for mfunc in funcNode:eachObject() do + if mfunc.type == 'function' + or mfunc.type == 'doc.type.function' then + local returnObject = vm.getReturnOfFunction(mfunc, index) + if returnObject then + local returnNode = vm.compileNode(returnObject) for rnode in returnNode:eachObject() do - -- TODO: narrow type - if rnode.type ~= 'doc.generic.name' then - result = result or vm.createNode() - result:merge(rnode) + if rnode.type == 'generic' then + returnNode = rnode:resolve(guide.getUri(func), args) + break end end - if result and returnNode:isOptional() then - result:addOptional() + if returnNode then + for rnode in returnNode:eachObject() do + -- TODO: narrow type + if rnode.type ~= 'doc.generic.name' then + result = result or vm.createNode() + result:merge(rnode) + end + end + if result and returnNode:isOptional() then + result:addOptional() + end end end end @@ -1821,6 +1824,36 @@ local function compileByGlobal(source) end end +---@param source parser.object +local function compileByCall(source) + local call = source.parent + if not call + or call.type ~= 'call' + or call.node ~= source then + return + end + local funcs = vm.getMatchedFunctions(source, call.args) + local myNode = vm.getNode(source) + if not myNode then + return + end + local needRemove + for n in myNode:eachObject() do + if n.type == 'function' + or n.type == 'doc.type.function' then + if not util.arrayHas(funcs, n) then + if not needRemove then + needRemove = vm.createNode() + end + needRemove:merge(n) + end + end + end + if needRemove then + myNode:removeNode(needRemove) + end +end + ---@param source vm.object ---@return vm.node function vm.compileNode(source) @@ -1845,6 +1878,7 @@ function vm.compileNode(source) vm.setNode(source, node, true) compileByGlobal(source) compileByNode(source) + compileByCall(source) node = vm.getNode(source) diff --git a/script/vm/node.lua b/script/vm/node.lua index 9433733e..61781e5f 100644 --- a/script/vm/node.lua +++ b/script/vm/node.lua @@ -252,6 +252,17 @@ function mt:narrow(name) return self end +---@param obj vm.object +function mt:removeObject(obj) + for index, c in ipairs(self) do + if c == obj then + table.remove(self, index) + self[c] = nil + return + end + end +end + ---@param node vm.node function mt:removeNode(node) for _, c in ipairs(node) do @@ -265,6 +276,8 @@ function mt:removeNode(node) else self:remove 'false' end + else + self:removeObject(c) end end end diff --git a/test/hover/init.lua b/test/hover/init.lua index 0925db77..1b6b2231 100644 --- a/test/hover/init.lua +++ b/test/hover/init.lua @@ -1947,3 +1947,46 @@ x({}, <?function?> () end) [[ (async) function () ]] + +TEST [[ +---@overload fun(x, y):string +---@overload fun(x):number +---@return boolean +local function f() end + +local n1 = <?f?>() +local n2 = f(0) +local n3 = f(0, 0) +]] +[[ +function f() + -> boolean +]] + +TEST [[ +---@overload fun(x, y):string +---@overload fun(x):number +---@return boolean +local function f() end + +local n1 = f() +local n2 = <?f?>(0) +local n3 = f(0, 0) +]] +[[ +local f: fun(x: any):number +]] + +TEST [[ +---@overload fun(x, y):string +---@overload fun(x):number +---@return boolean +local function f() end + +local n1 = f() +local n2 = f(0) +local n3 = <?f?>(0, 0) +]] +[[ +local f: fun(x: any, y: any):string +]] diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua index 394ce263..ad1df8e0 100644 --- a/test/type_inference/init.lua +++ b/test/type_inference/init.lua @@ -2929,8 +2929,8 @@ local <?x?> = f(r1()) ]] TEST 'boolean' [[ ----@overload fun(number, number):string ----@overload fun(number):number +---@overload fun(x, y):string +---@overload fun(x):number ---@return boolean local function f() end @@ -2940,8 +2940,8 @@ local n3 = f(0, 0) ]] TEST 'number' [[ ----@overload fun(number, number):string ----@overload fun(number):number +---@overload fun(x, y):string +---@overload fun(x):number ---@return boolean local function f() end @@ -2951,8 +2951,8 @@ local n3 = f(0, 0) ]] TEST 'string' [[ ----@overload fun(number, number):string ----@overload fun(number):number +---@overload fun(x, y):string +---@overload fun(x):number ---@return boolean local function f() end |