summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--changelog.md1
-rw-r--r--script/vm/compiler.lua10
-rw-r--r--script/vm/function.lua31
-rw-r--r--test/type_inference/param_match.lua38
4 files changed, 79 insertions, 1 deletions
diff --git a/changelog.md b/changelog.md
index 2f1f2e6c..cd566573 100644
--- a/changelog.md
+++ b/changelog.md
@@ -9,6 +9,7 @@
* `FIX` Lua 5.1: fix incorrect warning when using setfenv with an int as first parameter
* `FIX` Improve type narrow by checking exact match on literal type params
* `FIX` Correctly list enums for function overload arguments [#2840](https://github.com/LuaLS/lua-language-server/pull/2840)
+* `FIX` Incorrect function params' type infer when there is only `@overload` [#2509](https://github.com/LuaLS/lua-language-server/issues/2509) [#2708](https://github.com/LuaLS/lua-language-server/issues/2708) [#2709](https://github.com/LuaLS/lua-language-server/issues/2709)
## 3.10.5
`2024-8-19`
diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua
index 041d287e..50f260c2 100644
--- a/script/vm/compiler.lua
+++ b/script/vm/compiler.lua
@@ -1099,6 +1099,7 @@ local function compileFunctionParam(func, source)
-- local call ---@type fun(f: fun(x: number));call(function (x) end) --> x -> number
local funcNode = vm.compileNode(func)
+ local found = false
for n in funcNode:eachObject() do
if n.type == 'doc.type.function' and n.args[aindex] then
local argNode = vm.compileNode(n.args[aindex])
@@ -1107,9 +1108,16 @@ local function compileFunctionParam(func, source)
vm.setNode(source, an)
end
end
- return true
+ -- NOTE: keep existing behavior for local call which only set type based on the 1st match
+ if func.parent.type == 'callargs' then
+ return true
+ end
+ found = true
end
end
+ if found then
+ return true
+ end
local derviationParam = config.get(guide.getUri(func), 'Lua.type.inferParamType')
if derviationParam and func.parent.type == 'local' and func.parent.ref then
diff --git a/script/vm/function.lua b/script/vm/function.lua
index 7a15ac5a..21a432c1 100644
--- a/script/vm/function.lua
+++ b/script/vm/function.lua
@@ -359,6 +359,7 @@ end
---@return number
local function calcFunctionMatchScore(uri, args, func)
if vm.isVarargFunctionWithOverloads(func)
+ or vm.isFunctionWithOnlyOverloads(func)
or not isAllParamMatched(uri, args, func.args)
then
return -1
@@ -490,6 +491,36 @@ function vm.isVarargFunctionWithOverloads(func)
return false
end
+---@param func table
+---@return boolean
+function vm.isFunctionWithOnlyOverloads(func)
+ if func.type ~= 'function' then
+ return false
+ end
+ if func._onlyOverloadFunction ~= nil then
+ return func._onlyOverloadFunction
+ end
+
+ if not func.bindDocs then
+ func._onlyOverloadFunction = false
+ return false
+ end
+ local hasOverload = false
+ for _, doc in ipairs(func.bindDocs) do
+ if doc.type == 'doc.overload' then
+ hasOverload = true
+ elseif doc.type == 'doc.param'
+ or doc.type == 'doc.return'
+ then
+ -- has specified @param or @return, thus not only @overload
+ func._onlyOverloadFunction = false
+ return false
+ end
+ end
+ func._onlyOverloadFunction = hasOverload
+ return true
+end
+
---@param func parser.object
---@return boolean
function vm.isEmptyFunction(func)
diff --git a/test/type_inference/param_match.lua b/test/type_inference/param_match.lua
index 906b9305..21dcf4d3 100644
--- a/test/type_inference/param_match.lua
+++ b/test/type_inference/param_match.lua
@@ -172,6 +172,44 @@ local v = 'y'
local <?r?> = f(v)
]]
+TEST 'string|number' [[
+---@overload fun(a: string)
+---@overload fun(a: number)
+local function f(<?a?>) end
+]]
+
+TEST '1|2' [[
+---@overload fun(a: 1)
+---@overload fun(a: 2)
+local function f(<?a?>) end
+]]
+
+TEST 'string' [[
+---@overload fun(a: 1): string
+---@overload fun(a: 2): number
+local function f(a) end
+
+local <?r?> = f(1)
+]]
+
+TEST 'number' [[
+---@overload fun(a: 1): string
+---@overload fun(a: 2): number
+local function f(a) end
+
+local <?r?> = f(2)
+]]
+
+TEST 'string|number' [[
+---@overload fun(a: 1): string
+---@overload fun(a: 2): number
+local function f(a) end
+
+---@type number
+local v
+local <?r?> = f(v)
+]]
+
TEST 'number' [[
---@overload fun(a: 1, c: fun(x: number))
---@overload fun(a: 2, c: fun(x: string))