From dbb392e3f3953c6ce821755ad14bc857840f899c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=80=E8=90=8C=E5=B0=8F=E6=B1=90?= Date: Mon, 7 Nov 2022 19:59:11 +0800 Subject: full support for interface --- script/core/diagnostics/missing-parameter.lua | 20 ++++ script/core/diagnostics/missing-return-value.lua | 23 +--- script/core/diagnostics/missing-return.lua | 26 +---- script/core/diagnostics/redundant-parameter.lua | 18 ++++ script/core/diagnostics/redundant-return-value.lua | 23 +--- script/core/diagnostics/return-type-mismatch.lua | 25 +++-- script/parser/compile.lua | 13 ++- script/vm/function.lua | 116 +++++++++++++++++++-- test.lua | 2 +- test/diagnostics/type-check.lua | 59 +++++++++-- 10 files changed, 228 insertions(+), 97 deletions(-) diff --git a/script/core/diagnostics/missing-parameter.lua b/script/core/diagnostics/missing-parameter.lua index 78b94a09..1194f6d1 100644 --- a/script/core/diagnostics/missing-parameter.lua +++ b/script/core/diagnostics/missing-parameter.lua @@ -29,4 +29,24 @@ return function (uri, callback) message = lang.script('DIAG_MISS_ARGS', funcArgs, callArgs), } end) + + ---@async + guide.eachSourceType(state.ast, 'function', function (source) + await.delay() + if not source.args then + return + end + local funcArgs = vm.countParamsOfSource(source) + if funcArgs == 0 then + return + end + local myArgs = #source.args + if myArgs < funcArgs then + callback { + start = source.args.start, + finish = source.args.finish, + message = lang.script('DIAG_MISS_ARGS', funcArgs, myArgs), + } + end + end) end diff --git a/script/core/diagnostics/missing-return-value.lua b/script/core/diagnostics/missing-return-value.lua index 9eab7074..5c672b54 100644 --- a/script/core/diagnostics/missing-return-value.lua +++ b/script/core/diagnostics/missing-return-value.lua @@ -4,27 +4,6 @@ local vm = require 'vm' local lang = require 'language' local await = require 'await' ----@param func parser.object ----@return integer -local function getReturnsMin(func) - local min = vm.countReturnsOfFunction(func, true) - if min == 0 then - return 0 - end - for _, doc in ipairs(func.bindDocs) do - if doc.type == 'doc.overload' then - local n = vm.countReturnsOfFunction(doc.overload) - if n == 0 then - return 0 - end - if n < min then - min = n - end - end - end - return min -end - ---@async return function (uri, callback) local state = files.getState(uri) @@ -39,7 +18,7 @@ return function (uri, callback) if not returns then return end - local min = getReturnsMin(source) + local min = vm.countReturnsOfSource(source) if min == 0 then return end diff --git a/script/core/diagnostics/missing-return.lua b/script/core/diagnostics/missing-return.lua index 42ccaa9f..7333e5e3 100644 --- a/script/core/diagnostics/missing-return.lua +++ b/script/core/diagnostics/missing-return.lua @@ -38,27 +38,6 @@ local function hasReturn(block) return false end ----@param func parser.object ----@return integer -local function getReturnsMin(func) - local min = vm.countReturnsOfFunction(func, true) - if min == 0 then - return 0 - end - for _, doc in ipairs(func.bindDocs) do - if doc.type == 'doc.overload' then - local n = vm.countReturnsOfFunction(doc.overload) - if n == 0 then - return 0 - end - if n < min then - min = n - end - end - end - return min -end - ---@async return function (uri, callback) local state = files.getState(uri) @@ -75,7 +54,7 @@ return function (uri, callback) return end await.delay() - if getReturnsMin(source) == 0 then + if vm.countReturnsOfSource(source) == 0 then return end if hasReturn(source) then @@ -86,8 +65,7 @@ return function (uri, callback) if lastAction then pos = lastAction.range or lastAction.finish else - local row = guide.rowColOf(source.finish) - pos = guide.positionOf(row - 1, 0) + pos = source.keyword[3] or source.finish end callback { start = pos, diff --git a/script/core/diagnostics/redundant-parameter.lua b/script/core/diagnostics/redundant-parameter.lua index 9898d9bd..667f9c61 100644 --- a/script/core/diagnostics/redundant-parameter.lua +++ b/script/core/diagnostics/redundant-parameter.lua @@ -52,4 +52,22 @@ return function (uri, callback) end end end) + + ---@async + guide.eachSourceType(state.ast, 'function', function (source) + await.delay() + if not source.args then + return + end + local _, funcArgs = vm.countParamsOfSource(source) + local myArgs = #source.args + for i = funcArgs + 1, myArgs do + local arg = source.args[i] + callback { + start = arg.start, + finish = arg.finish, + message = lang.script('DIAG_OVER_MAX_ARGS', funcArgs, myArgs), + } + end + end) end diff --git a/script/core/diagnostics/redundant-return-value.lua b/script/core/diagnostics/redundant-return-value.lua index 9b913438..18667840 100644 --- a/script/core/diagnostics/redundant-return-value.lua +++ b/script/core/diagnostics/redundant-return-value.lua @@ -4,27 +4,6 @@ local vm = require 'vm' local lang = require 'language' local await = require 'await' ----@param func parser.object ----@return number -local function getReturnsMax(func) - local _, max = vm.countReturnsOfFunction(func, true) - if max == math.huge then - return max - end - for _, doc in ipairs(func.bindDocs) do - if doc.type == 'doc.overload' then - local _, n = vm.countReturnsOfFunction(doc.overload) - if n == math.huge then - return n - end - if n > max then - max = n - end - end - end - return max -end - ---@async return function (uri, callback) local state = files.getState(uri) @@ -39,7 +18,7 @@ return function (uri, callback) return end await.delay() - local max = getReturnsMax(source) + local _, max = vm.countReturnsOfSource(source) for _, ret in ipairs(returns) do local rmin, rmax = vm.countList(ret) if rmin > max then diff --git a/script/core/diagnostics/return-type-mismatch.lua b/script/core/diagnostics/return-type-mismatch.lua index 2ff8a909..1f335e9d 100644 --- a/script/core/diagnostics/return-type-mismatch.lua +++ b/script/core/diagnostics/return-type-mismatch.lua @@ -8,21 +8,28 @@ local util = require 'utility' ---@param func parser.object ---@return vm.node[]? local function getDocReturns(func) - if not func.bindDocs then - return nil - end ---@type table local returns = util.defaultTable(function () return vm.createNode() end) - for _, doc in ipairs(func.bindDocs) do - if doc.type == 'doc.return' then - for _, ret in ipairs(doc.returns) do - returns[ret.returnIndex]:merge(vm.compileNode(ret)) + if func.bindDocs then + for _, doc in ipairs(func.bindDocs) do + if doc.type == 'doc.return' then + for _, ret in ipairs(doc.returns) do + returns[ret.returnIndex]:merge(vm.compileNode(ret)) + end + end + if doc.type == 'doc.overload' then + for i, ret in ipairs(doc.overload.returns) do + returns[i]:merge(vm.compileNode(ret)) + end end end - if doc.type == 'doc.overload' then - for i, ret in ipairs(doc.overload.returns) do + end + for nd in vm.compileNode(func):eachObject() do + if nd.type == 'doc.type.function' then + ---@cast nd parser.object + for i, ret in ipairs(nd.returns) do returns[i]:merge(vm.compileNode(ret)) end end diff --git a/script/parser/compile.lua b/script/parser/compile.lua index 8b8a7770..b8040382 100644 --- a/script/parser/compile.lua +++ b/script/parser/compile.lua @@ -2289,16 +2289,15 @@ local function parseFunction(isLocal, isAction) end end if hasLeftParen then + params = params or {} local parenLeft = getPosition(Tokens[Index], 'left') Index = Index + 2 params = parseParams(params) - if params then - params.type = 'funcargs' - params.start = parenLeft - params.finish = lastRightPosition() - params.parent = func - func.args = params - end + params.type = 'funcargs' + params.start = parenLeft + params.finish = lastRightPosition() + params.parent = func + func.args = params skipSpace(true) if Tokens[Index + 1] == ')' then local parenRight = getPosition(Tokens[Index], 'right') diff --git a/script/vm/function.lua b/script/vm/function.lua index 5ad12acf..b4466668 100644 --- a/script/vm/function.lua +++ b/script/vm/function.lua @@ -63,6 +63,54 @@ function vm.countParamsOfFunction(func) return min, max, def end +---@param source parser.object +---@return integer min +---@return number max +---@return integer def +function vm.countParamsOfSource(source) + local min = 0 + local max = 0 + local def = 0 + local overloads = {} + if source.bindDocs then + for _, doc in ipairs(source.bindDocs) do + if doc.type == 'doc.overload' then + overloads[doc.overload] = true + end + end + end + local hasDocFunction + for nd in vm.compileNode(source):eachObject() do + if nd.type == 'doc.type.function' and not overloads[nd] then + hasDocFunction = true + ---@cast nd parser.object + local dmin, dmax, ddef = vm.countParamsOfFunction(nd) + if dmin > min then + min = dmin + end + if dmax > max then + max = dmax + end + if ddef > def then + def = ddef + end + end + end + if not hasDocFunction then + local dmin, dmax, ddef = vm.countParamsOfFunction(source) + if dmin > min then + min = dmin + end + if dmax > max then + max = dmax + end + if ddef > def then + def = ddef + end + end + return min, max, def +end + ---@param node vm.node ---@return integer min ---@return number max @@ -136,12 +184,12 @@ function vm.countReturnsOfFunction(func, onlyDoc, mark) end if not onlyDoc and not hasDocReturn and func.returns then for _, ret in ipairs(func.returns) do - local rmin, rmax, ddef = vm.countList(ret, mark) - if not min or rmin < min then - min = rmin + local dmin, dmax, ddef = vm.countList(ret, mark) + if not min or dmin < min then + min = dmin end - if not max or rmax > max then - max = rmax + if not max or dmax > max then + max = dmax end if not def or ddef > def then def = ddef @@ -156,6 +204,62 @@ function vm.countReturnsOfFunction(func, onlyDoc, mark) error('not a function') end +---@param source parser.object +---@return integer min +---@return number max +---@return integer def +function vm.countReturnsOfSource(source) + local overloads = {} + local hasDocFunction + local min, max, def + if source.bindDocs then + for _, doc in ipairs(source.bindDocs) do + if doc.type == 'doc.overload' then + overloads[doc.overload] = true + local dmin, dmax, ddef = vm.countReturnsOfFunction(doc.overload) + if not min or dmin < min then + min = dmin + end + if not max or dmax > max then + max = dmax + end + if not def or ddef > def then + def = ddef + end + end + end + end + for nd in vm.compileNode(source):eachObject() do + if nd.type == 'doc.type.function' and not overloads[nd] then + ---@cast nd parser.object + hasDocFunction = true + local dmin, dmax, ddef = vm.countReturnsOfFunction(nd) + if not min or dmin < min then + min = dmin + end + if not max or dmax > max then + max = dmax + end + if not def or ddef > def then + def = ddef + end + end + end + if not hasDocFunction then + local dmin, dmax, ddef = vm.countReturnsOfFunction(source, true) + if not min or dmin < min then + min = dmin + end + if not max or dmax > max then + max = dmax + end + if not def or ddef > def then + def = ddef + end + end + return min, max, def +end + ---@param func parser.object ---@param mark? table ---@return integer min @@ -254,7 +358,7 @@ function vm.isVarargFunctionWithOverloads(func) if not func.args then return false end - if func.args[1].type ~= '...' then + if not func.args[1] or func.args[1].type ~= '...' then return false end if not func.bindDocs then diff --git a/test.lua b/test.lua index 2a8d97c7..e255841f 100644 --- a/test.lua +++ b/test.lua @@ -56,8 +56,8 @@ local function testAll() test 'references' test 'hover' test 'completion' - test 'crossfile' test 'diagnostics' + test 'crossfile' test 'highlight' test 'rename' test 'signature' diff --git a/test/diagnostics/type-check.lua b/test/diagnostics/type-check.lua index 81cb1050..a8525ab0 100644 --- a/test/diagnostics/type-check.lua +++ b/test/diagnostics/type-check.lua @@ -3,7 +3,7 @@ local config = require 'config' config.add(nil, 'Lua.diagnostics.disable', 'unused-local') config.add(nil, 'Lua.diagnostics.disable', 'unused-function') config.add(nil, 'Lua.diagnostics.disable', 'undefined-global') -config.add(nil, 'Lua.diagnostics.disable', 'missing-return') +config.add(nil, 'Lua.diagnostics.disable', 'redundant-return') config.set(nil, 'Lua.type.castNumberToInteger', false) TEST [[ @@ -522,7 +522,9 @@ end TEST [[ ---@return number, number -local function f() end +local function f() + return 1, 1 +end ---@return number, boolean function F() @@ -532,7 +534,9 @@ end TEST [[ ---@return boolean, number -local function f() end +local function f() + return true, 1 +end ---@return number, boolean function F() @@ -542,7 +546,9 @@ end TEST [[ ---@return boolean, number? -local function f() end +local function f() + return true, 1 +end ---@return number, boolean function F() @@ -552,7 +558,9 @@ end TEST [[ ---@return number, number? -local function f() end +local function f() + return 1, 1 +end ---@return number, boolean, number function F() @@ -969,8 +977,47 @@ local x local t = { true, false, x } ]] +TEST [[ +---@type fun():number +local function f() +end +]] + +TEST [[ +---@type fun():number +local function f() + +end +]] + +TEST [[ +---@type fun():number +local function f() + return 1, +end +]] + +TEST [[ +---@type fun():number +local function f() + return +end +]] + +TEST [[ +---@type fun(x: number) +local function f +end +]] + +TEST [[ +---@type fun(x: number) +local function f(x, ) +end +]] + config.remove(nil, 'Lua.diagnostics.disable', 'unused-local') config.remove(nil, 'Lua.diagnostics.disable', 'unused-function') config.remove(nil, 'Lua.diagnostics.disable', 'undefined-global') -config.remove(nil, 'Lua.diagnostics.disable', 'missing-return') +config.remove(nil, 'Lua.diagnostics.disable', 'redundant-return') config.set(nil, 'Lua.type.castNumberToInteger', true) -- cgit v1.2.3