diff options
-rw-r--r-- | changelog.md | 1 | ||||
-rw-r--r-- | locale/zh-cn/script.lua | 2 | ||||
-rw-r--r-- | script/core/diagnostics/return-type-mismatch.lua | 70 | ||||
-rw-r--r-- | script/parser/guide.lua | 8 | ||||
-rw-r--r-- | script/proto/diagnostic.lua | 1 | ||||
-rw-r--r-- | script/vm/compiler.lua | 35 | ||||
-rw-r--r-- | test/diagnostics/type-check.lua | 61 |
7 files changed, 161 insertions, 17 deletions
diff --git a/changelog.md b/changelog.md index 914d17d6..0d25a274 100644 --- a/changelog.md +++ b/changelog.md @@ -10,6 +10,7 @@ * `missing-return-value` * `redundant-return-value` * `missing-return` + * `return-type-mismatch` * `NEW` settings: * `diagnostics.groupSeverity` * `diagnostics.groupFileStatus` diff --git a/locale/zh-cn/script.lua b/locale/zh-cn/script.lua index 4c468cea..aa0aea29 100644 --- a/locale/zh-cn/script.lua +++ b/locale/zh-cn/script.lua @@ -138,6 +138,8 @@ DIAG_REDUNDANT_RETURN_VALUE_RANGE = '最多只有 {max} 个返回值,但此处返回了第 {rmin} 到第 {rmax} 个值。' DIAG_MISSING_RETURN = '此处需要返回值。' +DIAG_RETURN_TYPE_MISMATCH = +'第 {index} 个返回值的类型为 `{def}` ,但实际返回的是 `{ref}`。' MWS_NOT_SUPPORT = '{} 目前还不支持多工作目录,我可能需要重启才能支持新的工作目录...' diff --git a/script/core/diagnostics/return-type-mismatch.lua b/script/core/diagnostics/return-type-mismatch.lua new file mode 100644 index 00000000..ba23fa2c --- /dev/null +++ b/script/core/diagnostics/return-type-mismatch.lua @@ -0,0 +1,70 @@ +local files = require 'files' +local lang = require 'language' +local guide = require 'parser.guide' +local vm = require 'vm' +local await = require 'await' + +---@param func parser.object +---@return vm.node[]? +local function getDocReturns(func) + if not func.bindDocs then + return nil + end + local returns = {} + for _, doc in ipairs(func.bindDocs) do + if doc.type == 'doc.return' then + for _, ret in ipairs(doc.returns) do + returns[ret.returnIndex] = vm.compileNode(ret) + end + end + end + if #returns == 0 then + return nil + end + return returns +end +---@async +return function (uri, callback) + local state = files.getState(uri) + if not state then + return + end + + ---@param docReturns vm.node[] + ---@param rets parser.object + local function checkReturn(docReturns, rets) + for i, docRet in ipairs(docReturns) do + local retNode, exp = vm.selectNode(rets, i) + if not exp then + break + end + if not vm.canCastType(uri, docRet, retNode) then + callback { + start = exp.start, + finish = exp.finish, + message = lang.script('DIAG_RETURN_TYPE_MISMATCH', { + def = vm.getInfer(docRet):view(uri), + ref = vm.getInfer(retNode):view(uri), + index = i, + }), + } + end + end + end + + ---@async + guide.eachSourceType(state.ast, 'function', function (source) + if not source.returns then + return + end + local docReturns = getDocReturns(source) + if not docReturns then + return + end + await.delay() + for _, ret in ipairs(source.returns) do + checkReturn(docReturns, ret) + await.delay() + end + end) +end diff --git a/script/parser/guide.lua b/script/parser/guide.lua index 83c84964..8de08ef9 100644 --- a/script/parser/guide.lua +++ b/script/parser/guide.lua @@ -1090,12 +1090,12 @@ end ---@param a table ---@param b table ---@return string|false mode ----@return table pathA? ----@return table pathB? +---@return table? pathA +---@return table? pathB function m.getPath(a, b, sameFunction) --- 首先测试双方在同一个函数内 if sameFunction and m.getParentFunction(a) ~= m.getParentFunction(b) then - return false, nil, nil + return false end local mode local objA @@ -1139,7 +1139,7 @@ function m.getPath(a, b, sameFunction) end end if not start then - return false, nil, nil + return false end -- pathA: { 1, 2, 3} -- pathB: {5, 6, 2, 3} diff --git a/script/proto/diagnostic.lua b/script/proto/diagnostic.lua index 43043b94..1065950d 100644 --- a/script/proto/diagnostic.lua +++ b/script/proto/diagnostic.lua @@ -74,6 +74,7 @@ m.register { 'assign-type-mismatch', 'param-type-mismatch', 'cast-type-mismatch', + 'return-type-mismatch', } { group = 'type-check', severity = 'Warning', diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua index f42a4768..e7af1f0f 100644 --- a/script/vm/compiler.lua +++ b/script/vm/compiler.lua @@ -469,9 +469,9 @@ function vm.getReturnOfFunction(func, index) end if not func._returns[index] then func._returns[index] = { - type = 'function.return', - parent = func, - index = index, + type = 'function.return', + parent = func, + returnIndex = index, } end return func._returns[index] @@ -736,14 +736,15 @@ function vm.compileByParentNode(source, key, ref, pushResult) end end ----@return vm.node? -local function selectNode(source, list, index) - if not list then - return nil - end +---@param list parser.object[] +---@param index integer +---@return vm.node +---@return parser.object? +function vm.selectNode(list, index) local exp if list[index] then exp = list[index] + index = 1 else for i = index, 1, -1 do if list[i] then @@ -758,16 +759,14 @@ local function selectNode(source, list, index) end end if not exp then - vm.setNode(source, vm.declareGlobal('type', 'nil')) - return vm.getNode(source) + return vm.createNode(vm.declareGlobal('type', 'nil')), nil end ---@type vm.node? local result if exp.type == 'call' then result = getReturn(exp.node, index, exp.args) if not result then - vm.setNode(source, vm.declareGlobal('type', 'unknown')) - return vm.getNode(source) + return vm.createNode(vm.declareGlobal('type', 'unknown')), exp end else ---@type vm.node @@ -776,6 +775,15 @@ local function selectNode(source, list, index) result:merge(vm.declareGlobal('type', 'unknown')) end end + return result, exp +end + +---@param source parser.object +---@param list parser.object[] +---@param index integer +---@return vm.node +local function selectNode(source, list, index) + local result = vm.selectNode(list, index) if source.type == 'function.return' then -- remove any for returns local rtnNode = vm.createNode() @@ -1513,9 +1521,10 @@ local compilerSwitch = util.switch() vm.setNode(source, vm.compileNode(source.value)) end) : case 'function.return' + ---@param source parser.object : call(function (source) local func = source.parent - local index = source.index + local index = source.returnIndex local hasMarkDoc if func.bindDocs then local sign = getObjectSign(func) diff --git a/test/diagnostics/type-check.lua b/test/diagnostics/type-check.lua index 34f7a492..d586c5de 100644 --- a/test/diagnostics/type-check.lua +++ b/test/diagnostics/type-check.lua @@ -496,5 +496,66 @@ TEST [[ local <!x!> = 'aaa' ]] +TEST [[ +---@return number +function F() + return <!true!> +end +]] + +TEST [[ +---@return number? +function F() + return 1 +end +]] + +TEST [[ +---@return number? +function F() + return nil +end +]] + +TEST [[ +---@return number, number +local function f() end + +---@return number, boolean +function F() + return <!f()!> +end +]] + +TEST [[ +---@return boolean, number +local function f() end + +---@return number, boolean +function F() + return <!f()!> +end +]] + +TEST [[ +---@return boolean, number? +local function f() end + +---@return number, boolean +function F() + return 1, f() +end +]] + +TEST [[ +---@return number, number? +local function f() end + +---@return number, boolean, number +function F() + return 1, <!f()!> +end +]] + config.remove(nil, 'Lua.diagnostics.disable', 'unused-local') config.remove(nil, 'Lua.diagnostics.disable', 'undefined-global') |