summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--changelog.md1
-rw-r--r--locale/zh-cn/script.lua2
-rw-r--r--script/core/diagnostics/return-type-mismatch.lua70
-rw-r--r--script/parser/guide.lua8
-rw-r--r--script/proto/diagnostic.lua1
-rw-r--r--script/vm/compiler.lua35
-rw-r--r--test/diagnostics/type-check.lua61
7 files changed, 161 insertions, 17 deletions
diff --git a/changelog.md b/changelog.md
index 914d17d6..0d25a274 100644
--- a/changelog.md
+++ b/changelog.md
@@ -10,6 +10,7 @@
* `missing-return-value`
* `redundant-return-value`
* `missing-return`
+ * `return-type-mismatch`
* `NEW` settings:
* `diagnostics.groupSeverity`
* `diagnostics.groupFileStatus`
diff --git a/locale/zh-cn/script.lua b/locale/zh-cn/script.lua
index 4c468cea..aa0aea29 100644
--- a/locale/zh-cn/script.lua
+++ b/locale/zh-cn/script.lua
@@ -138,6 +138,8 @@ DIAG_REDUNDANT_RETURN_VALUE_RANGE =
'最多只有 {max} 个返回值,但此处返回了第 {rmin} 到第 {rmax} 个值。'
DIAG_MISSING_RETURN =
'此处需要返回值。'
+DIAG_RETURN_TYPE_MISMATCH =
+'第 {index} 个返回值的类型为 `{def}` ,但实际返回的是 `{ref}`。'
MWS_NOT_SUPPORT =
'{} 目前还不支持多工作目录,我可能需要重启才能支持新的工作目录...'
diff --git a/script/core/diagnostics/return-type-mismatch.lua b/script/core/diagnostics/return-type-mismatch.lua
new file mode 100644
index 00000000..ba23fa2c
--- /dev/null
+++ b/script/core/diagnostics/return-type-mismatch.lua
@@ -0,0 +1,70 @@
+local files = require 'files'
+local lang = require 'language'
+local guide = require 'parser.guide'
+local vm = require 'vm'
+local await = require 'await'
+
+---@param func parser.object
+---@return vm.node[]?
+local function getDocReturns(func)
+ if not func.bindDocs then
+ return nil
+ end
+ local returns = {}
+ for _, doc in ipairs(func.bindDocs) do
+ if doc.type == 'doc.return' then
+ for _, ret in ipairs(doc.returns) do
+ returns[ret.returnIndex] = vm.compileNode(ret)
+ end
+ end
+ end
+ if #returns == 0 then
+ return nil
+ end
+ return returns
+end
+---@async
+return function (uri, callback)
+ local state = files.getState(uri)
+ if not state then
+ return
+ end
+
+ ---@param docReturns vm.node[]
+ ---@param rets parser.object
+ local function checkReturn(docReturns, rets)
+ for i, docRet in ipairs(docReturns) do
+ local retNode, exp = vm.selectNode(rets, i)
+ if not exp then
+ break
+ end
+ if not vm.canCastType(uri, docRet, retNode) then
+ callback {
+ start = exp.start,
+ finish = exp.finish,
+ message = lang.script('DIAG_RETURN_TYPE_MISMATCH', {
+ def = vm.getInfer(docRet):view(uri),
+ ref = vm.getInfer(retNode):view(uri),
+ index = i,
+ }),
+ }
+ end
+ end
+ end
+
+ ---@async
+ guide.eachSourceType(state.ast, 'function', function (source)
+ if not source.returns then
+ return
+ end
+ local docReturns = getDocReturns(source)
+ if not docReturns then
+ return
+ end
+ await.delay()
+ for _, ret in ipairs(source.returns) do
+ checkReturn(docReturns, ret)
+ await.delay()
+ end
+ end)
+end
diff --git a/script/parser/guide.lua b/script/parser/guide.lua
index 83c84964..8de08ef9 100644
--- a/script/parser/guide.lua
+++ b/script/parser/guide.lua
@@ -1090,12 +1090,12 @@ end
---@param a table
---@param b table
---@return string|false mode
----@return table pathA?
----@return table pathB?
+---@return table? pathA
+---@return table? pathB
function m.getPath(a, b, sameFunction)
--- 首先测试双方在同一个函数内
if sameFunction and m.getParentFunction(a) ~= m.getParentFunction(b) then
- return false, nil, nil
+ return false
end
local mode
local objA
@@ -1139,7 +1139,7 @@ function m.getPath(a, b, sameFunction)
end
end
if not start then
- return false, nil, nil
+ return false
end
-- pathA: { 1, 2, 3}
-- pathB: {5, 6, 2, 3}
diff --git a/script/proto/diagnostic.lua b/script/proto/diagnostic.lua
index 43043b94..1065950d 100644
--- a/script/proto/diagnostic.lua
+++ b/script/proto/diagnostic.lua
@@ -74,6 +74,7 @@ m.register {
'assign-type-mismatch',
'param-type-mismatch',
'cast-type-mismatch',
+ 'return-type-mismatch',
} {
group = 'type-check',
severity = 'Warning',
diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua
index f42a4768..e7af1f0f 100644
--- a/script/vm/compiler.lua
+++ b/script/vm/compiler.lua
@@ -469,9 +469,9 @@ function vm.getReturnOfFunction(func, index)
end
if not func._returns[index] then
func._returns[index] = {
- type = 'function.return',
- parent = func,
- index = index,
+ type = 'function.return',
+ parent = func,
+ returnIndex = index,
}
end
return func._returns[index]
@@ -736,14 +736,15 @@ function vm.compileByParentNode(source, key, ref, pushResult)
end
end
----@return vm.node?
-local function selectNode(source, list, index)
- if not list then
- return nil
- end
+---@param list parser.object[]
+---@param index integer
+---@return vm.node
+---@return parser.object?
+function vm.selectNode(list, index)
local exp
if list[index] then
exp = list[index]
+ index = 1
else
for i = index, 1, -1 do
if list[i] then
@@ -758,16 +759,14 @@ local function selectNode(source, list, index)
end
end
if not exp then
- vm.setNode(source, vm.declareGlobal('type', 'nil'))
- return vm.getNode(source)
+ return vm.createNode(vm.declareGlobal('type', 'nil')), nil
end
---@type vm.node?
local result
if exp.type == 'call' then
result = getReturn(exp.node, index, exp.args)
if not result then
- vm.setNode(source, vm.declareGlobal('type', 'unknown'))
- return vm.getNode(source)
+ return vm.createNode(vm.declareGlobal('type', 'unknown')), exp
end
else
---@type vm.node
@@ -776,6 +775,15 @@ local function selectNode(source, list, index)
result:merge(vm.declareGlobal('type', 'unknown'))
end
end
+ return result, exp
+end
+
+---@param source parser.object
+---@param list parser.object[]
+---@param index integer
+---@return vm.node
+local function selectNode(source, list, index)
+ local result = vm.selectNode(list, index)
if source.type == 'function.return' then
-- remove any for returns
local rtnNode = vm.createNode()
@@ -1513,9 +1521,10 @@ local compilerSwitch = util.switch()
vm.setNode(source, vm.compileNode(source.value))
end)
: case 'function.return'
+ ---@param source parser.object
: call(function (source)
local func = source.parent
- local index = source.index
+ local index = source.returnIndex
local hasMarkDoc
if func.bindDocs then
local sign = getObjectSign(func)
diff --git a/test/diagnostics/type-check.lua b/test/diagnostics/type-check.lua
index 34f7a492..d586c5de 100644
--- a/test/diagnostics/type-check.lua
+++ b/test/diagnostics/type-check.lua
@@ -496,5 +496,66 @@ TEST [[
local <!x!> = 'aaa'
]]
+TEST [[
+---@return number
+function F()
+ return <!true!>
+end
+]]
+
+TEST [[
+---@return number?
+function F()
+ return 1
+end
+]]
+
+TEST [[
+---@return number?
+function F()
+ return nil
+end
+]]
+
+TEST [[
+---@return number, number
+local function f() end
+
+---@return number, boolean
+function F()
+ return <!f()!>
+end
+]]
+
+TEST [[
+---@return boolean, number
+local function f() end
+
+---@return number, boolean
+function F()
+ return <!f()!>
+end
+]]
+
+TEST [[
+---@return boolean, number?
+local function f() end
+
+---@return number, boolean
+function F()
+ return 1, f()
+end
+]]
+
+TEST [[
+---@return number, number?
+local function f() end
+
+---@return number, boolean, number
+function F()
+ return 1, <!f()!>
+end
+]]
+
config.remove(nil, 'Lua.diagnostics.disable', 'unused-local')
config.remove(nil, 'Lua.diagnostics.disable', 'undefined-global')