summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--script/core/diagnostics/missing-parameter.lua20
-rw-r--r--script/core/diagnostics/missing-return-value.lua23
-rw-r--r--script/core/diagnostics/missing-return.lua26
-rw-r--r--script/core/diagnostics/redundant-parameter.lua18
-rw-r--r--script/core/diagnostics/redundant-return-value.lua23
-rw-r--r--script/core/diagnostics/return-type-mismatch.lua25
-rw-r--r--script/parser/compile.lua13
-rw-r--r--script/vm/function.lua116
-rw-r--r--test.lua2
-rw-r--r--test/diagnostics/type-check.lua59
10 files changed, 228 insertions, 97 deletions
diff --git a/script/core/diagnostics/missing-parameter.lua b/script/core/diagnostics/missing-parameter.lua
index 78b94a09..1194f6d1 100644
--- a/script/core/diagnostics/missing-parameter.lua
+++ b/script/core/diagnostics/missing-parameter.lua
@@ -29,4 +29,24 @@ return function (uri, callback)
message = lang.script('DIAG_MISS_ARGS', funcArgs, callArgs),
}
end)
+
+ ---@async
+ guide.eachSourceType(state.ast, 'function', function (source)
+ await.delay()
+ if not source.args then
+ return
+ end
+ local funcArgs = vm.countParamsOfSource(source)
+ if funcArgs == 0 then
+ return
+ end
+ local myArgs = #source.args
+ if myArgs < funcArgs then
+ callback {
+ start = source.args.start,
+ finish = source.args.finish,
+ message = lang.script('DIAG_MISS_ARGS', funcArgs, myArgs),
+ }
+ end
+ end)
end
diff --git a/script/core/diagnostics/missing-return-value.lua b/script/core/diagnostics/missing-return-value.lua
index 9eab7074..5c672b54 100644
--- a/script/core/diagnostics/missing-return-value.lua
+++ b/script/core/diagnostics/missing-return-value.lua
@@ -4,27 +4,6 @@ local vm = require 'vm'
local lang = require 'language'
local await = require 'await'
----@param func parser.object
----@return integer
-local function getReturnsMin(func)
- local min = vm.countReturnsOfFunction(func, true)
- if min == 0 then
- return 0
- end
- for _, doc in ipairs(func.bindDocs) do
- if doc.type == 'doc.overload' then
- local n = vm.countReturnsOfFunction(doc.overload)
- if n == 0 then
- return 0
- end
- if n < min then
- min = n
- end
- end
- end
- return min
-end
-
---@async
return function (uri, callback)
local state = files.getState(uri)
@@ -39,7 +18,7 @@ return function (uri, callback)
if not returns then
return
end
- local min = getReturnsMin(source)
+ local min = vm.countReturnsOfSource(source)
if min == 0 then
return
end
diff --git a/script/core/diagnostics/missing-return.lua b/script/core/diagnostics/missing-return.lua
index 42ccaa9f..7333e5e3 100644
--- a/script/core/diagnostics/missing-return.lua
+++ b/script/core/diagnostics/missing-return.lua
@@ -38,27 +38,6 @@ local function hasReturn(block)
return false
end
----@param func parser.object
----@return integer
-local function getReturnsMin(func)
- local min = vm.countReturnsOfFunction(func, true)
- if min == 0 then
- return 0
- end
- for _, doc in ipairs(func.bindDocs) do
- if doc.type == 'doc.overload' then
- local n = vm.countReturnsOfFunction(doc.overload)
- if n == 0 then
- return 0
- end
- if n < min then
- min = n
- end
- end
- end
- return min
-end
-
---@async
return function (uri, callback)
local state = files.getState(uri)
@@ -75,7 +54,7 @@ return function (uri, callback)
return
end
await.delay()
- if getReturnsMin(source) == 0 then
+ if vm.countReturnsOfSource(source) == 0 then
return
end
if hasReturn(source) then
@@ -86,8 +65,7 @@ return function (uri, callback)
if lastAction then
pos = lastAction.range or lastAction.finish
else
- local row = guide.rowColOf(source.finish)
- pos = guide.positionOf(row - 1, 0)
+ pos = source.keyword[3] or source.finish
end
callback {
start = pos,
diff --git a/script/core/diagnostics/redundant-parameter.lua b/script/core/diagnostics/redundant-parameter.lua
index 9898d9bd..667f9c61 100644
--- a/script/core/diagnostics/redundant-parameter.lua
+++ b/script/core/diagnostics/redundant-parameter.lua
@@ -52,4 +52,22 @@ return function (uri, callback)
end
end
end)
+
+ ---@async
+ guide.eachSourceType(state.ast, 'function', function (source)
+ await.delay()
+ if not source.args then
+ return
+ end
+ local _, funcArgs = vm.countParamsOfSource(source)
+ local myArgs = #source.args
+ for i = funcArgs + 1, myArgs do
+ local arg = source.args[i]
+ callback {
+ start = arg.start,
+ finish = arg.finish,
+ message = lang.script('DIAG_OVER_MAX_ARGS', funcArgs, myArgs),
+ }
+ end
+ end)
end
diff --git a/script/core/diagnostics/redundant-return-value.lua b/script/core/diagnostics/redundant-return-value.lua
index 9b913438..18667840 100644
--- a/script/core/diagnostics/redundant-return-value.lua
+++ b/script/core/diagnostics/redundant-return-value.lua
@@ -4,27 +4,6 @@ local vm = require 'vm'
local lang = require 'language'
local await = require 'await'
----@param func parser.object
----@return number
-local function getReturnsMax(func)
- local _, max = vm.countReturnsOfFunction(func, true)
- if max == math.huge then
- return max
- end
- for _, doc in ipairs(func.bindDocs) do
- if doc.type == 'doc.overload' then
- local _, n = vm.countReturnsOfFunction(doc.overload)
- if n == math.huge then
- return n
- end
- if n > max then
- max = n
- end
- end
- end
- return max
-end
-
---@async
return function (uri, callback)
local state = files.getState(uri)
@@ -39,7 +18,7 @@ return function (uri, callback)
return
end
await.delay()
- local max = getReturnsMax(source)
+ local _, max = vm.countReturnsOfSource(source)
for _, ret in ipairs(returns) do
local rmin, rmax = vm.countList(ret)
if rmin > max then
diff --git a/script/core/diagnostics/return-type-mismatch.lua b/script/core/diagnostics/return-type-mismatch.lua
index 2ff8a909..1f335e9d 100644
--- a/script/core/diagnostics/return-type-mismatch.lua
+++ b/script/core/diagnostics/return-type-mismatch.lua
@@ -8,21 +8,28 @@ local util = require 'utility'
---@param func parser.object
---@return vm.node[]?
local function getDocReturns(func)
- if not func.bindDocs then
- return nil
- end
---@type table<integer, vm.node>
local returns = util.defaultTable(function ()
return vm.createNode()
end)
- for _, doc in ipairs(func.bindDocs) do
- if doc.type == 'doc.return' then
- for _, ret in ipairs(doc.returns) do
- returns[ret.returnIndex]:merge(vm.compileNode(ret))
+ if func.bindDocs then
+ for _, doc in ipairs(func.bindDocs) do
+ if doc.type == 'doc.return' then
+ for _, ret in ipairs(doc.returns) do
+ returns[ret.returnIndex]:merge(vm.compileNode(ret))
+ end
+ end
+ if doc.type == 'doc.overload' then
+ for i, ret in ipairs(doc.overload.returns) do
+ returns[i]:merge(vm.compileNode(ret))
+ end
end
end
- if doc.type == 'doc.overload' then
- for i, ret in ipairs(doc.overload.returns) do
+ end
+ for nd in vm.compileNode(func):eachObject() do
+ if nd.type == 'doc.type.function' then
+ ---@cast nd parser.object
+ for i, ret in ipairs(nd.returns) do
returns[i]:merge(vm.compileNode(ret))
end
end
diff --git a/script/parser/compile.lua b/script/parser/compile.lua
index 8b8a7770..b8040382 100644
--- a/script/parser/compile.lua
+++ b/script/parser/compile.lua
@@ -2289,16 +2289,15 @@ local function parseFunction(isLocal, isAction)
end
end
if hasLeftParen then
+ params = params or {}
local parenLeft = getPosition(Tokens[Index], 'left')
Index = Index + 2
params = parseParams(params)
- if params then
- params.type = 'funcargs'
- params.start = parenLeft
- params.finish = lastRightPosition()
- params.parent = func
- func.args = params
- end
+ params.type = 'funcargs'
+ params.start = parenLeft
+ params.finish = lastRightPosition()
+ params.parent = func
+ func.args = params
skipSpace(true)
if Tokens[Index + 1] == ')' then
local parenRight = getPosition(Tokens[Index], 'right')
diff --git a/script/vm/function.lua b/script/vm/function.lua
index 5ad12acf..b4466668 100644
--- a/script/vm/function.lua
+++ b/script/vm/function.lua
@@ -63,6 +63,54 @@ function vm.countParamsOfFunction(func)
return min, max, def
end
+---@param source parser.object
+---@return integer min
+---@return number max
+---@return integer def
+function vm.countParamsOfSource(source)
+ local min = 0
+ local max = 0
+ local def = 0
+ local overloads = {}
+ if source.bindDocs then
+ for _, doc in ipairs(source.bindDocs) do
+ if doc.type == 'doc.overload' then
+ overloads[doc.overload] = true
+ end
+ end
+ end
+ local hasDocFunction
+ for nd in vm.compileNode(source):eachObject() do
+ if nd.type == 'doc.type.function' and not overloads[nd] then
+ hasDocFunction = true
+ ---@cast nd parser.object
+ local dmin, dmax, ddef = vm.countParamsOfFunction(nd)
+ if dmin > min then
+ min = dmin
+ end
+ if dmax > max then
+ max = dmax
+ end
+ if ddef > def then
+ def = ddef
+ end
+ end
+ end
+ if not hasDocFunction then
+ local dmin, dmax, ddef = vm.countParamsOfFunction(source)
+ if dmin > min then
+ min = dmin
+ end
+ if dmax > max then
+ max = dmax
+ end
+ if ddef > def then
+ def = ddef
+ end
+ end
+ return min, max, def
+end
+
---@param node vm.node
---@return integer min
---@return number max
@@ -136,12 +184,12 @@ function vm.countReturnsOfFunction(func, onlyDoc, mark)
end
if not onlyDoc and not hasDocReturn and func.returns then
for _, ret in ipairs(func.returns) do
- local rmin, rmax, ddef = vm.countList(ret, mark)
- if not min or rmin < min then
- min = rmin
+ local dmin, dmax, ddef = vm.countList(ret, mark)
+ if not min or dmin < min then
+ min = dmin
end
- if not max or rmax > max then
- max = rmax
+ if not max or dmax > max then
+ max = dmax
end
if not def or ddef > def then
def = ddef
@@ -156,6 +204,62 @@ function vm.countReturnsOfFunction(func, onlyDoc, mark)
error('not a function')
end
+---@param source parser.object
+---@return integer min
+---@return number max
+---@return integer def
+function vm.countReturnsOfSource(source)
+ local overloads = {}
+ local hasDocFunction
+ local min, max, def
+ if source.bindDocs then
+ for _, doc in ipairs(source.bindDocs) do
+ if doc.type == 'doc.overload' then
+ overloads[doc.overload] = true
+ local dmin, dmax, ddef = vm.countReturnsOfFunction(doc.overload)
+ if not min or dmin < min then
+ min = dmin
+ end
+ if not max or dmax > max then
+ max = dmax
+ end
+ if not def or ddef > def then
+ def = ddef
+ end
+ end
+ end
+ end
+ for nd in vm.compileNode(source):eachObject() do
+ if nd.type == 'doc.type.function' and not overloads[nd] then
+ ---@cast nd parser.object
+ hasDocFunction = true
+ local dmin, dmax, ddef = vm.countReturnsOfFunction(nd)
+ if not min or dmin < min then
+ min = dmin
+ end
+ if not max or dmax > max then
+ max = dmax
+ end
+ if not def or ddef > def then
+ def = ddef
+ end
+ end
+ end
+ if not hasDocFunction then
+ local dmin, dmax, ddef = vm.countReturnsOfFunction(source, true)
+ if not min or dmin < min then
+ min = dmin
+ end
+ if not max or dmax > max then
+ max = dmax
+ end
+ if not def or ddef > def then
+ def = ddef
+ end
+ end
+ return min, max, def
+end
+
---@param func parser.object
---@param mark? table
---@return integer min
@@ -254,7 +358,7 @@ function vm.isVarargFunctionWithOverloads(func)
if not func.args then
return false
end
- if func.args[1].type ~= '...' then
+ if not func.args[1] or func.args[1].type ~= '...' then
return false
end
if not func.bindDocs then
diff --git a/test.lua b/test.lua
index 2a8d97c7..e255841f 100644
--- a/test.lua
+++ b/test.lua
@@ -56,8 +56,8 @@ local function testAll()
test 'references'
test 'hover'
test 'completion'
- test 'crossfile'
test 'diagnostics'
+ test 'crossfile'
test 'highlight'
test 'rename'
test 'signature'
diff --git a/test/diagnostics/type-check.lua b/test/diagnostics/type-check.lua
index 81cb1050..a8525ab0 100644
--- a/test/diagnostics/type-check.lua
+++ b/test/diagnostics/type-check.lua
@@ -3,7 +3,7 @@ local config = require 'config'
config.add(nil, 'Lua.diagnostics.disable', 'unused-local')
config.add(nil, 'Lua.diagnostics.disable', 'unused-function')
config.add(nil, 'Lua.diagnostics.disable', 'undefined-global')
-config.add(nil, 'Lua.diagnostics.disable', 'missing-return')
+config.add(nil, 'Lua.diagnostics.disable', 'redundant-return')
config.set(nil, 'Lua.type.castNumberToInteger', false)
TEST [[
@@ -522,7 +522,9 @@ end
TEST [[
---@return number, number
-local function f() end
+local function f()
+ return 1, 1
+end
---@return number, boolean
function F()
@@ -532,7 +534,9 @@ end
TEST [[
---@return boolean, number
-local function f() end
+local function f()
+ return true, 1
+end
---@return number, boolean
function F()
@@ -542,7 +546,9 @@ end
TEST [[
---@return boolean, number?
-local function f() end
+local function f()
+ return true, 1
+end
---@return number, boolean
function F()
@@ -552,7 +558,9 @@ end
TEST [[
---@return number, number?
-local function f() end
+local function f()
+ return 1, 1
+end
---@return number, boolean, number
function F()
@@ -969,8 +977,47 @@ local x
local t = { true, false, x }
]]
+TEST [[
+---@type fun():number
+local function f()
+<!!>end
+]]
+
+TEST [[
+---@type fun():number
+local function f()
+ <!return!>
+end
+]]
+
+TEST [[
+---@type fun():number
+local function f()
+ return 1, <!true!>
+end
+]]
+
+TEST [[
+---@type fun():number
+local function f()
+ return <!true!>
+end
+]]
+
+TEST [[
+---@type fun(x: number)
+local function f<!()!>
+end
+]]
+
+TEST [[
+---@type fun(x: number)
+local function f(x, <!y!>)
+end
+]]
+
config.remove(nil, 'Lua.diagnostics.disable', 'unused-local')
config.remove(nil, 'Lua.diagnostics.disable', 'unused-function')
config.remove(nil, 'Lua.diagnostics.disable', 'undefined-global')
-config.remove(nil, 'Lua.diagnostics.disable', 'missing-return')
+config.remove(nil, 'Lua.diagnostics.disable', 'redundant-return')
config.set(nil, 'Lua.type.castNumberToInteger', true)