diff options
author | 最萌小汐 <sumneko@hotmail.com> | 2022-06-22 01:41:16 +0800 |
---|---|---|
committer | 最萌小汐 <sumneko@hotmail.com> | 2022-06-22 01:41:16 +0800 |
commit | 7b3e7f4b0a62b7c9c3b895e89b4fe79bb8f350ba (patch) | |
tree | ee441dc6545127e650c6fcf435c96d33079ac177 | |
parent | d1e977ca8a6d72673649282bcd1973b1d279a8a1 (diff) | |
download | lua-language-server-7b3e7f4b0a62b7c9c3b895e89b4fe79bb8f350ba.zip |
fix
-rw-r--r-- | changelog.md | 4 | ||||
-rw-r--r-- | script/core/diagnostics/missing-parameter.lua | 122 | ||||
-rw-r--r-- | script/core/diagnostics/redundant-parameter.lua | 63 | ||||
-rw-r--r-- | script/vm/function.lua | 107 | ||||
-rw-r--r-- | test/diagnostics/common.lua | 11 | ||||
-rw-r--r-- | test/hover/init.lua | 16 | ||||
-rw-r--r-- | test/type_inference/init.lua | 12 |
7 files changed, 135 insertions, 200 deletions
diff --git a/changelog.md b/changelog.md index 48793c71..f0a43feb 100644 --- a/changelog.md +++ b/changelog.md @@ -8,8 +8,8 @@ ``` * `CHG` infer called function by params num ```lua - ---@overload fun(x, y):string - ---@overload fun(x):number + ---@overload fun(x: number, y: number):string + ---@overload fun(x: number):number ---@return boolean local function f() end diff --git a/script/core/diagnostics/missing-parameter.lua b/script/core/diagnostics/missing-parameter.lua index 9844046f..b6067175 100644 --- a/script/core/diagnostics/missing-parameter.lua +++ b/script/core/diagnostics/missing-parameter.lua @@ -3,116 +3,6 @@ local guide = require 'parser.guide' local vm = require 'vm' local lang = require 'language' ----@param source parser.object ----@return integer -local function countReturnsOfFunction(source) - local n = 0 - - local docs = source.bindDocs - if docs then - for _, doc in ipairs(docs) do - if doc.type == 'doc.return' then - for _, rtn in ipairs(doc.returns) do - if rtn.returnIndex and rtn.returnIndex > n then - n = rtn.returnIndex - end - end - end - end - end - - local returns = source.returns - if returns then - for _, rtn in ipairs(returns) do - if #rtn > n then - n = #rtn - end - end - end - - return n -end - ----@param source parser.object ----@return integer -local function countReturnsOfDocFunction(source) - return #source.returns -end - -local function countMaxReturns(source) - local hasFounded - local n = 0 - for _, def in ipairs(vm.getDefs(source)) do - if def.type == 'function' then - hasFounded = true - local rets = countReturnsOfFunction(def) - if rets > n then - n = rets - end - elseif def.type == 'doc.type.function' then - hasFounded = true - local rets = countReturnsOfDocFunction(def) - if rets > n then - n = rets - end - end - end - - if hasFounded then - return n - else - return math.huge - end -end - -local function countCallArgs(source) - local result = 0 - if not source.args then - return 0 - end - local lastArg = source.args[#source.args] - if lastArg.type == 'varargs' then - return math.huge - end - if lastArg.type == 'call' then - result = result + countMaxReturns(lastArg.node) - 1 - end - result = result + #source.args - return result -end - ----@return integer -local function countFuncArgs(source) - if not source.args or #source.args == 0 then - return 0 - end - local count = 0 - for i = #source.args, 1, -1 do - local arg = source.args[i] - if arg.type ~= '...' - and not (arg.name and arg.name[1] =='...') - and not vm.compileNode(arg):isNullable() then - return i - end - end - return count -end - -local function getFuncArgs(func) - local funcArgs - local defs = vm.getDefs(func) - for _, def in ipairs(defs) do - if def.type == 'function' - or def.type == 'doc.type.function' then - local args = countFuncArgs(def) - if not funcArgs or args < funcArgs then - funcArgs = args - end - end - end - return funcArgs -end - return function (uri, callback) local state = files.getState(uri) if not state then @@ -120,19 +10,15 @@ return function (uri, callback) end guide.eachSourceType(state.ast, 'call', function (source) - local callArgs = countCallArgs(source) + local _, callArgs = vm.countList(source.args) - local func = source.node - local funcArgs = getFuncArgs(func) + local funcNode = vm.compileNode(source.node) + local funcArgs = vm.countParamsOfNode(funcNode) - if not funcArgs then + if callArgs >= funcArgs then return end - local delta = callArgs - funcArgs - if delta >= 0 then - return - end callback { start = source.start, finish = source.finish, diff --git a/script/core/diagnostics/redundant-parameter.lua b/script/core/diagnostics/redundant-parameter.lua index 41781df8..2b7f1230 100644 --- a/script/core/diagnostics/redundant-parameter.lua +++ b/script/core/diagnostics/redundant-parameter.lua @@ -3,43 +3,6 @@ local guide = require 'parser.guide' local vm = require 'vm' local lang = require 'language' -local function countCallArgs(source) - local result = 0 - if not source.args then - return 0 - end - result = result + #source.args - return result -end - -local function countFuncArgs(source) - if not source.args or #source.args == 0 then - return 0 - end - local lastArg = source.args[#source.args] - if lastArg.type == '...' - or (lastArg.name and lastArg.name[1] == '...') then - return math.maxinteger - else - return #source.args - end -end - -local function getFuncArgs(func) - local funcArgs - local defs = vm.getDefs(func) - for _, def in ipairs(defs) do - if def.type == 'function' - or def.type == 'doc.type.function' then - local args = countFuncArgs(def) - if not funcArgs or args > funcArgs then - funcArgs = args - end - end - end - return funcArgs -end - return function (uri, callback) local state = files.getState(uri) if not state then @@ -47,28 +10,30 @@ return function (uri, callback) end guide.eachSourceType(state.ast, 'call', function (source) - local callArgs = countCallArgs(source) + local callArgs = vm.countList(source.args) if callArgs == 0 then return end - local func = source.node - local funcArgs = getFuncArgs(func) - - if not funcArgs then - return - end + local funcNode = vm.compileNode(source.node) + local _, funcArgs = vm.countParamsOfNode(funcNode) - local delta = callArgs - funcArgs - if delta <= 0 then + if callArgs <= funcArgs then return end if callArgs == 1 and source.node.type == 'getmethod' then return end - for i = #source.args - delta + 1, #source.args do - local arg = source.args[i] - if arg then + if funcArgs + 1 > #source.args then + local lastArg = source.args[#source.args] + callback { + start = lastArg.start, + finish = lastArg.finish, + message = lang.script('DIAG_OVER_MAX_ARGS', funcArgs, callArgs) + } + else + for i = funcArgs + 1, #source.args do + local arg = source.args[i] callback { start = arg.start, finish = arg.finish, diff --git a/script/vm/function.lua b/script/vm/function.lua index 69900141..45f8c0df 100644 --- a/script/vm/function.lua +++ b/script/vm/function.lua @@ -1,21 +1,48 @@ ---@class vm local vm = require 'vm.vm' +---@param arg parser.object +---@return parser.object? +local function getDocParam(arg) + if not arg.bindDocs then + return nil + end + for _, doc in ipairs(arg.bindDocs) do + if doc.type == 'doc.param' + and doc.param[1] == arg[1] then + return doc + end + end + return nil +end + ---@param func parser.object ---@return integer min ---@return integer max function vm.countParamsOfFunction(func) local min = 0 local max = 0 - if func.type == 'function' - or func.type == 'doc.type.function' then + if func.type == 'function' then if func.args then max = #func.args - min = max for i = #func.args, 1, -1 do local arg = func.args[i] - if arg.type == '...' - or (arg.name and arg.name[1] =='...') then + if arg.type == '...' then + max = math.huge + elseif getDocParam(arg) + and not vm.compileNode(arg):isNullable() then + min = i + break + end + end + end + end + if func.type == 'doc.type.function' then + if func.args then + max = #func.args + for i = #func.args, 1, -1 do + local arg = func.args[i] + if arg.name and arg.name[1] =='...' then max = math.huge elseif not vm.compileNode(arg):isNullable() then min = i @@ -27,22 +54,67 @@ function vm.countParamsOfFunction(func) return min, max end +---@param node vm.node +---@return integer min +---@return integer max +function vm.countParamsOfNode(node) + local min, max + for n in node:eachObject() do + if n.type == 'function' + or n.type == 'doc.type.function' then + local fmin, fmax = vm.countParamsOfFunction(n) + if not min or fmin < min then + min = fmin + end + if not max or fmax > max then + max = fmax + end + end + end + return min or 0, max or math.huge +end + ---@param func parser.object ---@return integer min ---@return integer max function vm.countReturnsOfFunction(func) if func.type == 'function' then - if not func.returns then - return 0, 0 - end local min, max - for _, ret in ipairs(func.returns) do - local rmin, rmax = vm.countList(ret) - if not min or rmin < min then - min = rmin + if func.returns then + for _, ret in ipairs(func.returns) do + local rmin, rmax = vm.countList(ret) + if not min or rmin < min then + min = rmin + end + if not max or rmax > max then + max = rmax + end + end + end + if func.bindDocs then + local lastReturn + local n = 0 + local dmin, dmax + for _, doc in ipairs(func.bindDocs) do + if doc.type == 'doc.return' then + for _, ret in ipairs(doc) do + n = n + 1 + lastReturn = ret + dmax = n + if not vm.compileNode(ret):isNullable() then + dmin = n + end + end + end + end + if lastReturn and lastReturn.types[1][1] == '...' then + dmax = math.huge + end + if dmin and (not min or (dmin < min)) then + min = dmin end - if not max or rmax > max then - max = rmax + if dmax and (not max or (dmax > max)) then + max = dmax end end return min, max @@ -50,7 +122,7 @@ function vm.countReturnsOfFunction(func) if func.type == 'doc.type.function' then return vm.countList(func.returns) end - return 0, 0 + error('not a function') end ---@param func parser.object @@ -69,7 +141,7 @@ function vm.countReturnsOfCall(func, args) max = rmax end end - return min or 0, max or 0 + return min or 0, max or math.huge end ---@param list parser.object[]? @@ -83,7 +155,8 @@ function vm.countList(list) if not lastArg then return 0, 0 end - if lastArg.type == '...' then + if lastArg.type == '...' + or lastArg.type == 'varargs' then return #list - 1, math.huge end if lastArg.type == 'call' then diff --git a/test/diagnostics/common.lua b/test/diagnostics/common.lua index 4edb9703..d640e45a 100644 --- a/test/diagnostics/common.lua +++ b/test/diagnostics/common.lua @@ -1620,3 +1620,14 @@ local n print(n.x) ]] + +TEST [[ +---@diagnostic disable: unused-local, unused-function, undefined-global + +function F() end + +---@param x boolean +function F(x) end + +F(k()) +]] diff --git a/test/hover/init.lua b/test/hover/init.lua index 1b6b2231..7fec9254 100644 --- a/test/hover/init.lua +++ b/test/hover/init.lua @@ -1949,8 +1949,8 @@ x({}, <?function?> () end) ]] TEST [[ ----@overload fun(x, y):string ----@overload fun(x):number +---@overload fun(x: number, y: number):string +---@overload fun(x: number):number ---@return boolean local function f() end @@ -1964,8 +1964,8 @@ function f() ]] TEST [[ ----@overload fun(x, y):string ----@overload fun(x):number +---@overload fun(x: number, y: number):string +---@overload fun(x: number):number ---@return boolean local function f() end @@ -1974,12 +1974,12 @@ local n2 = <?f?>(0) local n3 = f(0, 0) ]] [[ -local f: fun(x: any):number +local f: fun(x: number):number ]] TEST [[ ----@overload fun(x, y):string ----@overload fun(x):number +---@overload fun(x: number, y: number):string +---@overload fun(x: number):number ---@return boolean local function f() end @@ -1988,5 +1988,5 @@ local n2 = f(0) local n3 = <?f?>(0, 0) ]] [[ -local f: fun(x: any, y: any):string +local f: fun(x: number, y: number):string ]] diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua index ad1df8e0..74ac7120 100644 --- a/test/type_inference/init.lua +++ b/test/type_inference/init.lua @@ -2929,8 +2929,8 @@ local <?x?> = f(r1()) ]] TEST 'boolean' [[ ----@overload fun(x, y):string ----@overload fun(x):number +---@overload fun(x: number, y: number):string +---@overload fun(x: number):number ---@return boolean local function f() end @@ -2940,8 +2940,8 @@ local n3 = f(0, 0) ]] TEST 'number' [[ ----@overload fun(x, y):string ----@overload fun(x):number +---@overload fun(x: number, y: number):string +---@overload fun(x: number):number ---@return boolean local function f() end @@ -2951,8 +2951,8 @@ local n3 = f(0, 0) ]] TEST 'string' [[ ----@overload fun(x, y):string ----@overload fun(x):number +---@overload fun(x: number, y: number):string +---@overload fun(x: number):number ---@return boolean local function f() end |