summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--changelog.md1
-rw-r--r--script/vm/function.lua59
-rw-r--r--script/vm/value.lua7
-rw-r--r--test/type_inference/param_match.lua34
4 files changed, 88 insertions, 13 deletions
diff --git a/changelog.md b/changelog.md
index bfd1f467..3efffd98 100644
--- a/changelog.md
+++ b/changelog.md
@@ -6,6 +6,7 @@
* `NEW` Setting: `Lua.docScriptPath`: Path to a script that overrides `cli.doc.export`, allowing user-specified documentation exporting.
* `FIX` Fix `VM.OnCompileFunctionParam` function in plugins
* `FIX` Lua 5.1: fix incorrect warning when using setfenv with an int as first parameter
+* `FIX` Improve type narrow by checking exact match on literal type params
## 3.10.5
`2024-8-19`
diff --git a/script/vm/function.lua b/script/vm/function.lua
index 1e308317..7a15ac5a 100644
--- a/script/vm/function.lua
+++ b/script/vm/function.lua
@@ -353,6 +353,35 @@ local function isAllParamMatched(uri, args, params)
return true
end
+---@param uri uri
+---@param args parser.object[]
+---@param func parser.object
+---@return number
+local function calcFunctionMatchScore(uri, args, func)
+ if vm.isVarargFunctionWithOverloads(func)
+ or not isAllParamMatched(uri, args, func.args)
+ then
+ return -1
+ end
+ local matchScore = 0
+ for i = 1, math.min(#args, #func.args) do
+ local arg, param = args[i], func.args[i]
+ local defLiterals, literalsCount = vm.getLiterals(param)
+ if defLiterals then
+ for n in vm.compileNode(arg):eachObject() do
+ -- if param's literals map contains arg's literal, this is narrower than a subtype match
+ if defLiterals[guide.getLiteral(n)] then
+ -- the more the literals defined in the param, the less bonus score will be added
+ -- this favors matching overload param with exact literal value, over alias/enum that has many literal values
+ matchScore = matchScore + 1/literalsCount
+ break
+ end
+ end
+ end
+ end
+ return matchScore
+end
+
---@param func parser.object
---@param args? parser.object[]
---@return parser.object[]?
@@ -365,21 +394,29 @@ function vm.getExactMatchedFunctions(func, args)
return funcs
end
local uri = guide.getUri(func)
- local needRemove
+ local matchScores = {}
for i, n in ipairs(funcs) do
- if vm.isVarargFunctionWithOverloads(n)
- or not isAllParamMatched(uri, args, n.args) then
- if not needRemove then
- needRemove = {}
- end
- needRemove[#needRemove+1] = i
- end
+ matchScores[i] = calcFunctionMatchScore(uri, args, n)
+ end
+
+ local maxMatchScore = math.max(table.unpack(matchScores))
+ if maxMatchScore == -1 then
+ -- all should be removed
+ return nil
end
- if not needRemove then
+
+ local minMatchScore = math.min(table.unpack(matchScores))
+ if minMatchScore == maxMatchScore then
+ -- all should be kept
return funcs
end
- if #needRemove == #funcs then
- return nil
+
+ -- remove functions that have matchScore < maxMatchScore
+ local needRemove = {}
+ for i, matchScore in ipairs(matchScores) do
+ if matchScore < maxMatchScore then
+ needRemove[#needRemove + 1] = i
+ end
end
util.tableMultiRemove(funcs, needRemove)
return funcs
diff --git a/script/vm/value.lua b/script/vm/value.lua
index 7eab4a8e..ce031357 100644
--- a/script/vm/value.lua
+++ b/script/vm/value.lua
@@ -213,11 +213,13 @@ end
---@param v vm.object
---@return table<any, boolean>?
+---@return integer
function vm.getLiterals(v)
if not v then
- return nil
+ return nil, 0
end
local map
+ local count = 0
local node = vm.compileNode(v)
for n in node:eachObject() do
local literal
@@ -237,7 +239,8 @@ function vm.getLiterals(v)
map = {}
end
map[literal] = true
+ count = count + 1
end
end
- return map
+ return map, count
end
diff --git a/test/type_inference/param_match.lua b/test/type_inference/param_match.lua
index 1079e433..906b9305 100644
--- a/test/type_inference/param_match.lua
+++ b/test/type_inference/param_match.lua
@@ -138,6 +138,40 @@ local function f(...) end
local <?r?> = f(10)
]]
+TEST '1' [[
+---@overload fun(a: string): 1
+---@overload fun(a: 'y'): 2
+local function f(...) end
+
+local <?r?> = f('x')
+]]
+
+TEST '2' [[
+---@overload fun(a: string): 1
+---@overload fun(a: 'y'): 2
+local function f(...) end
+
+local <?r?> = f('y')
+]]
+
+TEST '1' [[
+---@overload fun(a: string): 1
+---@overload fun(a: 'y'): 2
+local function f(...) end
+
+local v = 'x'
+local <?r?> = f(v)
+]]
+
+TEST '2' [[
+---@overload fun(a: string): 1
+---@overload fun(a: 'y'): 2
+local function f(...) end
+
+local v = 'y'
+local <?r?> = f(v)
+]]
+
TEST 'number' [[
---@overload fun(a: 1, c: fun(x: number))
---@overload fun(a: 2, c: fun(x: string))