summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--changelog.md4
-rw-r--r--script/vm/compiler.lua70
-rw-r--r--script/vm/node.lua13
-rw-r--r--test/hover/init.lua43
-rw-r--r--test/type_inference/init.lua12
5 files changed, 116 insertions, 26 deletions
diff --git a/changelog.md b/changelog.md
index d4b5df88..48793c71 100644
--- a/changelog.md
+++ b/changelog.md
@@ -8,8 +8,8 @@
```
* `CHG` infer called function by params num
```lua
- ---@overload fun(number, number):string
- ---@overload fun(number):number
+ ---@overload fun(x, y):string
+ ---@overload fun(x):number
---@return boolean
local function f() end
diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua
index 515a8ebe..ff1e3a15 100644
--- a/script/vm/compiler.lua
+++ b/script/vm/compiler.lua
@@ -554,29 +554,32 @@ local function getReturn(func, index, args)
end
return vm.compileNode(ast)
end
- local funcs = vm.getMatchedFunctions(func, args)
+ local funcNode = vm.compileNode(func)
---@type vm.node?
local result
- for _, mfunc in ipairs(funcs) do
- local returnObject = vm.getReturnOfFunction(mfunc, index)
- if returnObject then
- local returnNode = vm.compileNode(returnObject)
- for rnode in returnNode:eachObject() do
- if rnode.type == 'generic' then
- returnNode = rnode:resolve(guide.getUri(func), args)
- break
- end
- end
- if returnNode then
+ for mfunc in funcNode:eachObject() do
+ if mfunc.type == 'function'
+ or mfunc.type == 'doc.type.function' then
+ local returnObject = vm.getReturnOfFunction(mfunc, index)
+ if returnObject then
+ local returnNode = vm.compileNode(returnObject)
for rnode in returnNode:eachObject() do
- -- TODO: narrow type
- if rnode.type ~= 'doc.generic.name' then
- result = result or vm.createNode()
- result:merge(rnode)
+ if rnode.type == 'generic' then
+ returnNode = rnode:resolve(guide.getUri(func), args)
+ break
end
end
- if result and returnNode:isOptional() then
- result:addOptional()
+ if returnNode then
+ for rnode in returnNode:eachObject() do
+ -- TODO: narrow type
+ if rnode.type ~= 'doc.generic.name' then
+ result = result or vm.createNode()
+ result:merge(rnode)
+ end
+ end
+ if result and returnNode:isOptional() then
+ result:addOptional()
+ end
end
end
end
@@ -1821,6 +1824,36 @@ local function compileByGlobal(source)
end
end
+---@param source parser.object
+local function compileByCall(source)
+ local call = source.parent
+ if not call
+ or call.type ~= 'call'
+ or call.node ~= source then
+ return
+ end
+ local funcs = vm.getMatchedFunctions(source, call.args)
+ local myNode = vm.getNode(source)
+ if not myNode then
+ return
+ end
+ local needRemove
+ for n in myNode:eachObject() do
+ if n.type == 'function'
+ or n.type == 'doc.type.function' then
+ if not util.arrayHas(funcs, n) then
+ if not needRemove then
+ needRemove = vm.createNode()
+ end
+ needRemove:merge(n)
+ end
+ end
+ end
+ if needRemove then
+ myNode:removeNode(needRemove)
+ end
+end
+
---@param source vm.object
---@return vm.node
function vm.compileNode(source)
@@ -1845,6 +1878,7 @@ function vm.compileNode(source)
vm.setNode(source, node, true)
compileByGlobal(source)
compileByNode(source)
+ compileByCall(source)
node = vm.getNode(source)
diff --git a/script/vm/node.lua b/script/vm/node.lua
index 9433733e..61781e5f 100644
--- a/script/vm/node.lua
+++ b/script/vm/node.lua
@@ -252,6 +252,17 @@ function mt:narrow(name)
return self
end
+---@param obj vm.object
+function mt:removeObject(obj)
+ for index, c in ipairs(self) do
+ if c == obj then
+ table.remove(self, index)
+ self[c] = nil
+ return
+ end
+ end
+end
+
---@param node vm.node
function mt:removeNode(node)
for _, c in ipairs(node) do
@@ -265,6 +276,8 @@ function mt:removeNode(node)
else
self:remove 'false'
end
+ else
+ self:removeObject(c)
end
end
end
diff --git a/test/hover/init.lua b/test/hover/init.lua
index 0925db77..1b6b2231 100644
--- a/test/hover/init.lua
+++ b/test/hover/init.lua
@@ -1947,3 +1947,46 @@ x({}, <?function?> () end)
[[
(async) function ()
]]
+
+TEST [[
+---@overload fun(x, y):string
+---@overload fun(x):number
+---@return boolean
+local function f() end
+
+local n1 = <?f?>()
+local n2 = f(0)
+local n3 = f(0, 0)
+]]
+[[
+function f()
+ -> boolean
+]]
+
+TEST [[
+---@overload fun(x, y):string
+---@overload fun(x):number
+---@return boolean
+local function f() end
+
+local n1 = f()
+local n2 = <?f?>(0)
+local n3 = f(0, 0)
+]]
+[[
+local f: fun(x: any):number
+]]
+
+TEST [[
+---@overload fun(x, y):string
+---@overload fun(x):number
+---@return boolean
+local function f() end
+
+local n1 = f()
+local n2 = f(0)
+local n3 = <?f?>(0, 0)
+]]
+[[
+local f: fun(x: any, y: any):string
+]]
diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua
index 394ce263..ad1df8e0 100644
--- a/test/type_inference/init.lua
+++ b/test/type_inference/init.lua
@@ -2929,8 +2929,8 @@ local <?x?> = f(r1())
]]
TEST 'boolean' [[
----@overload fun(number, number):string
----@overload fun(number):number
+---@overload fun(x, y):string
+---@overload fun(x):number
---@return boolean
local function f() end
@@ -2940,8 +2940,8 @@ local n3 = f(0, 0)
]]
TEST 'number' [[
----@overload fun(number, number):string
----@overload fun(number):number
+---@overload fun(x, y):string
+---@overload fun(x):number
---@return boolean
local function f() end
@@ -2951,8 +2951,8 @@ local n3 = f(0, 0)
]]
TEST 'string' [[
----@overload fun(number, number):string
----@overload fun(number):number
+---@overload fun(x, y):string
+---@overload fun(x):number
---@return boolean
local function f() end