summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--changelog.md4
-rw-r--r--script/core/diagnostics/missing-parameter.lua122
-rw-r--r--script/core/diagnostics/redundant-parameter.lua63
-rw-r--r--script/vm/function.lua107
-rw-r--r--test/diagnostics/common.lua11
-rw-r--r--test/hover/init.lua16
-rw-r--r--test/type_inference/init.lua12
7 files changed, 135 insertions, 200 deletions
diff --git a/changelog.md b/changelog.md
index 48793c71..f0a43feb 100644
--- a/changelog.md
+++ b/changelog.md
@@ -8,8 +8,8 @@
```
* `CHG` infer called function by params num
```lua
- ---@overload fun(x, y):string
- ---@overload fun(x):number
+ ---@overload fun(x: number, y: number):string
+ ---@overload fun(x: number):number
---@return boolean
local function f() end
diff --git a/script/core/diagnostics/missing-parameter.lua b/script/core/diagnostics/missing-parameter.lua
index 9844046f..b6067175 100644
--- a/script/core/diagnostics/missing-parameter.lua
+++ b/script/core/diagnostics/missing-parameter.lua
@@ -3,116 +3,6 @@ local guide = require 'parser.guide'
local vm = require 'vm'
local lang = require 'language'
----@param source parser.object
----@return integer
-local function countReturnsOfFunction(source)
- local n = 0
-
- local docs = source.bindDocs
- if docs then
- for _, doc in ipairs(docs) do
- if doc.type == 'doc.return' then
- for _, rtn in ipairs(doc.returns) do
- if rtn.returnIndex and rtn.returnIndex > n then
- n = rtn.returnIndex
- end
- end
- end
- end
- end
-
- local returns = source.returns
- if returns then
- for _, rtn in ipairs(returns) do
- if #rtn > n then
- n = #rtn
- end
- end
- end
-
- return n
-end
-
----@param source parser.object
----@return integer
-local function countReturnsOfDocFunction(source)
- return #source.returns
-end
-
-local function countMaxReturns(source)
- local hasFounded
- local n = 0
- for _, def in ipairs(vm.getDefs(source)) do
- if def.type == 'function' then
- hasFounded = true
- local rets = countReturnsOfFunction(def)
- if rets > n then
- n = rets
- end
- elseif def.type == 'doc.type.function' then
- hasFounded = true
- local rets = countReturnsOfDocFunction(def)
- if rets > n then
- n = rets
- end
- end
- end
-
- if hasFounded then
- return n
- else
- return math.huge
- end
-end
-
-local function countCallArgs(source)
- local result = 0
- if not source.args then
- return 0
- end
- local lastArg = source.args[#source.args]
- if lastArg.type == 'varargs' then
- return math.huge
- end
- if lastArg.type == 'call' then
- result = result + countMaxReturns(lastArg.node) - 1
- end
- result = result + #source.args
- return result
-end
-
----@return integer
-local function countFuncArgs(source)
- if not source.args or #source.args == 0 then
- return 0
- end
- local count = 0
- for i = #source.args, 1, -1 do
- local arg = source.args[i]
- if arg.type ~= '...'
- and not (arg.name and arg.name[1] =='...')
- and not vm.compileNode(arg):isNullable() then
- return i
- end
- end
- return count
-end
-
-local function getFuncArgs(func)
- local funcArgs
- local defs = vm.getDefs(func)
- for _, def in ipairs(defs) do
- if def.type == 'function'
- or def.type == 'doc.type.function' then
- local args = countFuncArgs(def)
- if not funcArgs or args < funcArgs then
- funcArgs = args
- end
- end
- end
- return funcArgs
-end
-
return function (uri, callback)
local state = files.getState(uri)
if not state then
@@ -120,19 +10,15 @@ return function (uri, callback)
end
guide.eachSourceType(state.ast, 'call', function (source)
- local callArgs = countCallArgs(source)
+ local _, callArgs = vm.countList(source.args)
- local func = source.node
- local funcArgs = getFuncArgs(func)
+ local funcNode = vm.compileNode(source.node)
+ local funcArgs = vm.countParamsOfNode(funcNode)
- if not funcArgs then
+ if callArgs >= funcArgs then
return
end
- local delta = callArgs - funcArgs
- if delta >= 0 then
- return
- end
callback {
start = source.start,
finish = source.finish,
diff --git a/script/core/diagnostics/redundant-parameter.lua b/script/core/diagnostics/redundant-parameter.lua
index 41781df8..2b7f1230 100644
--- a/script/core/diagnostics/redundant-parameter.lua
+++ b/script/core/diagnostics/redundant-parameter.lua
@@ -3,43 +3,6 @@ local guide = require 'parser.guide'
local vm = require 'vm'
local lang = require 'language'
-local function countCallArgs(source)
- local result = 0
- if not source.args then
- return 0
- end
- result = result + #source.args
- return result
-end
-
-local function countFuncArgs(source)
- if not source.args or #source.args == 0 then
- return 0
- end
- local lastArg = source.args[#source.args]
- if lastArg.type == '...'
- or (lastArg.name and lastArg.name[1] == '...') then
- return math.maxinteger
- else
- return #source.args
- end
-end
-
-local function getFuncArgs(func)
- local funcArgs
- local defs = vm.getDefs(func)
- for _, def in ipairs(defs) do
- if def.type == 'function'
- or def.type == 'doc.type.function' then
- local args = countFuncArgs(def)
- if not funcArgs or args > funcArgs then
- funcArgs = args
- end
- end
- end
- return funcArgs
-end
-
return function (uri, callback)
local state = files.getState(uri)
if not state then
@@ -47,28 +10,30 @@ return function (uri, callback)
end
guide.eachSourceType(state.ast, 'call', function (source)
- local callArgs = countCallArgs(source)
+ local callArgs = vm.countList(source.args)
if callArgs == 0 then
return
end
- local func = source.node
- local funcArgs = getFuncArgs(func)
-
- if not funcArgs then
- return
- end
+ local funcNode = vm.compileNode(source.node)
+ local _, funcArgs = vm.countParamsOfNode(funcNode)
- local delta = callArgs - funcArgs
- if delta <= 0 then
+ if callArgs <= funcArgs then
return
end
if callArgs == 1 and source.node.type == 'getmethod' then
return
end
- for i = #source.args - delta + 1, #source.args do
- local arg = source.args[i]
- if arg then
+ if funcArgs + 1 > #source.args then
+ local lastArg = source.args[#source.args]
+ callback {
+ start = lastArg.start,
+ finish = lastArg.finish,
+ message = lang.script('DIAG_OVER_MAX_ARGS', funcArgs, callArgs)
+ }
+ else
+ for i = funcArgs + 1, #source.args do
+ local arg = source.args[i]
callback {
start = arg.start,
finish = arg.finish,
diff --git a/script/vm/function.lua b/script/vm/function.lua
index 69900141..45f8c0df 100644
--- a/script/vm/function.lua
+++ b/script/vm/function.lua
@@ -1,21 +1,48 @@
---@class vm
local vm = require 'vm.vm'
+---@param arg parser.object
+---@return parser.object?
+local function getDocParam(arg)
+ if not arg.bindDocs then
+ return nil
+ end
+ for _, doc in ipairs(arg.bindDocs) do
+ if doc.type == 'doc.param'
+ and doc.param[1] == arg[1] then
+ return doc
+ end
+ end
+ return nil
+end
+
---@param func parser.object
---@return integer min
---@return integer max
function vm.countParamsOfFunction(func)
local min = 0
local max = 0
- if func.type == 'function'
- or func.type == 'doc.type.function' then
+ if func.type == 'function' then
if func.args then
max = #func.args
- min = max
for i = #func.args, 1, -1 do
local arg = func.args[i]
- if arg.type == '...'
- or (arg.name and arg.name[1] =='...') then
+ if arg.type == '...' then
+ max = math.huge
+ elseif getDocParam(arg)
+ and not vm.compileNode(arg):isNullable() then
+ min = i
+ break
+ end
+ end
+ end
+ end
+ if func.type == 'doc.type.function' then
+ if func.args then
+ max = #func.args
+ for i = #func.args, 1, -1 do
+ local arg = func.args[i]
+ if arg.name and arg.name[1] =='...' then
max = math.huge
elseif not vm.compileNode(arg):isNullable() then
min = i
@@ -27,22 +54,67 @@ function vm.countParamsOfFunction(func)
return min, max
end
+---@param node vm.node
+---@return integer min
+---@return integer max
+function vm.countParamsOfNode(node)
+ local min, max
+ for n in node:eachObject() do
+ if n.type == 'function'
+ or n.type == 'doc.type.function' then
+ local fmin, fmax = vm.countParamsOfFunction(n)
+ if not min or fmin < min then
+ min = fmin
+ end
+ if not max or fmax > max then
+ max = fmax
+ end
+ end
+ end
+ return min or 0, max or math.huge
+end
+
---@param func parser.object
---@return integer min
---@return integer max
function vm.countReturnsOfFunction(func)
if func.type == 'function' then
- if not func.returns then
- return 0, 0
- end
local min, max
- for _, ret in ipairs(func.returns) do
- local rmin, rmax = vm.countList(ret)
- if not min or rmin < min then
- min = rmin
+ if func.returns then
+ for _, ret in ipairs(func.returns) do
+ local rmin, rmax = vm.countList(ret)
+ if not min or rmin < min then
+ min = rmin
+ end
+ if not max or rmax > max then
+ max = rmax
+ end
+ end
+ end
+ if func.bindDocs then
+ local lastReturn
+ local n = 0
+ local dmin, dmax
+ for _, doc in ipairs(func.bindDocs) do
+ if doc.type == 'doc.return' then
+ for _, ret in ipairs(doc) do
+ n = n + 1
+ lastReturn = ret
+ dmax = n
+ if not vm.compileNode(ret):isNullable() then
+ dmin = n
+ end
+ end
+ end
+ end
+ if lastReturn and lastReturn.types[1][1] == '...' then
+ dmax = math.huge
+ end
+ if dmin and (not min or (dmin < min)) then
+ min = dmin
end
- if not max or rmax > max then
- max = rmax
+ if dmax and (not max or (dmax > max)) then
+ max = dmax
end
end
return min, max
@@ -50,7 +122,7 @@ function vm.countReturnsOfFunction(func)
if func.type == 'doc.type.function' then
return vm.countList(func.returns)
end
- return 0, 0
+ error('not a function')
end
---@param func parser.object
@@ -69,7 +141,7 @@ function vm.countReturnsOfCall(func, args)
max = rmax
end
end
- return min or 0, max or 0
+ return min or 0, max or math.huge
end
---@param list parser.object[]?
@@ -83,7 +155,8 @@ function vm.countList(list)
if not lastArg then
return 0, 0
end
- if lastArg.type == '...' then
+ if lastArg.type == '...'
+ or lastArg.type == 'varargs' then
return #list - 1, math.huge
end
if lastArg.type == 'call' then
diff --git a/test/diagnostics/common.lua b/test/diagnostics/common.lua
index 4edb9703..d640e45a 100644
--- a/test/diagnostics/common.lua
+++ b/test/diagnostics/common.lua
@@ -1620,3 +1620,14 @@ local n
print(n.x)
]]
+
+TEST [[
+---@diagnostic disable: unused-local, unused-function, undefined-global
+
+function F() end
+
+---@param x boolean
+function F(x) end
+
+F(k())
+]]
diff --git a/test/hover/init.lua b/test/hover/init.lua
index 1b6b2231..7fec9254 100644
--- a/test/hover/init.lua
+++ b/test/hover/init.lua
@@ -1949,8 +1949,8 @@ x({}, <?function?> () end)
]]
TEST [[
----@overload fun(x, y):string
----@overload fun(x):number
+---@overload fun(x: number, y: number):string
+---@overload fun(x: number):number
---@return boolean
local function f() end
@@ -1964,8 +1964,8 @@ function f()
]]
TEST [[
----@overload fun(x, y):string
----@overload fun(x):number
+---@overload fun(x: number, y: number):string
+---@overload fun(x: number):number
---@return boolean
local function f() end
@@ -1974,12 +1974,12 @@ local n2 = <?f?>(0)
local n3 = f(0, 0)
]]
[[
-local f: fun(x: any):number
+local f: fun(x: number):number
]]
TEST [[
----@overload fun(x, y):string
----@overload fun(x):number
+---@overload fun(x: number, y: number):string
+---@overload fun(x: number):number
---@return boolean
local function f() end
@@ -1988,5 +1988,5 @@ local n2 = f(0)
local n3 = <?f?>(0, 0)
]]
[[
-local f: fun(x: any, y: any):string
+local f: fun(x: number, y: number):string
]]
diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua
index ad1df8e0..74ac7120 100644
--- a/test/type_inference/init.lua
+++ b/test/type_inference/init.lua
@@ -2929,8 +2929,8 @@ local <?x?> = f(r1())
]]
TEST 'boolean' [[
----@overload fun(x, y):string
----@overload fun(x):number
+---@overload fun(x: number, y: number):string
+---@overload fun(x: number):number
---@return boolean
local function f() end
@@ -2940,8 +2940,8 @@ local n3 = f(0, 0)
]]
TEST 'number' [[
----@overload fun(x, y):string
----@overload fun(x):number
+---@overload fun(x: number, y: number):string
+---@overload fun(x: number):number
---@return boolean
local function f() end
@@ -2951,8 +2951,8 @@ local n3 = f(0, 0)
]]
TEST 'string' [[
----@overload fun(x, y):string
----@overload fun(x):number
+---@overload fun(x: number, y: number):string
+---@overload fun(x: number):number
---@return boolean
local function f() end