diff options
author | 最萌小汐 <sumneko@hotmail.com> | 2024-09-05 18:36:02 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-09-05 18:36:02 +0800 |
commit | c636fdd95cee883d6d57330dd9c119a9a67680d4 (patch) | |
tree | 8c9af81b66bbeb5febb7901126dfc0b5742a4a72 /script/vm/function.lua | |
parent | 08dd0ca8956c8ca53876a95142b4531e454c9e48 (diff) | |
parent | 30deedc444355d7940dedbc039a59e8cd8bc938a (diff) | |
download | lua-language-server-c636fdd95cee883d6d57330dd9c119a9a67680d4.zip |
Merge pull request #2822 from tomlau10/fix/type_narrow
fix: improve function type narrow by checking params' literal identical
Diffstat (limited to 'script/vm/function.lua')
-rw-r--r-- | script/vm/function.lua | 59 |
1 files changed, 48 insertions, 11 deletions
diff --git a/script/vm/function.lua b/script/vm/function.lua index 1e308317..7a15ac5a 100644 --- a/script/vm/function.lua +++ b/script/vm/function.lua @@ -353,6 +353,35 @@ local function isAllParamMatched(uri, args, params) return true end +---@param uri uri +---@param args parser.object[] +---@param func parser.object +---@return number +local function calcFunctionMatchScore(uri, args, func) + if vm.isVarargFunctionWithOverloads(func) + or not isAllParamMatched(uri, args, func.args) + then + return -1 + end + local matchScore = 0 + for i = 1, math.min(#args, #func.args) do + local arg, param = args[i], func.args[i] + local defLiterals, literalsCount = vm.getLiterals(param) + if defLiterals then + for n in vm.compileNode(arg):eachObject() do + -- if param's literals map contains arg's literal, this is narrower than a subtype match + if defLiterals[guide.getLiteral(n)] then + -- the more the literals defined in the param, the less bonus score will be added + -- this favors matching overload param with exact literal value, over alias/enum that has many literal values + matchScore = matchScore + 1/literalsCount + break + end + end + end + end + return matchScore +end + ---@param func parser.object ---@param args? parser.object[] ---@return parser.object[]? @@ -365,21 +394,29 @@ function vm.getExactMatchedFunctions(func, args) return funcs end local uri = guide.getUri(func) - local needRemove + local matchScores = {} for i, n in ipairs(funcs) do - if vm.isVarargFunctionWithOverloads(n) - or not isAllParamMatched(uri, args, n.args) then - if not needRemove then - needRemove = {} - end - needRemove[#needRemove+1] = i - end + matchScores[i] = calcFunctionMatchScore(uri, args, n) + end + + local maxMatchScore = math.max(table.unpack(matchScores)) + if maxMatchScore == -1 then + -- all should be removed + return nil end - if not needRemove then + + local minMatchScore = math.min(table.unpack(matchScores)) + if minMatchScore == maxMatchScore then + -- all should be kept return funcs end - if #needRemove == #funcs then - return nil + + -- remove functions that have matchScore < maxMatchScore + local needRemove = {} + for i, matchScore in ipairs(matchScores) do + if matchScore < maxMatchScore then + needRemove[#needRemove + 1] = i + end end util.tableMultiRemove(funcs, needRemove) return funcs |